kernels-bot commited on
Commit
12ce283
·
verified ·
1 Parent(s): 697dd7e

Uploaded using `kernel-builder` (batch 23/32).

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. build/torch210-cxx11-cu128-x86_64-linux/_C.py +194 -0
  3. build/torch210-cxx11-cu128-x86_64-linux/__init__.py +152 -19
  4. build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so +3 -0
  5. build/torch210-cxx11-cu128-x86_64-linux/_ops.py +3 -3
  6. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/comm/barrier.cuh +83 -0
  7. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/compile.cuh +18 -0
  8. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/cute_tie.cuh +2 -0
  9. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/exception.cuh +43 -0
  10. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/math.cuh +153 -0
  11. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/tma_copy.cuh +92 -0
  12. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/types.cuh +43 -0
  13. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/utils.cuh +16 -149
  14. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd.cuh +137 -0
  15. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh +144 -0
  16. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/epilogue/transform.cuh +24 -0
  17. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh +150 -195
  18. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh +31 -25
  19. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh +457 -0
  20. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh +510 -0
  21. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh +514 -0
  22. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh +1380 -0
  23. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +9 -5
  24. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh +125 -126
  25. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh +205 -164
  26. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh +35 -30
  27. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh +47 -40
  28. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh +37 -28
  29. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh +79 -82
  30. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +55 -46
  31. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh +64 -63
  32. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh +76 -155
  33. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh +36 -29
  34. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh +30 -23
  35. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/smxx_layout.cuh +34 -21
  36. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/layout/mega_moe.cuh +260 -0
  37. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/layout/sym_buffer.cuh +41 -0
  38. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/mma/sm100.cuh +151 -0
  39. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/mma/sm90.cuh +293 -0
  40. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/ld_st.cuh +251 -0
  41. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/tcgen05.cuh +168 -0
  42. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/tma.cuh +112 -0
  43. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/utils.cuh +53 -0
  44. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/wgmma.cuh +25 -0
  45. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/scheduler/gemm.cuh +300 -0
  46. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/scheduler/mega_moe.cuh +221 -0
  47. build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/scheduler/paged_mqa_logits.cuh +239 -0
  48. build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp +0 -904
  49. build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp +0 -270
  50. build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm70_gemm_array.hpp +0 -279
.gitattributes CHANGED
@@ -46,3 +46,4 @@ build/torch211-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=
46
  build/torch211-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=lfs diff=lfs merge=lfs -text
47
  build/torch211-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=lfs diff=lfs merge=lfs -text
48
  build/torch29-cxx11-cu129-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=lfs diff=lfs merge=lfs -text
 
 
46
  build/torch211-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=lfs diff=lfs merge=lfs -text
47
  build/torch211-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=lfs diff=lfs merge=lfs -text
48
  build/torch29-cxx11-cu129-aarch64-linux/_deep_gemm_cuda_8546a43.abi3.so filter=lfs diff=lfs merge=lfs -text
49
+ build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so filter=lfs diff=lfs merge=lfs -text
build/torch210-cxx11-cu128-x86_64-linux/_C.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ._ops import ops
4
+
5
+
6
+ def set_num_sms(num_sms: int):
7
+ ops.set_num_sms(num_sms)
8
+
9
+
10
+ def get_num_sms() -> int:
11
+ return ops.get_num_sms()
12
+
13
+
14
+ def set_tc_util(tc_util: int):
15
+ ops.set_tc_util(tc_util)
16
+
17
+
18
+ def get_tc_util() -> int:
19
+ return ops.get_tc_util()
20
+
21
+
22
+ def set_ignore_compile_dims(value: bool):
23
+ ops.set_ignore_compile_dims(value)
24
+
25
+
26
+ def set_block_size_multiple_of(value):
27
+ if isinstance(value, tuple):
28
+ block_m, block_n = value
29
+ else:
30
+ block_m = block_n = value
31
+ ops.set_block_size_multiple_of(block_m, block_n)
32
+
33
+
34
+ def set_pdl(enable_pdl: bool):
35
+ ops.set_pdl(enable_pdl)
36
+
37
+
38
+ def get_pdl() -> bool:
39
+ return ops.get_pdl()
40
+
41
+
42
+ def set_mk_alignment_for_contiguous_layout(value: int):
43
+ ops.set_mk_alignment_for_contiguous_layout(value)
44
+
45
+
46
+ def get_mk_alignment_for_contiguous_layout() -> int:
47
+ return ops.get_mk_alignment_for_contiguous_layout()
48
+
49
+
50
+ def get_theoretical_mk_alignment_for_contiguous_layout(expected_m=None) -> int:
51
+ return ops.get_theoretical_mk_alignment_for_contiguous_layout(
52
+ 0 if expected_m is None else expected_m,
53
+ expected_m is not None,
54
+ )
55
+
56
+
57
+ def get_tma_aligned_size(mn: int, element_size: int) -> int:
58
+ return ops.get_tma_aligned_size(mn, element_size).item()
59
+
60
+
61
+ def get_mn_major_tma_aligned_tensor(sf):
62
+ return ops.get_mn_major_tma_aligned_tensor(sf)
63
+
64
+
65
+ def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf):
66
+ return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
67
+
68
+
69
+ def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
70
+ sf, ks_tensor, ks, gran_k
71
+ ):
72
+ ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu")
73
+ return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
74
+ sf, ks_tensor, ks_int, gran_k
75
+ )
76
+
77
+
78
+ def transform_sf_into_required_layout(
79
+ sf,
80
+ mn,
81
+ k,
82
+ recipe,
83
+ num_groups=None,
84
+ is_sfa=None,
85
+ disable_ue8m0_cast=False,
86
+ ):
87
+ if len(recipe) == 3:
88
+ r0, r1, r2 = recipe
89
+ recipe_len = 3
90
+ elif len(recipe) == 2:
91
+ r0, r1 = recipe
92
+ r2 = 0
93
+ recipe_len = 2
94
+ else:
95
+ raise ValueError("recipe must have length 2 or 3")
96
+
97
+ return ops.transform_sf_into_required_layout(
98
+ sf,
99
+ mn,
100
+ k,
101
+ r0,
102
+ r1,
103
+ r2,
104
+ recipe_len,
105
+ 0 if num_groups is None else num_groups,
106
+ num_groups is not None,
107
+ False if is_sfa is None else is_sfa,
108
+ is_sfa is not None,
109
+ disable_ue8m0_cast,
110
+ )
111
+
112
+
113
+ def get_token_alignment_for_mega_moe() -> int:
114
+ return ops.get_token_alignment_for_mega_moe()
115
+
116
+
117
+ def get_symm_buffer_size_for_mega_moe(
118
+ num_ranks,
119
+ num_experts,
120
+ num_max_tokens_per_rank,
121
+ num_topk,
122
+ hidden,
123
+ intermediate_hidden,
124
+ use_fp8_dispatch=True,
125
+ activation="swiglu",
126
+ ):
127
+ num_bytes = ops.get_symm_buffer_size_for_mega_moe(
128
+ num_ranks,
129
+ num_experts,
130
+ num_max_tokens_per_rank,
131
+ num_topk,
132
+ hidden,
133
+ intermediate_hidden,
134
+ use_fp8_dispatch,
135
+ activation,
136
+ )
137
+
138
+ def slice_input_buffers(buffer):
139
+ return tuple(
140
+ ops.get_symm_buffer_views_for_mega_moe(
141
+ buffer,
142
+ num_ranks,
143
+ num_experts,
144
+ num_max_tokens_per_rank,
145
+ num_topk,
146
+ hidden,
147
+ intermediate_hidden,
148
+ use_fp8_dispatch,
149
+ activation,
150
+ )
151
+ )
152
+
153
+ return num_bytes, slice_input_buffers
154
+
155
+
156
+ def fp8_fp4_mega_moe(
157
+ y,
158
+ l1_weights,
159
+ l2_weights,
160
+ cumulative_local_expert_recv_stats,
161
+ sym_buffer,
162
+ sym_buffer_ptrs,
163
+ rank_idx,
164
+ num_max_tokens_per_rank,
165
+ num_experts,
166
+ num_topk,
167
+ recipe,
168
+ activation,
169
+ activation_clamp,
170
+ fast_math,
171
+ ):
172
+ l1_weights_data, l1_weights_sf = l1_weights
173
+ l2_weights_data, l2_weights_sf = l2_weights
174
+ r0, r1, r2 = recipe
175
+ ops.fp8_fp4_mega_moe(
176
+ y,
177
+ l1_weights_data,
178
+ l1_weights_sf,
179
+ l2_weights_data,
180
+ l2_weights_sf,
181
+ cumulative_local_expert_recv_stats,
182
+ sym_buffer,
183
+ sym_buffer_ptrs,
184
+ rank_idx,
185
+ num_max_tokens_per_rank,
186
+ num_experts,
187
+ num_topk,
188
+ r0,
189
+ r1,
190
+ r2,
191
+ activation,
192
+ activation_clamp,
193
+ fast_math,
194
+ )
build/torch210-cxx11-cu128-x86_64-linux/__init__.py CHANGED
@@ -1,12 +1,18 @@
1
  import os
2
  import subprocess
 
3
  import torch
4
 
 
 
 
 
 
5
  # Import the compiled extension
6
- from ._ops import ops, add_op_namespace_prefix
7
  from . import utils
8
 
9
- __version__ = "2.3.0"
10
 
11
 
12
  # ── Register fake tensor implementations for torch.compile ──────────────────
@@ -32,6 +38,7 @@ for _op in [
32
  "m_grouped_bf16_gemm_nn_contiguous",
33
  "m_grouped_bf16_gemm_nt_masked",
34
  "fp8_gemm_nt_skip_head_mid",
 
35
  ]:
36
 
37
  @torch.library.register_fake(add_op_namespace_prefix(_op))
@@ -58,10 +65,41 @@ def get_tc_util() -> int:
58
  return ops.get_tc_util()
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def get_mk_alignment_for_contiguous_layout() -> int:
62
  return ops.get_mk_alignment_for_contiguous_layout()
63
 
64
 
 
 
 
 
 
 
 
65
  # Layout utilities
66
 
67
 
@@ -77,10 +115,12 @@ def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf):
77
  return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
78
 
79
 
80
- def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks):
 
 
81
  ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu")
82
  return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
83
- sf, ks_tensor, ks_int
84
  )
85
 
86
 
@@ -88,16 +128,20 @@ def transform_sf_into_required_layout(
88
  sf,
89
  mn,
90
  k,
91
- recipe=None,
92
- recipe_ab=None,
93
  num_groups=None,
94
- is_sfa=False,
95
  disable_ue8m0_cast=False,
96
  ):
97
- has_recipe = recipe is not None
98
- r0, r1, r2 = recipe if has_recipe else (0, 0, 0)
99
- has_recipe_ab = recipe_ab is not None
100
- rab0, rab1 = recipe_ab if has_recipe_ab else (0, 0)
 
 
 
 
 
101
  has_ng = num_groups is not None
102
  ng = num_groups if has_ng else 0
103
  return ops.transform_sf_into_required_layout(
@@ -107,13 +151,11 @@ def transform_sf_into_required_layout(
107
  r0,
108
  r1,
109
  r2,
110
- has_recipe,
111
- rab0,
112
- rab1,
113
- has_recipe_ab,
114
  ng,
115
  has_ng,
116
- is_sfa,
 
117
  disable_ue8m0_cast,
118
  )
119
 
@@ -593,8 +635,37 @@ def fp8_mqa_logits(
593
  )
594
 
595
 
596
- def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms):
597
- return ops.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598
 
599
 
600
  def fp8_paged_mqa_logits(
@@ -606,6 +677,7 @@ def fp8_paged_mqa_logits(
606
  schedule_meta,
607
  max_context_len,
608
  clean_logits=False,
 
609
  ):
610
  return ops.fp8_paged_mqa_logits(
611
  q,
@@ -616,6 +688,38 @@ def fp8_paged_mqa_logits(
616
  schedule_meta,
617
  max_context_len,
618
  clean_logits,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
  )
620
 
621
 
@@ -642,6 +746,14 @@ def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None):
642
  ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns, has_ns)
643
 
644
 
 
 
 
 
 
 
 
 
645
  # Initialize the C++ runtime
646
 
647
 
@@ -683,6 +795,14 @@ if "DG_CUTLASS_INCLUDE" not in os.environ:
683
  _include, # legacy layout: include/cutlass
684
  os.path.join(_include, "third-party", "cutlass", "include"), # submodule layout
685
  ]
 
 
 
 
 
 
 
 
686
  for _cutlass_include in _cutlass_include_candidates:
687
  if os.path.isdir(os.path.join(_cutlass_include, "cutlass")):
688
  os.environ["DG_CUTLASS_INCLUDE"] = _cutlass_include
@@ -703,8 +823,21 @@ def _ensure_initialized():
703
  global _initialized
704
  if _initialized:
705
  return
 
706
  _initialized = True
707
- ops.init(_lib_root, _find_cuda_home())
 
 
 
 
 
 
 
 
 
 
 
 
708
 
709
 
710
  # Try to initialize eagerly, but don't fail if CUDA is not found
 
1
  import os
2
  import subprocess
3
+ import sysconfig
4
  import torch
5
 
6
+ # Avoid holding a CUDA tensor in DeepGEMM's process-lifetime runtime singleton.
7
+ # In packaged/lazy-loaded use, that can outlive PyTorch's CUDA teardown and crash
8
+ # during interpreter shutdown.
9
+ os.environ.setdefault("DG_USE_TEMP_CUBLASLT_WORKSPACE", "1")
10
+
11
  # Import the compiled extension
12
+ from ._ops import ops as _ops, add_op_namespace_prefix
13
  from . import utils
14
 
15
+ __version__ = "2.5.0"
16
 
17
 
18
  # ── Register fake tensor implementations for torch.compile ──────────────────
 
38
  "m_grouped_bf16_gemm_nn_contiguous",
39
  "m_grouped_bf16_gemm_nt_masked",
40
  "fp8_gemm_nt_skip_head_mid",
41
+ "fp8_fp4_mega_moe",
42
  ]:
43
 
44
  @torch.library.register_fake(add_op_namespace_prefix(_op))
 
65
  return ops.get_tc_util()
66
 
67
 
68
+ def set_ignore_compile_dims(value: bool):
69
+ ops.set_ignore_compile_dims(value)
70
+
71
+
72
+ def set_block_size_multiple_of(value):
73
+ if isinstance(value, tuple):
74
+ block_m, block_n = value
75
+ else:
76
+ block_m = block_n = value
77
+ ops.set_block_size_multiple_of(block_m, block_n)
78
+
79
+
80
+ def set_pdl(enable_pdl: bool):
81
+ ops.set_pdl(enable_pdl)
82
+
83
+
84
+ def get_pdl() -> bool:
85
+ return ops.get_pdl()
86
+
87
+
88
+ def set_mk_alignment_for_contiguous_layout(alignment: int):
89
+ ops.set_mk_alignment_for_contiguous_layout(alignment)
90
+
91
+
92
  def get_mk_alignment_for_contiguous_layout() -> int:
93
  return ops.get_mk_alignment_for_contiguous_layout()
94
 
95
 
96
+ def get_theoretical_mk_alignment_for_contiguous_layout(expected_m=None) -> int:
97
+ return ops.get_theoretical_mk_alignment_for_contiguous_layout(
98
+ 0 if expected_m is None else expected_m,
99
+ expected_m is not None,
100
+ )
101
+
102
+
103
  # Layout utilities
104
 
105
 
 
115
  return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
116
 
117
 
118
+ def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
119
+ sf, ks_tensor, ks, gran_k
120
+ ):
121
  ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu")
122
  return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(
123
+ sf, ks_tensor, ks_int, gran_k
124
  )
125
 
126
 
 
128
  sf,
129
  mn,
130
  k,
131
+ recipe,
 
132
  num_groups=None,
133
+ is_sfa=None,
134
  disable_ue8m0_cast=False,
135
  ):
136
+ if len(recipe) == 3:
137
+ r0, r1, r2 = recipe
138
+ recipe_len = 3
139
+ elif len(recipe) == 2:
140
+ r0, r1 = recipe
141
+ r2 = 0
142
+ recipe_len = 2
143
+ else:
144
+ raise ValueError("recipe must have length 2 or 3")
145
  has_ng = num_groups is not None
146
  ng = num_groups if has_ng else 0
147
  return ops.transform_sf_into_required_layout(
 
151
  r0,
152
  r1,
153
  r2,
154
+ recipe_len,
 
 
 
155
  ng,
156
  has_ng,
157
+ False if is_sfa is None else is_sfa,
158
+ is_sfa is not None,
159
  disable_ue8m0_cast,
160
  )
161
 
 
635
  )
636
 
637
 
638
+ def fp8_fp4_mqa_logits(
639
+ q,
640
+ kv,
641
+ weights,
642
+ cu_seq_len_k_start,
643
+ cu_seq_len_k_end,
644
+ clean_logits=True,
645
+ max_seqlen_k=0,
646
+ logits_dtype=torch.float32,
647
+ ):
648
+ if isinstance(q, tuple):
649
+ q_data, q_sf = q
650
+ else:
651
+ q_data, q_sf = q, None
652
+ kv_data, kv_sf = kv
653
+ return ops.fp8_fp4_mqa_logits(
654
+ q_data,
655
+ q_sf,
656
+ kv_data,
657
+ kv_sf,
658
+ weights,
659
+ cu_seq_len_k_start,
660
+ cu_seq_len_k_end,
661
+ clean_logits,
662
+ max_seqlen_k,
663
+ logits_dtype,
664
+ )
665
+
666
+
667
+ def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms, indices=None):
668
+ return ops.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms, indices)
669
 
670
 
671
  def fp8_paged_mqa_logits(
 
677
  schedule_meta,
678
  max_context_len,
679
  clean_logits=False,
680
+ indices=None,
681
  ):
682
  return ops.fp8_paged_mqa_logits(
683
  q,
 
688
  schedule_meta,
689
  max_context_len,
690
  clean_logits,
691
+ indices,
692
+ )
693
+
694
+
695
+ def fp8_fp4_paged_mqa_logits(
696
+ q,
697
+ kv_cache,
698
+ weights,
699
+ context_lens,
700
+ block_table,
701
+ schedule_meta,
702
+ max_context_len,
703
+ clean_logits=False,
704
+ logits_dtype=torch.float32,
705
+ indices=None,
706
+ ):
707
+ if isinstance(q, tuple):
708
+ q_data, q_sf = q
709
+ else:
710
+ q_data, q_sf = q, None
711
+ return ops.fp8_fp4_paged_mqa_logits(
712
+ q_data,
713
+ q_sf,
714
+ kv_cache,
715
+ weights,
716
+ context_lens,
717
+ block_table,
718
+ schedule_meta,
719
+ max_context_len,
720
+ clean_logits,
721
+ logits_dtype,
722
+ indices,
723
  )
724
 
725
 
 
746
  ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns, has_ns)
747
 
748
 
749
+ from .mega import (
750
+ SymmBuffer,
751
+ get_symm_buffer_for_mega_moe,
752
+ transform_weights_for_mega_moe,
753
+ fp8_fp4_mega_moe,
754
+ )
755
+
756
+
757
  # Initialize the C++ runtime
758
 
759
 
 
795
  _include, # legacy layout: include/cutlass
796
  os.path.join(_include, "third-party", "cutlass", "include"), # submodule layout
797
  ]
798
+ for _site_packages in {
799
+ sysconfig.get_paths().get("purelib"),
800
+ sysconfig.get_paths().get("platlib"),
801
+ }:
802
+ if _site_packages:
803
+ _cutlass_include_candidates.append(
804
+ os.path.join(_site_packages, "cutlass_library", "source", "include")
805
+ )
806
  for _cutlass_include in _cutlass_include_candidates:
807
  if os.path.isdir(os.path.join(_cutlass_include, "cutlass")):
808
  os.environ["DG_CUTLASS_INCLUDE"] = _cutlass_include
 
823
  global _initialized
824
  if _initialized:
825
  return
826
+ _ops.init(_lib_root, _find_cuda_home())
827
  _initialized = True
828
+
829
+
830
+ class _InitializedOps:
831
+ def __init__(self, raw_ops):
832
+ self._raw_ops = raw_ops
833
+
834
+ def __getattr__(self, name):
835
+ if name != "init":
836
+ _ensure_initialized()
837
+ return getattr(self._raw_ops, name)
838
+
839
+
840
+ ops = _InitializedOps(_ops)
841
 
842
 
843
  # Try to initialize eagerly, but don't fail if CUDA is not found
build/torch210-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_388adb9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2bff23699a1ab0aa2a92bab110612828e10cd623f2f626002ca4a1eba38668e
3
+ size 3381200
build/torch210-cxx11-cu128-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _deep_gemm_cuda_8546a43
3
- ops = torch.ops._deep_gemm_cuda_8546a43
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_deep_gemm_cuda_8546a43::{op_name}"
 
1
  import torch
2
+ from . import _deep_gemm_cuda_388adb9
3
+ ops = torch.ops._deep_gemm_cuda_388adb9
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_deep_gemm_cuda_388adb9::{op_name}"
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/comm/barrier.cuh ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cutlass/arch/barrier.h>
4
+
5
+ #include <deep_gemm/ptx/ld_st.cuh>
6
+ #include <deep_gemm/layout/sym_buffer.cuh>
7
+ #include <deep_gemm/layout/mega_moe.cuh>
8
+
9
+ namespace deep_gemm::comm {
10
+
11
+ CUTLASS_DEVICE void cluster_sync_with_relaxed_arrive() {
12
+ // Perform cluster_sync with `barrier.cluster.arrive.relaxed`
13
+ // This is slightly faster than `cute::cluster_sync` but has weaker memory ordering guarantee
14
+ cute::cluster_arrive_relaxed();
15
+ cute::cluster_wait();
16
+ }
17
+
18
+ template <uint32_t kNumSMs, uint32_t kGridSyncIndex = 0, typename sync_scope_t>
19
+ CUTLASS_DEVICE void grid_sync(const layout::Workspace& workspace,
20
+ const uint32_t& sm_idx, const uint32_t& thread_idx,
21
+ const sync_scope_t& sync_scope) {
22
+ // NOTES: the implementation idea is from `cooperative_groups::this_grid().sync()`
23
+ static constexpr uint32_t kFinishSumTag = 0x80000000u;
24
+ sync_scope();
25
+ if (thread_idx == 0) {
26
+ const auto count_ptr = workspace.get_grid_sync_count_ptr<kGridSyncIndex>();
27
+ const auto old_value = ptx::atomic_add_rel(
28
+ count_ptr, sm_idx == 0 ? (kFinishSumTag - (kNumSMs - 1)) : 1);
29
+ uint32_t new_value;
30
+ do {
31
+ new_value = ptx::ld_acq(count_ptr);
32
+ } while (((new_value ^ old_value) & kFinishSumTag) == 0);
33
+ }
34
+ sync_scope();
35
+ }
36
+
37
+ template <uint32_t kNumRanks, uint32_t kNumSMs, uint32_t kNumThreads, uint32_t kGridSyncIndex, uint32_t kTag, typename sync_scope_t>
38
+ CUTLASS_DEVICE void nvlink_barrier(const layout::Workspace& workspace,
39
+ const layout::SymBuffer<kNumRanks>& sym_buffer,
40
+ const uint32_t& sm_idx, const uint32_t& thread_idx,
41
+ const sync_scope_t& sync_scope,
42
+ const bool& sync_prologue = true,
43
+ const bool& sync_epilogue = true) {
44
+ DG_STATIC_ASSERT(kNumRanks <= kNumThreads, "Insufficient threads");
45
+
46
+ // Grid sync before NVLink signaling
47
+ if (sync_prologue)
48
+ grid_sync<kNumSMs, kGridSyncIndex>(workspace, sm_idx, thread_idx, sync_scope);
49
+
50
+ // NVLink cross-rank barrier, only SM 0 participates
51
+ if (sm_idx == 0) {
52
+ auto* counter_ptr = workspace.get_nvl_barrier_counter_ptr();
53
+ const auto status = (*counter_ptr) & 3;
54
+ const auto signal_phase = status & 1, signal_sign = status >> 1;
55
+ auto* signal_ptr = workspace.get_nvl_barrier_signal_ptr(signal_phase);
56
+
57
+ // Send signals to remote ranks
58
+ if (thread_idx < kNumRanks)
59
+ ptx::red_add_rel_sys(sym_buffer.map(signal_ptr, thread_idx), signal_sign ? -1 : 1);
60
+ sync_scope();
61
+
62
+ // Update status and wait arrival (with 30s timeout, at 2 GHz)
63
+ constexpr int64_t kNumTimeoutCycles = 30ll * 2000000000ll;
64
+ if (thread_idx == 0) {
65
+ ptx::red_add(counter_ptr, 1);
66
+ const int target = signal_sign ? 0 : static_cast<int>(kNumRanks);
67
+ const auto start_clock = clock64();
68
+ while (ptx::ld_acq_sys(signal_ptr) != target) {
69
+ if (clock64() - start_clock >= kNumTimeoutCycles) {
70
+ printf("DeepGEMM NVLink barrier timeout (30s): rank=%d, counter=%d, signal=%d, target=%d, phase=%d, sign=%d, tag=%d\n",
71
+ sym_buffer.rank_idx, *counter_ptr, ptx::ld_acq_sys(signal_ptr), target, signal_phase, signal_sign, kTag);
72
+ DG_DEVICE_ASSERT(false and "NVLink barrier timeout");
73
+ }
74
+ }
75
+ }
76
+ }
77
+
78
+ // Grid sync after NVLink completion
79
+ if (sync_epilogue)
80
+ grid_sync<kNumSMs, kGridSyncIndex>(workspace, sm_idx, thread_idx, sync_scope);
81
+ }
82
+
83
+ } // namespace deep_gemm::comm
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/compile.cuh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cutlass/detail/helper_macros.hpp>
4
+
5
+ #if defined(__NVCC__) or (defined(__clang__) and defined(__CUDA__)) or defined(__CUDACC_RTC__) or defined(__CLION_IDE__)
6
+ #define DG_IN_CUDA_COMPILATION
7
+ #endif
8
+
9
+ #if defined(__NVCC__) || (defined(__clang__) and defined(__CUDA__))
10
+ #define CUTLASS_HOST_DEVICE_NOINLINE __device__ __host__
11
+ #define CUTLASS_DEVICE_NOINLINE __device__
12
+ #elif defined(__CUDACC_RTC__)
13
+ #define CUTLASS_HOST_DEVICE_NOINLINE __device__
14
+ #define CUTLASS_DEVICE_NOINLINE __device__
15
+ #else
16
+ #define CUTLASS_HOST_DEVICE_NOINLINE
17
+ #define CUTLASS_DEVICE_NOINLINE
18
+ #endif
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/cute_tie.cuh CHANGED
@@ -1,5 +1,7 @@
1
  #pragma once
2
 
 
 
3
  namespace cute {
4
 
5
  struct ignore_t {
 
1
  #pragma once
2
 
3
+ #include <cute/int_tuple.hpp>
4
+
5
  namespace cute {
6
 
7
  struct ignore_t {
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/exception.cuh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cuda/std/cstdint>
4
+ #include <deep_gemm/common/compile.cuh>
5
+
6
+ #ifdef __CLION_IDE__
7
+
8
+ CUTLASS_HOST_DEVICE void host_device_printf(const char* format, ...) {
9
+ asm volatile("trap;");
10
+ }
11
+
12
+ #define printf host_device_printf
13
+ #endif
14
+
15
+ #ifndef DG_DEVICE_ASSERT
16
+ #define DG_DEVICE_ASSERT(cond) \
17
+ do { \
18
+ if (not (cond)) { \
19
+ printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
20
+ asm("trap;"); \
21
+ } \
22
+ } while (0)
23
+ #endif
24
+
25
+ #ifndef DG_TRAP_ONLY_DEVICE_ASSERT
26
+ #define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \
27
+ do { \
28
+ if (not (cond)) \
29
+ asm("trap;"); \
30
+ } while (0)
31
+ #endif
32
+
33
+ #ifndef DG_STATIC_ASSERT
34
+ #define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
35
+ #endif
36
+
37
+ #ifndef DG_UNIFIED_ASSERT
38
+ #ifdef DG_IN_CUDA_COMPILATION
39
+ #define DG_UNIFIED_ASSERT(cond) DG_DEVICE_ASSERT(cond)
40
+ #else
41
+ #define DG_UNIFIED_ASSERT(cond) DG_HOST_ASSERT(cond)
42
+ #endif
43
+ #endif
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/math.cuh ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cuda/std/cstdint>
4
+ #include <deep_gemm/common/compile.cuh>
5
+ #include <deep_gemm/common/exception.cuh>
6
+
7
+ namespace deep_gemm::math {
8
+
9
+ /// Pointer operations
10
+ template <typename dtype_t = void>
11
+ CUTLASS_HOST_DEVICE dtype_t* advance_ptr(void* ptr, const uint64_t num_bytes) {
12
+ return reinterpret_cast<dtype_t*>(static_cast<uint8_t*>(ptr) + num_bytes);
13
+ }
14
+
15
+ /// Math functions
16
+ template <typename T>
17
+ CUTLASS_HOST_DEVICE T ceil_div(T a, T b) {
18
+ return (a + b - 1) / b;
19
+ }
20
+
21
+ template <typename T>
22
+ CUTLASS_HOST_DEVICE constexpr T constexpr_ceil_div(T a, T b) {
23
+ return (a + b - 1) / b;
24
+ }
25
+
26
+ template <typename T, bool kDoCeilAlignment = true>
27
+ CUTLASS_HOST_DEVICE T align(T a, T b) {
28
+ return (kDoCeilAlignment ? ceil_div(a, b) : (a / b)) * b;
29
+ }
30
+
31
+ template <typename T>
32
+ CUTLASS_HOST_DEVICE constexpr T constexpr_align(T a, T b) {
33
+ return constexpr_ceil_div(a, b) * b;
34
+ }
35
+
36
+ template <typename T>
37
+ CUTLASS_HOST_DEVICE constexpr T constexpr_gcd(T a, T b) {
38
+ return b == 0 ? a : constexpr_gcd(b, a % b);
39
+ }
40
+
41
+ template <typename T>
42
+ CUTLASS_HOST_DEVICE constexpr T constexpr_min(T a, T b) {
43
+ return a < b ? a : b;
44
+ }
45
+
46
+ template <typename T>
47
+ CUTLASS_DEVICE void swap(T& a, T& b) {
48
+ T temp = a;
49
+ a = b;
50
+ b = temp;
51
+ }
52
+
53
+ #ifdef DG_IN_CUDA_COMPILATION
54
+ CUTLASS_DEVICE float2 fma2(const float2& a, const float2& b, const float2& c) {
55
+ #if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)
56
+ return __ffma2_rn(a, b, c);
57
+ #else
58
+ return make_float2(
59
+ __fmaf_rn(a.x, b.x, c.x),
60
+ __fmaf_rn(a.y, b.y, c.y)
61
+ );
62
+ #endif
63
+ }
64
+
65
+ CUTLASS_HOST_DEVICE float fast_rcp(const float& x) {
66
+ #if defined(__CUDA_ARCH__)
67
+ float ret;
68
+ asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(ret) : "f"(x));
69
+ return ret;
70
+ #else
71
+ return 1.0f / x;
72
+ #endif
73
+ }
74
+
75
+ /// Casting
76
+ template <typename old_t>
77
+ CUTLASS_DEVICE int cast_into_bf16_and_pack(old_t& x, old_t& y) {
78
+ auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast<float*>(&x), *reinterpret_cast<float*>(&y)});
79
+ return *reinterpret_cast<int*>(&bf16x2);
80
+ }
81
+
82
+ CUTLASS_DEVICE float fast_pow2(const int& x) {
83
+ uint32_t bits_x = (x + 127) << 23;
84
+ return *reinterpret_cast<float*>(&bits_x);
85
+ }
86
+
87
+ CUTLASS_DEVICE int fast_log2_ceil(float x) {
88
+ const auto bits = *reinterpret_cast<uint32_t*>(&x);
89
+ const auto exp = bits >> 23;
90
+ const auto man = bits & ((1 << 23) - 1);
91
+ return exp - 127 + (man != 0);
92
+ }
93
+
94
+ template <bool kUseUE8M0 = true>
95
+ CUTLASS_DEVICE void get_e4m3_sf_and_sf_inv(const float2& amax, float2& sf, float2& sf_inv) {
96
+ DG_STATIC_ASSERT(kUseUE8M0, "Must use UE8M0");
97
+ const float2 finfo_factor = {1.0 / 448.0, 1.0 / 448.0};
98
+ const auto scaled = __fmul2_rn(amax, finfo_factor);
99
+ const auto exp_x = fast_log2_ceil(scaled.x);
100
+ const auto exp_y = fast_log2_ceil(scaled.y);
101
+ sf.x = fast_pow2(exp_x), sf_inv.x = fast_pow2(-exp_x);
102
+ sf.y = fast_pow2(exp_y), sf_inv.y = fast_pow2(-exp_y);
103
+ }
104
+
105
+ /// Reduction
106
+ CUTLASS_DEVICE uint32_t warp_inclusive_sum(uint32_t value, const uint32_t& lane_idx) {
107
+ #pragma unroll
108
+ for (uint32_t offset = 1; offset < 32; offset <<= 1) {
109
+ const uint32_t synced = __shfl_up_sync(0xffffffff, value, offset);
110
+ if (lane_idx >= offset)
111
+ value += synced;
112
+ }
113
+ return value;
114
+ }
115
+
116
+ // Operation functors
117
+ template <typename T> struct ReduceSum { CUTLASS_DEVICE T operator()(T a, T b) const { return a + b; } };
118
+ template <typename T> struct ReduceMax { CUTLASS_DEVICE T operator()(T a, T b) const { return a > b ? a : b; } };
119
+ template <typename T> struct ReduceMin { CUTLASS_DEVICE T operator()(T a, T b) const { return a < b ? a : b; } };
120
+ template <typename T> struct ReduceAnd { CUTLASS_DEVICE T operator()(T a, T b) const { return a & b; } };
121
+ template <typename T> struct ReduceOr { CUTLASS_DEVICE T operator()(T a, T b) const { return a | b; } };
122
+
123
+ // Unified reduction function
124
+ template <uint32_t kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>
125
+ CUTLASS_DEVICE T warp_reduce(T value, Op op) {
126
+ DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or
127
+ kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,
128
+ "Invalid number of lanes");
129
+ constexpr uint32_t mask = 0xffffffff;
130
+ if constexpr (kIntergroupReduce) {
131
+ if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1));
132
+ if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2));
133
+ if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4));
134
+ if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8));
135
+ if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16));
136
+ } else {
137
+ if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16));
138
+ if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8));
139
+ if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4));
140
+ if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2));
141
+ if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1));
142
+ }
143
+ return value;
144
+ }
145
+
146
+ // Convenience aliases
147
+ template <uint32_t kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
148
+ CUTLASS_DEVICE T warp_reduce_sum(T value) {
149
+ return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});
150
+ }
151
+ #endif
152
+
153
+ } // namespace deep_gemm
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/tma_copy.cuh ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cute/arch/copy_sm90_tma.hpp>
4
+ #include <cute/arch/copy_sm100_tma.hpp>
5
+ #include <cutlass/arch/barrier.h>
6
+
7
+ #include <deep_gemm/common/exception.cuh>
8
+
9
+ namespace deep_gemm::tma {
10
+
11
+ template <uint32_t BLOCK_INNER, uint32_t kSwizzleMode, typename dtype_t>
12
+ constexpr uint32_t get_inner_block_atom_size() {
13
+ return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t);
14
+ }
15
+
16
+ template <uint32_t BLOCK_INNER, uint32_t BLOCK_OUTER,
17
+ uint32_t kSwizzleMode,
18
+ typename dtype_t, bool kIs3DTMA = false>
19
+ CUTLASS_DEVICE void
20
+ copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr,
21
+ dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx,
22
+ const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) {
23
+ DG_STATIC_ASSERT(static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL) ==
24
+ static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint");
25
+ constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size<BLOCK_INNER, kSwizzleMode, dtype_t>();
26
+
27
+ if constexpr (not kIs3DTMA) {
28
+ if (num_tma_multicast == 1) {
29
+ #pragma unroll
30
+ for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
31
+ cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
32
+ static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
33
+ smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
34
+ inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
35
+ }
36
+ } else {
37
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000))
38
+ // 2-CTA function will send signals to the leader CTA only
39
+ #pragma unroll
40
+ for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
41
+ cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
42
+ static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
43
+ smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
44
+ inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
45
+ }
46
+ #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900))
47
+ if (cute::block_rank_in_cluster() == 0) {
48
+ #pragma unroll
49
+ for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
50
+ cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
51
+ (1 << num_tma_multicast) - 1, static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL),
52
+ smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
53
+ inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
54
+ }
55
+ }
56
+ #endif
57
+ }
58
+ } else {
59
+ if (num_tma_multicast == 1) {
60
+ #pragma unroll
61
+ for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
62
+ cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
63
+ static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
64
+ smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
65
+ inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
66
+ }
67
+ } else {
68
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000))
69
+ // 2-CTA function will send signals to the leader CTA only
70
+ #pragma unroll
71
+ for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
72
+ cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
73
+ static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
74
+ smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
75
+ inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
76
+ }
77
+ #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900))
78
+ if (cute::block_rank_in_cluster() == 0) {
79
+ #pragma unroll
80
+ for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
81
+ cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
82
+ (1 << num_tma_multicast) - 1, static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL),
83
+ smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
84
+ inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
85
+ }
86
+ }
87
+ #endif
88
+ }
89
+ }
90
+ }
91
+
92
+ } // namespace deep_gemm::tma
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/types.cuh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cute/arch/mma_sm100_desc.hpp>
4
+
5
+ namespace deep_gemm {
6
+
7
+ enum class MmaKind {
8
+ BF16 = 0,
9
+ MXFP8FP4 = 1,
10
+ };
11
+
12
+ constexpr CUTLASS_HOST_DEVICE int get_element_size(const MmaKind& mma_kind) {
13
+ switch (mma_kind) {
14
+ case MmaKind::BF16: return 2;
15
+ case MmaKind::MXFP8FP4: return 1;
16
+ default: return 0;
17
+ }
18
+ }
19
+
20
+ enum class GemmType {
21
+ Normal = 0,
22
+ MGroupedContiguous = 1,
23
+ MGroupedMasked = 2,
24
+ KGroupedContiguous = 3,
25
+ Batched = 4,
26
+ MGroupedContiguousWithPsumLayout = 5,
27
+ };
28
+
29
+ constexpr CUTLASS_HOST_DEVICE bool is_m_grouped_contiguous(const GemmType& gemm_type) {
30
+ switch (gemm_type) {
31
+ case GemmType::MGroupedContiguous: return true;
32
+ case GemmType::MGroupedContiguousWithPsumLayout: return true;
33
+ default: return false;
34
+ }
35
+ }
36
+
37
+ enum class KernelType {
38
+ Kernel1D1D = 0,
39
+ Kernel1D2D = 1,
40
+ KernelNoSF = 2
41
+ };
42
+
43
+ } // namespace deep_gemm
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/common/utils.cuh CHANGED
@@ -1,167 +1,24 @@
1
  #pragma once
2
 
3
- #include <cuda_bf16.h>
4
- #include <cuda_fp8.h>
5
  #include <cuda/std/cstdint>
6
- #include <cuda/std/utility>
7
- #include <cute/container/tuple.hpp>
8
 
9
- #include "cute_tie.cuh"
10
 
11
- #ifdef __CLION_IDE__
12
-
13
- __host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) {
14
- asm volatile("trap;");
15
- }
16
-
17
- #define printf host_device_printf
18
- #endif
19
-
20
- #ifndef DG_DEVICE_ASSERT
21
- #define DG_DEVICE_ASSERT(cond) \
22
- do { \
23
- if (not (cond)) { \
24
- printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
25
- asm("trap;"); \
26
- } \
27
- } while (0)
28
- #endif
29
-
30
- #ifndef DG_TRAP_ONLY_DEVICE_ASSERT
31
- #define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \
32
- do { \
33
- if (not (cond)) \
34
- asm("trap;"); \
35
- } while (0)
36
- #endif
37
-
38
- #ifndef DG_STATIC_ASSERT
39
- #define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
40
- #endif
41
-
42
- namespace deep_gemm {
43
 
44
  template <typename FuncT>
45
  struct PatternVisitor {
46
  FuncT func;
47
 
48
- __device__ __host__
49
  explicit PatternVisitor(FuncT&& func): func(std::forward<FuncT>(func)) {}
50
 
51
- __device__ __host__
52
- auto operator [](const uint32_t& i) {
53
  return func(i);
54
  }
55
  };
56
 
57
- template <typename T>
58
- __device__ __host__ T ceil_div(T a, T b) {
59
- return (a + b - 1) / b;
60
- }
61
-
62
- template <typename T>
63
- __device__ __host__ constexpr T constexpr_ceil_div(T a, T b) {
64
- return (a + b - 1) / b;
65
- }
66
-
67
- template <typename T>
68
- __device__ __host__ T align(T a, T b) {
69
- return ceil_div(a, b) * b;
70
- }
71
-
72
- template <typename T>
73
- __device__ __host__ constexpr T constexpr_align(T a, T b) {
74
- return constexpr_ceil_div(a, b) * b;
75
- }
76
-
77
- template <typename T>
78
- __device__ __host__ constexpr T constexpr_gcd(T a, T b) {
79
- return b == 0 ? a : constexpr_gcd(b, a % b);
80
- }
81
-
82
- template<typename T>
83
- __forceinline__ __device__ void swap(T& a, T& b) {
84
- T temp = a;
85
- a = b;
86
- b = temp;
87
- }
88
-
89
- __forceinline__ __device__ uint32_t get_sm_idx() {
90
- uint32_t sm_idx;
91
- asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx));
92
- return sm_idx;
93
- }
94
-
95
- __forceinline__ __device__ uint32_t get_lane_idx() {
96
- uint32_t lane_id;
97
- asm ("mov.u32 %0, %laneid;" : "=r"(lane_id));
98
- return lane_id;
99
- }
100
-
101
- __device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) {
102
- uint32_t ret;
103
- asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr)));
104
- return ret;
105
- }
106
-
107
- __device__ __forceinline__ float2 ld_shared(const float2* ptr) {
108
- float2 ret;
109
- asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr)));
110
- return ret;
111
- }
112
-
113
- __device__ __forceinline__ float4 ld_shared(const float4* ptr) {
114
- float4 ret;
115
- asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr)));
116
- return ret;
117
- }
118
-
119
- __device__ __forceinline__ uint4 ld_shared(const uint4* ptr) {
120
- uint4 ret;
121
- asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr)));
122
- return ret;
123
- }
124
-
125
- __device__ __forceinline__ float ld_shared(const float* ptr) {
126
- float ret;
127
- asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr)));
128
- return ret;
129
- }
130
-
131
- __device__ __forceinline__ void st_shared(const float* ptr, float val) {
132
- asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val));
133
- }
134
-
135
- __device__ __forceinline__ void st_shared(const float2* ptr, float2 val) {
136
- asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y));
137
- }
138
-
139
- __device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
140
- asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val));
141
- }
142
-
143
- __device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) {
144
- asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y));
145
- }
146
-
147
- __device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) {
148
- asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w));
149
- }
150
-
151
- __device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) {
152
- asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val));
153
- }
154
-
155
- template <typename old_t>
156
- __device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) {
157
- auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast<float*>(&x), *reinterpret_cast<float*>(&y)});
158
- return *reinterpret_cast<int*>(&bf16x2);
159
- }
160
-
161
- __device__ __forceinline__ void prefetch_l1(void *ptr) {
162
- asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr));
163
- }
164
-
165
  template <uint32_t kNumBytes>
166
  struct Vectorized {
167
  static auto zeros() {
@@ -180,4 +37,14 @@ struct Vectorized {
180
  using vec_t = decltype(zeros());
181
  };
182
 
183
- } // namespace `deep_gemm`
 
 
 
 
 
 
 
 
 
 
 
1
  #pragma once
2
 
 
 
3
  #include <cuda/std/cstdint>
 
 
4
 
5
+ #include <deep_gemm/common/exception.cuh>
6
 
7
+ namespace deep_gemm::utils {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  template <typename FuncT>
10
  struct PatternVisitor {
11
  FuncT func;
12
 
13
+ CUTLASS_HOST_DEVICE
14
  explicit PatternVisitor(FuncT&& func): func(std::forward<FuncT>(func)) {}
15
 
16
+ CUTLASS_HOST_DEVICE
17
+ auto operator [](const uint32_t& i) const {
18
  return func(i);
19
  }
20
  };
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  template <uint32_t kNumBytes>
23
  struct Vectorized {
24
  static auto zeros() {
 
37
  using vec_t = decltype(zeros());
38
  };
39
 
40
+ template <uint32_t kNumCols>
41
+ CUTLASS_DEVICE constexpr uint32_t get_num_aligned_tmem_cols() {
42
+ DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns");
43
+ if constexpr (kNumCols <= 32) return 32;
44
+ if constexpr (kNumCols <= 64) return 64;
45
+ if constexpr (kNumCols <= 128) return 128;
46
+ if constexpr (kNumCols <= 256) return 256;
47
+ return 512;
48
+ }
49
+
50
+ } // namespace deep_gemm::utils
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd.cuh ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cute/atom/copy_traits_sm100.hpp>
4
+
5
+ #include <deep_gemm/common/math.cuh>
6
+ #include <deep_gemm/common/types.cuh>
7
+ #include <deep_gemm/common/utils.cuh>
8
+ #include <deep_gemm/ptx/ld_st.cuh>
9
+ #include <deep_gemm/ptx/tcgen05.cuh>
10
+
11
+ namespace deep_gemm::epilogue {
12
+
13
+ template <uint32_t BLOCK_M, uint32_t BLOCK_N,
14
+ uint32_t STORE_BLOCK_M, uint32_t STORE_BLOCK_N,
15
+ uint32_t kSwizzleCDMode,
16
+ uint32_t kNumTMAStoreStages,
17
+ uint32_t kNumUMMAStoreThreads,
18
+ GemmType kGemmType, bool kWithAccumulation,
19
+ typename cd_dtype_t,
20
+ typename epilogue_type_t,
21
+ typename pattern_cd_t>
22
+ CUTLASS_DEVICE void
23
+ sm100_store_cd(const utils::PatternVisitor<pattern_cd_t>& smem_cd, uint32_t& tma_stage_idx,
24
+ const uint32_t& tmem_base_addr,
25
+ const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx,
26
+ const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx,
27
+ const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier,
28
+ const cute::TmaDescriptor& tensor_map_cd) {
29
+ // TMA checks
30
+ constexpr uint32_t kNumBankGroupBytes = 16;
31
+ constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
32
+ DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
33
+ DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
34
+ DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes");
35
+ DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
36
+
37
+ // Share store pipeline between blocks
38
+ auto advance_store_pipeline = [&]() {
39
+ tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
40
+ };
41
+
42
+ // Iterate over M waves
43
+ constexpr auto kNumMWaves = BLOCK_M / STORE_BLOCK_M;
44
+ #pragma unroll
45
+ for (uint32_t w = 0; w < kNumMWaves; ++ w) {
46
+ // Issue every swizzled atom and pipeline STSM and TMA store
47
+ constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
48
+ #pragma unroll
49
+ for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
50
+ auto smem_base_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]);
51
+
52
+ // Wait shared memory to be released
53
+ if (epilogue_warp_idx == 0)
54
+ cute::tma_store_wait<kNumTMAStoreStages - 1>();
55
+ cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
56
+
57
+ // The pipeline stage
58
+ const auto m_idx = base_m_idx + w * STORE_BLOCK_M;
59
+ const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(base_n_idx + s * STORE_BLOCK_N);
60
+
61
+ // Store into shared memory
62
+ #pragma unroll
63
+ for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
64
+ // Calculate the index of the bank group to be written in the atom
65
+ auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
66
+
67
+ // Reshape the atom in another view and swizzle
68
+ // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
69
+ // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
70
+ // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
71
+ constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
72
+ auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
73
+ auto col = kHasShortcut ? (i) : (bank_group_index % 8);
74
+ col ^= row % (kSwizzleCDMode / 16);
75
+
76
+ // Source and destination memory address
77
+ uint32_t tmem_addr = tmem_base_addr + // Accumulator offset
78
+ w * BLOCK_N + // Wave offset
79
+ s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
80
+ auto smem_ptr = smem_base_ptr + // Base pointer
81
+ epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
82
+ row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
83
+
84
+ // Load from tensor memory, store into shared memory
85
+ uint32_t values[kNumElemsPerBankGroup];
86
+ if constexpr (cute::is_same_v<cd_dtype_t, float>) {
87
+ // For FP32 output, read and store
88
+ DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
89
+ cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
90
+ values[0], values[1], values[2], values[3]);
91
+ cutlass::arch::fence_view_async_tmem_load();
92
+ ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
93
+ } else {
94
+ // For BF16 output, read, cast and store
95
+ DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
96
+ cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
97
+ values[0], values[1], values[2], values[3],
98
+ values[4], values[5], values[6], values[7]);
99
+ cutlass::arch::fence_view_async_tmem_load();
100
+ ptx::st_shared(
101
+ smem_ptr,
102
+ math::cast_into_bf16_and_pack(values[0], values[1]),
103
+ math::cast_into_bf16_and_pack(values[2], values[3]),
104
+ math::cast_into_bf16_and_pack(values[4], values[5]),
105
+ math::cast_into_bf16_and_pack(values[6], values[7])
106
+ );
107
+ }
108
+ }
109
+
110
+ // Notify tensor memory empty (only at the leader CTA) arrival ASAP
111
+ // NOTES: only the last stage needs to do this
112
+ if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
113
+ ptx::tcgen05_before_thread_sync();
114
+ tmem_empty_barrier->arrive(0u);
115
+ }
116
+
117
+ // Synchronize all threads and issue TMA
118
+ cute::tma_store_fence();
119
+ cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
120
+ if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
121
+ if constexpr (kGemmType == GemmType::Batched) {
122
+ using cute_tma_t = cute::conditional_t<kWithAccumulation,
123
+ cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
124
+ cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx, batch_idx);
125
+ } else {
126
+ using cute_tma_t = cute::conditional_t<kWithAccumulation,
127
+ cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
128
+ cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx);
129
+ }
130
+ cute::tma_store_arrive();
131
+ }
132
+ __syncwarp();
133
+ }
134
+ }
135
+ }
136
+
137
+ } // namespace deep_gemm::epilogue
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cute/atom/copy_traits_sm100.hpp>
4
+
5
+ #include <deep_gemm/common/math.cuh>
6
+ #include <deep_gemm/common/types.cuh>
7
+ #include <deep_gemm/common/utils.cuh>
8
+ #include <deep_gemm/ptx/ld_st.cuh>
9
+ #include <deep_gemm/ptx/tcgen05.cuh>
10
+
11
+ namespace deep_gemm::epilogue {
12
+
13
+ template <uint32_t BLOCK_M, uint32_t BLOCK_N,
14
+ uint32_t STORE_BLOCK_M, uint32_t STORE_BLOCK_N,
15
+ uint32_t kSwizzleCDMode,
16
+ uint32_t kNumTMAStoreStages,
17
+ uint32_t kNumUMMAStoreThreads,
18
+ GemmType kGemmType, bool kWithAccumulation,
19
+ typename cd_dtype_t,
20
+ typename epilogue_type_t,
21
+ typename pattern_cd_t>
22
+ CUTLASS_DEVICE void
23
+ sm100_store_cd_swap_ab(const utils::PatternVisitor<pattern_cd_t>& smem_cd, uint32_t& tma_stage_idx,
24
+ const uint32_t& tmem_base_addr,
25
+ const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx,
26
+ const uint32_t& effective_m,
27
+ const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx,
28
+ const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier,
29
+ const cute::TmaDescriptor& tensor_map_cd) {
30
+ // NOTES: The epilogue requires a full warpgroup to read all 128 TMEM rows,
31
+ // implying STORE_BLOCK_N must be 128.
32
+ DG_STATIC_ASSERT(STORE_BLOCK_N == 128, "STORE_BLOCK_N must be 128 to match TMEM rows");
33
+
34
+ // TMA checks
35
+ constexpr uint32_t STORE_BLOCK_N_ATOM = kSwizzleCDMode / sizeof(cd_dtype_t);
36
+ constexpr uint32_t kNumBankGroupBytes = 16;
37
+ constexpr uint32_t kNumSwizzleAtomRows = 8;
38
+ DG_STATIC_ASSERT(kSwizzleCDMode == 128, "TMA D must be 128B swizzled");
39
+ DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes");
40
+ DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
41
+ DG_STATIC_ASSERT(STORE_BLOCK_M % kNumSwizzleAtomRows == 0, "Invalid swizzling");
42
+ DG_STATIC_ASSERT(STORE_BLOCK_N % STORE_BLOCK_N_ATOM == 0, "Invalid swizzling");
43
+
44
+ // Share store pipeline between blocks
45
+ auto advance_store_pipeline = [&]() {
46
+ tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
47
+ };
48
+
49
+ // Iterate over M blocks
50
+ const auto num_stores = effective_m / STORE_BLOCK_M;
51
+ for (uint32_t s = 0; s < num_stores; ++ s, advance_store_pipeline()) {
52
+ // Wait shared memory to be released
53
+ if (epilogue_warp_idx == 0)
54
+ cute::tma_store_wait<kNumTMAStoreStages - 1>();
55
+ cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
56
+
57
+ // Store into shared memory
58
+ #pragma unroll
59
+ for (uint32_t i = 0; i < STORE_BLOCK_M / kNumSwizzleAtomRows; ++ i) {
60
+ uint32_t tmem_addr = tmem_base_addr +
61
+ s * STORE_BLOCK_M + // Store stage offset
62
+ i * kNumSwizzleAtomRows; // In-block offset
63
+ uint32_t values[kNumSwizzleAtomRows];
64
+
65
+ // Warps cooperatively write an atomic block to shared memory
66
+ DG_STATIC_ASSERT(STORE_BLOCK_N_ATOM % 32 == 0, "Invalid block sizes");
67
+ constexpr uint32_t kNumWarpsPerAtom = STORE_BLOCK_N_ATOM / 32;
68
+ uint32_t outer_atom_offset = (epilogue_warp_idx / kNumWarpsPerAtom) * STORE_BLOCK_M * kSwizzleCDMode;
69
+ uint32_t inner_atom_offset = i * kNumSwizzleAtomRows * kSwizzleCDMode;
70
+ auto smem_base_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + outer_atom_offset + inner_atom_offset;
71
+
72
+ if constexpr (cute::is_same_v<cd_dtype_t, float>) {
73
+ // NOTES: Swizzling is not required in this case, but used here for consistency with other cases
74
+ cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, values[0], values[1], values[2], values[3],
75
+ values[4], values[5], values[6], values[7]);
76
+ uint32_t col = lane_idx / 4;
77
+
78
+ #pragma unroll
79
+ for (uint32_t row = 0; row < kNumSwizzleAtomRows; ++ row) {
80
+ auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8)
81
+ + (col ^ row) * kNumBankGroupBytes
82
+ + (lane_idx % 4) * sizeof(float);
83
+ ptx::st_shared(reinterpret_cast<uint32_t*>(smem_ptr), values[row]);
84
+ }
85
+ } else {
86
+ // Load from TMEM using `.16x256b` shape to satisfy STSM layout requirements
87
+ // Start from lane index 0
88
+ cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr,
89
+ values[0], values[1], values[2], values[3]);
90
+ // Start from lane index 16
91
+ cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000,
92
+ values[4], values[5], values[6], values[7]);
93
+ cutlass::arch::fence_view_async_tmem_load();
94
+
95
+ // Destination shared memory address
96
+ uint32_t row = lane_idx % 8;
97
+ uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8;
98
+ auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8)
99
+ + (col ^ row) * kNumBankGroupBytes;
100
+
101
+ // Store matrix with transposition
102
+ ptx::SM90_U32x4_STSM_T<int>::copy(math::cast_into_bf16_and_pack(values[0], values[1]),
103
+ math::cast_into_bf16_and_pack(values[2], values[3]),
104
+ math::cast_into_bf16_and_pack(values[4], values[5]),
105
+ math::cast_into_bf16_and_pack(values[6], values[7]),
106
+ smem_ptr);
107
+ }
108
+ }
109
+
110
+ // Notify tensor memory empty (only at the leader CTA) arrival ASAP
111
+ // NOTES: only the last stage needs to do this
112
+ if (s == num_stores - 1) {
113
+ ptx::tcgen05_before_thread_sync();
114
+ tmem_empty_barrier->arrive(0u);
115
+ }
116
+
117
+ // Synchronize all threads and issue TMA
118
+ cute::tma_store_fence();
119
+ cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
120
+ if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
121
+ #pragma unroll
122
+ for (uint32_t i = 0; i < STORE_BLOCK_N / STORE_BLOCK_N_ATOM; ++ i) {
123
+ auto smem_ptr = smem_cd[tma_stage_idx] + i * STORE_BLOCK_M * STORE_BLOCK_N_ATOM;
124
+ uint32_t m_idx = base_m_idx + s * STORE_BLOCK_M;
125
+ uint32_t n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N_ATOM>(base_n_idx + i * STORE_BLOCK_N_ATOM);
126
+
127
+ // Issue 2D or 3D TMA store
128
+ if constexpr (kGemmType == GemmType::Batched) {
129
+ using cute_tma_t = cute::conditional_t<kWithAccumulation,
130
+ cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
131
+ cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx, batch_idx);
132
+ } else {
133
+ using cute_tma_t = cute::conditional_t<kWithAccumulation,
134
+ cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
135
+ cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx);
136
+ }
137
+ }
138
+ cute::tma_store_arrive();
139
+ }
140
+ __syncwarp();
141
+ }
142
+ }
143
+
144
+ } // namespace deep_gemm::epilogue
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/epilogue/transform.cuh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <deep_gemm/common/exception.cuh>
4
+
5
+ namespace deep_gemm::epilogue::transform {
6
+
7
+ struct EpilogueIdentity {
8
+ template <uint32_t STORE_BLOCK_N>
9
+ CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) {
10
+ return n_idx;
11
+ }
12
+ };
13
+
14
+ template <uint32_t kLeft, uint32_t kMid, uint32_t kRight>
15
+ struct EpilogueHeadSplits: EpilogueIdentity {
16
+ template <uint32_t STORE_BLOCK_N>
17
+ CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) {
18
+ DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0 and
19
+ kRight % STORE_BLOCK_N == 0, "Invalid head splits config");
20
+ return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid;
21
+ }
22
+ };
23
+
24
+ } // namespace deep_gemm::epilogue::transform
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh CHANGED
@@ -4,14 +4,18 @@
4
 
5
  #include <cutlass/arch/barrier.h>
6
 
7
- #include <deep_gemm/common/scheduler.cuh>
8
- #include <deep_gemm/common/utils.cuh>
9
- #include <deep_gemm/common/sm100_utils.cuh>
 
 
 
 
 
 
10
 
11
  namespace deep_gemm {
12
 
13
- using namespace deep_gemm::sm100;
14
-
15
  template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
16
  uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
17
  uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K_,
@@ -21,9 +25,10 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
21
  uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
22
  uint32_t kNumMulticast, bool kIsMulticastOnA,
23
  uint32_t kNumSMs,
 
24
  GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
25
  uint64_t kTensorCoreUtilControl>
26
- __global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
27
  sm100_bf16_gemm_impl(int* grouped_layout,
28
  uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
29
  const __grid_constant__ cute::TmaDescriptor tensor_map_a,
@@ -48,41 +53,31 @@ sm100_bf16_gemm_impl(int* grouped_layout,
48
  if constexpr (kWithAccumulation)
49
  DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
50
 
51
- // Configs
52
  constexpr uint32_t LAYOUT_AD_M = 128;
53
- constexpr uint32_t WAVE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
54
- constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M;
55
- constexpr uint32_t kNumTMAStoreStages = 2;
56
- DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K");
57
- DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
58
- DG_STATIC_ASSERT(sizeof(cutlass::bfloat16_t) * LAYOUT_AD_M % kSwizzleAMode == 0, "Invalid swizzle A mode");
59
-
60
- // Overwrite shape constants if the compiler gives
61
- shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
62
- shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
63
- shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
64
-
65
- // Utils
66
- bool is_leader_cta = cute::block_rank_in_cluster() == 0;
67
- const auto warp_idx = cutlass::canonical_warp_idx_sync();
68
- const auto lane_idx = get_lane_idx();
69
-
70
- // Align to 1024 bytes for swizzle-128B
71
- extern __shared__ __align__(1024) uint8_t smem_buffer[];
72
-
73
- // 2-CTA MMA
74
  constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
75
  constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
76
- constexpr uint32_t STORE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
77
- constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
78
- constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M;
79
- DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
80
- DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D");
81
  DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
 
 
 
 
 
 
 
 
 
 
 
 
82
  DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M");
83
 
84
  // Share memory sizes
85
- constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode;
86
  constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
87
  constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
88
  constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
@@ -91,41 +86,54 @@ sm100_bf16_gemm_impl(int* grouped_layout,
91
  DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
92
 
93
  // NOTES: Make sure we have enough shared memory for UMMA padding
94
- static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16);
95
- DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA");
96
-
97
- // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size
98
- // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2`
99
- constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2;
100
 
101
  // Real tensor memory size and offsets
102
- constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N;
103
- constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols>();
 
 
 
 
 
 
 
 
 
104
 
105
  // Prefetch TMA descriptors at the very beginning
106
- if (warp_idx == 0 and cute::elect_one_sync()) {
107
  cute::prefetch_tma_descriptor(&tensor_map_a);
108
  cute::prefetch_tma_descriptor(&tensor_map_b);
109
  cute::prefetch_tma_descriptor(&tensor_map_cd);
110
  }
111
 
 
 
 
 
 
 
 
 
112
  // D/A/B shared memory
113
- auto smem_cd = PatternVisitor([&](const uint32_t& i) {
114
  return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
115
  });
116
- auto smem_a = PatternVisitor([&](const uint32_t& i) {
117
  return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
118
  });
119
- auto smem_b = PatternVisitor([&](const uint32_t& i) {
120
  return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
121
  });
122
 
123
  // Fill barriers
124
  auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
125
- auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
126
- auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
127
- auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
128
- auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
129
  auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2;
130
 
131
  // Fill the tensor memory pointer
@@ -159,9 +167,13 @@ sm100_bf16_gemm_impl(int* grouped_layout,
159
  }
160
  kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
161
 
 
 
 
162
  // Block scheduler
163
  uint32_t m_block_idx, n_block_idx;
164
- auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
 
165
 
166
  // Pipeline and TMA phases
167
  uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0;
@@ -178,16 +190,20 @@ sm100_bf16_gemm_impl(int* grouped_layout,
178
  // TMA load warp
179
  // Persistently schedule over blocks
180
  while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
181
- const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
 
 
 
 
182
  for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
183
  // Wait consumer release
184
  empty_barriers[stage_idx]->wait(phase ^ 1);
185
 
186
  // Compute offsets
187
  // NOTES: the group is always concatenated with the outer dimension
188
- uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> (
189
  shape_m, BLOCK_M, m_block_idx);
190
- uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> (
191
  shape_n, BLOCK_N, n_block_idx, m_block_idx);
192
 
193
  // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
@@ -195,14 +211,14 @@ sm100_bf16_gemm_impl(int* grouped_layout,
195
  DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
196
  kMajorA == cute::UMMA::Major::K, "Invalid major");
197
  uint32_t k_idx = k_block_idx * BLOCK_K;
198
- uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> (
199
  shape_k, BLOCK_K, k_block_idx, m_block_idx);
200
- uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> (
201
  shape_k, BLOCK_K, k_block_idx, m_block_idx);
202
 
203
  // Add 2 CTA offsets
204
  if constexpr (kNumMulticast > 1) {
205
- m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
206
  n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
207
  }
208
 
@@ -210,16 +226,16 @@ sm100_bf16_gemm_impl(int* grouped_layout,
210
  constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
211
  const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
212
  if constexpr (kMajorA == cute::UMMA::Major::K)
213
- tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
214
  &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx);
215
  if constexpr (kMajorA == cute::UMMA::Major::MN)
216
- tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
217
  &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx);
218
  if constexpr (kMajorB == cute::UMMA::Major::K)
219
- tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
220
  &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx);
221
  if constexpr (kMajorB == cute::UMMA::Major::MN)
222
- tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
223
  &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx);
224
 
225
  // Arrive at full barriers
@@ -235,17 +251,16 @@ sm100_bf16_gemm_impl(int* grouped_layout,
235
  // MMA issue warp
236
  // NOTES: only the leader CTA will do this
237
  // Make instruction descriptor
238
- // TODO: refactor `UMMA_M` calculation
239
- constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
240
- constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
241
- constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t);
242
- auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>();
243
 
244
  DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
245
  // Merged stages only happens in NT normal GEMM cases
246
  constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
247
- auto a_desc = make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], 0, 0);
248
- auto b_desc = make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
249
  uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
250
  uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
251
 
@@ -262,7 +277,7 @@ sm100_bf16_gemm_impl(int* grouped_layout,
262
  auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
263
  auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
264
  tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
265
- tcgen05_after_thread_sync();
266
 
267
  // UMMA and empty barrier arrival alias
268
  auto umma_arrive = [](const uint64_t* barrier) {
@@ -279,36 +294,45 @@ sm100_bf16_gemm_impl(int* grouped_layout,
279
  // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
280
  if (do_tmem_full_arrive)
281
  umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
 
282
  };
283
 
 
 
 
 
 
 
284
  // Launch MMAs
285
- const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
286
  for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
287
  // Wait TMA arrival
288
  full_barriers[stage_idx]->wait(phase);
289
- tcgen05_after_thread_sync();
290
 
291
  // Issue UMMA in the leader CTA
292
- using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_F16BF16_SS, SM100_MMA_F16BF16_2x1SM_SS>;
293
- const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
294
- const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
295
- const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
296
  if (cute::elect_one_sync()) {
297
  #pragma unroll
298
  for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
299
  uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K;
300
- b_desc.lo = advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
301
- #pragma unroll
302
- for (uint32_t w = 0; w < kNumMWaves; ++ w) {
303
- DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset");
304
- a_desc.lo = advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K + w * WAVE_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
305
- mma_t::fma(a_desc, b_desc,
306
- accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
307
- k_block_idx > 0 or k > 0,
308
- runtime_instr_desc);
 
309
  }
310
  }
311
  }
 
312
 
313
  // Commit to the mbarrier object
314
  // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
@@ -319,15 +343,16 @@ sm100_bf16_gemm_impl(int* grouped_layout,
319
  if constexpr (kTensorCoreUtilControl < 100) {
320
  // For utilization control
321
  umma_arrive(reinterpret_cast<uint64_t*>(tensor_core_full_barrier));
 
322
 
323
  // Wait for last UMMA to be done
324
  tensor_core_full_barrier->wait(tensor_core_phase);
325
  tensor_core_phase ^= 1;
326
 
327
  // Sleep for certain cycles
328
- constexpr static uint64_t kNumUMMACycles = (2ull * LAYOUT_AD_M * kNumMWaves * BLOCK_N * BLOCK_K) / 8192ull;
329
  constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
330
- const auto& start_clock = clock64();
331
  if (cute::elect_one_sync())
332
  while (clock64() - start_clock < kNumDummyCycles) {}
333
  __syncwarp();
@@ -336,9 +361,9 @@ sm100_bf16_gemm_impl(int* grouped_layout,
336
  }
337
 
338
  // To safely deconstruct barriers, we need another round of waits
339
- const auto& iter_idx = scheduler.current_iter - 1;
340
  if (kNumMulticast > 1 and iter_idx >= 0) {
341
- const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
342
  tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
343
  }
344
  } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) {
@@ -348,19 +373,10 @@ sm100_bf16_gemm_impl(int* grouped_layout,
348
  // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
349
  // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
350
  // NOTES: we also forbid two CTAs to share the same SM and its tensor memory
351
- DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
352
-
353
- // TMA checks
354
- constexpr uint32_t kNumBankGroupBytes = 16;
355
- constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
356
- DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
357
- DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
358
 
359
  // Share store pipeline between blocks
360
  uint32_t tma_stage_idx = 0;
361
- auto advance_store_pipeline = [&]() {
362
- tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
363
- };
364
 
365
  // Persistently schedule over blocks
366
  while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
@@ -369,108 +385,47 @@ sm100_bf16_gemm_impl(int* grouped_layout,
369
 
370
  // Wait UMMA arrival
371
  tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
372
- tcgen05_after_thread_sync();
373
 
374
  // Load from tensor memory into registers, and write shared memory with STSM
375
- DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
376
- DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
377
-
378
- // Iterate over M waves
379
- #pragma unroll
380
- for (uint32_t w = 0; w < kNumMWaves; ++ w) {
381
- // Issue every swizzled atom and pipeline STSM and TMA store
382
- constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
383
- #pragma unroll
384
- for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
385
- // Wait shared memory to be released
386
- if (epilogue_warp_idx == 0)
387
- cute::tma_store_wait<kNumTMAStoreStages - 1>();
388
- cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
389
-
390
- // The pipeline stage
391
- const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
392
- const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
393
-
394
- // Store into shared memory
395
- #pragma unroll
396
- for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
397
- // Calculate the index of the bank group to be written in the atom
398
- auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
399
-
400
- // Reshape the atom in another view and swizzle
401
- // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
402
- // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
403
- // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
404
- constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
405
- auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
406
- auto col = kHasShortcut ? (i) : (bank_group_index % 8);
407
- col ^= row % (kSwizzleCDMode / 16);
408
-
409
- // Source and destination memory address
410
- uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
411
- w * BLOCK_N + // Wave offset
412
- s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
413
- auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
414
- epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
415
- row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
416
-
417
- // Load from tensor memory, store into shared memory
418
- uint32_t values[kNumElemsPerBankGroup];
419
- if constexpr (cute::is_same_v<cd_dtype_t, float>) {
420
- // For FP32 output, read and store
421
- DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
422
- cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
423
- values[0], values[1], values[2], values[3]);
424
- cutlass::arch::fence_view_async_tmem_load();
425
- st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
426
- } else {
427
- // For BF16 output, read, cast and store
428
- DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
429
- cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
430
- values[0], values[1], values[2], values[3],
431
- values[4], values[5], values[6], values[7]);
432
- cutlass::arch::fence_view_async_tmem_load();
433
- st_shared(smem_ptr,
434
- cast_into_bf16_and_pack(values[0], values[1]),
435
- cast_into_bf16_and_pack(values[2], values[3]),
436
- cast_into_bf16_and_pack(values[4], values[5]),
437
- cast_into_bf16_and_pack(values[6], values[7]));
438
- }
439
- }
440
-
441
- // Notify tensor memory empty (only at the leader CTA) arrival ASAP
442
- // NOTES: only the last stage needs to do this
443
- if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
444
- tcgen05_before_thread_sync();
445
- tmem_empty_barriers[accum_stage_idx]->arrive(0u);
446
- }
447
- __syncwarp();
448
-
449
- // Synchronize all threads and issue TMA
450
- cute::tma_store_fence();
451
- cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
452
- if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
453
- if constexpr (kGemmType == GemmType::Batched) {
454
- using cute_tma_t = cute::conditional_t<kWithAccumulation,
455
- cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
456
- cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx],
457
- n_idx, m_idx, scheduler.current_group_idx);
458
- } else {
459
- using cute_tma_t = cute::conditional_t<kWithAccumulation,
460
- cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
461
- cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx);
462
- }
463
- cute::tma_store_arrive();
464
- }
465
- }
466
  }
467
  }
468
-
469
- // Deallocate tensor memory by the last UMMA store warp
470
- // NOTES: warp 0 is waiting TMA store
471
- if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
472
- Allocator().free(0, kNumTmemCols);
473
  }
 
 
 
 
 
 
 
 
474
  #else
475
  if (blockIdx.x == 0 and threadIdx.x == 0)
476
  DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
 
4
 
5
  #include <cutlass/arch/barrier.h>
6
 
7
+ #include <deep_gemm/scheduler/gemm.cuh>
8
+ #include <deep_gemm/common/math.cuh>
9
+ #include <deep_gemm/common/tma_copy.cuh>
10
+ #include <deep_gemm/epilogue/sm100_store_cd.cuh>
11
+ #include <deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh>
12
+ #include <deep_gemm/epilogue/transform.cuh>
13
+ #include <deep_gemm/mma/sm100.cuh>
14
+ #include <deep_gemm/ptx/tcgen05.cuh>
15
+ #include <deep_gemm/ptx/utils.cuh>
16
 
17
  namespace deep_gemm {
18
 
 
 
19
  template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
20
  uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
21
  uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K_,
 
25
  uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
26
  uint32_t kNumMulticast, bool kIsMulticastOnA,
27
  uint32_t kNumSMs,
28
+ bool kSwapAB,
29
  GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
30
  uint64_t kTensorCoreUtilControl>
31
+ CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
32
  sm100_bf16_gemm_impl(int* grouped_layout,
33
  uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
34
  const __grid_constant__ cute::TmaDescriptor tensor_map_a,
 
53
  if constexpr (kWithAccumulation)
54
  DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
55
 
56
+ // MMA Configs
57
  constexpr uint32_t LAYOUT_AD_M = 128;
58
+ constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast;
59
+ constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N;
60
+ constexpr uint32_t UMMA_K = 16;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
62
  constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
63
+ DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K");
 
 
 
 
64
  DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
65
+ DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or
66
+ (not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size");
67
+
68
+ // Epilogue configs
69
+ // Always enable pipeline for better performance
70
+ constexpr uint32_t kNumEpilogueStages = 2;
71
+ constexpr uint32_t kNumTMAStoreStages = 2;
72
+ // NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N
73
+ // per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases
74
+ constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
75
+ constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t);
76
+ constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M;
77
  DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M");
78
 
79
  // Share memory sizes
80
+ constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t);
81
  constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
82
  constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
83
  constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
 
86
  DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
87
 
88
  // NOTES: Make sure we have enough shared memory for UMMA padding
89
+ static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16);
90
+ DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory out of bound for UMMA");
 
 
 
 
91
 
92
  // Real tensor memory size and offsets
93
+ constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * UMMA_N;
94
+ constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols>();
95
+ DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
96
+
97
+ // Synchronize the cluster before 2-CTA TMEM allocation
98
+ kNumMulticast > 1 ? cute::cluster_sync() : void();
99
+
100
+ // Utils
101
+ bool is_leader_cta = cute::block_rank_in_cluster() == 0;
102
+ const auto warp_idx = cutlass::canonical_warp_idx_sync();
103
+ const auto lane_idx = ptx::get_lane_idx();
104
 
105
  // Prefetch TMA descriptors at the very beginning
106
+ if (warp_idx == 0) {
107
  cute::prefetch_tma_descriptor(&tensor_map_a);
108
  cute::prefetch_tma_descriptor(&tensor_map_b);
109
  cute::prefetch_tma_descriptor(&tensor_map_cd);
110
  }
111
 
112
+ // Overwrite shape constants if the compiler gives
113
+ shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
114
+ shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
115
+ shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
116
+
117
+ // Align to 1024 bytes for swizzle-128B
118
+ extern __shared__ __align__(1024) uint8_t smem_buffer[];
119
+
120
  // D/A/B shared memory
121
+ auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) {
122
  return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
123
  });
124
+ auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
125
  return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
126
  });
127
+ auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
128
  return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
129
  });
130
 
131
  // Fill barriers
132
  auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
133
+ auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
134
+ auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
135
+ auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
136
+ auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
137
  auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2;
138
 
139
  // Fill the tensor memory pointer
 
167
  }
168
  kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
169
 
170
+ // Wait for primary kernel completion
171
+ cudaGridDependencySynchronize();
172
+
173
  // Block scheduler
174
  uint32_t m_block_idx, n_block_idx;
175
+ auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(
176
+ shape_m, shape_n, shape_k, grouped_layout);
177
 
178
  // Pipeline and TMA phases
179
  uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0;
 
190
  // TMA load warp
191
  // Persistently schedule over blocks
192
  while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
193
+ // Use dynamic load block M, when swap-AB is enabled
194
+ const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M;
195
+
196
+ // For k-grouped layout, the number of block K is variable
197
+ const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
198
  for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
199
  // Wait consumer release
200
  empty_barriers[stage_idx]->wait(phase ^ 1);
201
 
202
  // Compute offsets
203
  // NOTES: the group is always concatenated with the outer dimension
204
+ uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> (
205
  shape_m, BLOCK_M, m_block_idx);
206
+ uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> (
207
  shape_n, BLOCK_N, n_block_idx, m_block_idx);
208
 
209
  // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
 
211
  DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
212
  kMajorA == cute::UMMA::Major::K, "Invalid major");
213
  uint32_t k_idx = k_block_idx * BLOCK_K;
214
+ uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> (
215
  shape_k, BLOCK_K, k_block_idx, m_block_idx);
216
+ uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> (
217
  shape_k, BLOCK_K, k_block_idx, m_block_idx);
218
 
219
  // Add 2 CTA offsets
220
  if constexpr (kNumMulticast > 1) {
221
+ m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0;
222
  n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
223
  }
224
 
 
226
  constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
227
  const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
228
  if constexpr (kMajorA == cute::UMMA::Major::K)
229
+ tma::copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
230
  &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx);
231
  if constexpr (kMajorA == cute::UMMA::Major::MN)
232
+ tma::copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
233
  &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx);
234
  if constexpr (kMajorB == cute::UMMA::Major::K)
235
+ tma::copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
236
  &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx);
237
  if constexpr (kMajorB == cute::UMMA::Major::MN)
238
+ tma::copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
239
  &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx);
240
 
241
  // Arrive at full barriers
 
251
  // MMA issue warp
252
  // NOTES: only the leader CTA will do this
253
  // Make instruction descriptor
254
+ auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float,
255
+ UMMA_M, UMMA_N, kMajorB, kMajorA>()
256
+ : cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float,
257
+ UMMA_M, UMMA_N, kMajorA, kMajorB>();
 
258
 
259
  DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
260
  // Merged stages only happens in NT normal GEMM cases
261
  constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
262
+ auto a_desc = mma::sm100::make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], 0, 0);
263
+ auto b_desc = mma::sm100::make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
264
  uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
265
  uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
266
 
 
277
  auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
278
  auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
279
  tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
280
+ ptx::tcgen05_after_thread_sync();
281
 
282
  // UMMA and empty barrier arrival alias
283
  auto umma_arrive = [](const uint64_t* barrier) {
 
294
  // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
295
  if (do_tmem_full_arrive)
296
  umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
297
+ __syncwarp();
298
  };
299
 
300
+ // Dynamic update of UMMA N based on effective M, when swap-AB is enabled
301
+ if constexpr (kSwapAB) {
302
+ uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx);
303
+ mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n);
304
+ }
305
+
306
  // Launch MMAs
307
+ const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
308
  for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
309
  // Wait TMA arrival
310
  full_barriers[stage_idx]->wait(phase);
311
+ ptx::tcgen05_after_thread_sync();
312
 
313
  // Issue UMMA in the leader CTA
314
+ using mma_t = cute::conditional_t<kNumMulticast == 1, ptx::SM100_MMA_F16BF16_SS, ptx::SM100_MMA_F16BF16_2x1SM_SS>;
315
+ const auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
316
+ const auto a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
317
+ const auto b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
318
  if (cute::elect_one_sync()) {
319
  #pragma unroll
320
  for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
321
  uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K;
322
+ a_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(
323
+ a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
324
+ b_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(
325
+ b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
326
+ if (kSwapAB) {
327
+ mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N,
328
+ k_block_idx > 0 or k > 0, runtime_instr_desc);
329
+ } else {
330
+ mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N,
331
+ k_block_idx > 0 or k > 0, runtime_instr_desc);
332
  }
333
  }
334
  }
335
+ __syncwarp();
336
 
337
  // Commit to the mbarrier object
338
  // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
 
343
  if constexpr (kTensorCoreUtilControl < 100) {
344
  // For utilization control
345
  umma_arrive(reinterpret_cast<uint64_t*>(tensor_core_full_barrier));
346
+ __syncwarp();
347
 
348
  // Wait for last UMMA to be done
349
  tensor_core_full_barrier->wait(tensor_core_phase);
350
  tensor_core_phase ^= 1;
351
 
352
  // Sleep for certain cycles
353
+ constexpr static uint64_t kNumUMMACycles = (2ull * UMMA_M * UMMA_N * BLOCK_K) / 8192ull;
354
  constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
355
+ const auto start_clock = clock64();
356
  if (cute::elect_one_sync())
357
  while (clock64() - start_clock < kNumDummyCycles) {}
358
  __syncwarp();
 
361
  }
362
 
363
  // To safely deconstruct barriers, we need another round of waits
364
+ const auto iter_idx = scheduler.current_iter - 1;
365
  if (kNumMulticast > 1 and iter_idx >= 0) {
366
+ const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
367
  tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
368
  }
369
  } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) {
 
373
  // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
374
  // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
375
  // NOTES: we also forbid two CTAs to share the same SM and its tensor memory
376
+ DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
 
 
 
 
 
 
377
 
378
  // Share store pipeline between blocks
379
  uint32_t tma_stage_idx = 0;
 
 
 
380
 
381
  // Persistently schedule over blocks
382
  while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
 
385
 
386
  // Wait UMMA arrival
387
  tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
388
+ ptx::tcgen05_after_thread_sync();
389
 
390
  // Load from tensor memory into registers, and write shared memory with STSM
391
+ const auto tmem_base_addr = accum_stage_idx * UMMA_N;
392
+ const auto base_m_idx = scheduler.template get_global_idx<
393
+ (not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
394
+ const auto base_n_idx = n_block_idx * BLOCK_N;
395
+
396
+ if constexpr (kSwapAB) {
397
+ const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx);
398
+ epilogue::sm100_store_cd_swap_ab<BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
399
+ kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
400
+ kGemmType, kWithAccumulation,
401
+ cd_dtype_t, epilogue::transform::EpilogueIdentity>
402
+ (smem_cd, tma_stage_idx, tmem_base_addr,
403
+ base_m_idx, base_n_idx, scheduler.current_group_idx,
404
+ effective_m,
405
+ epilogue_warp_idx, lane_idx,
406
+ tmem_empty_barriers[accum_stage_idx],
407
+ tensor_map_cd);
408
+ } else {
409
+ epilogue::sm100_store_cd<BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
410
+ kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
411
+ kGemmType, kWithAccumulation,
412
+ cd_dtype_t, epilogue::transform::EpilogueIdentity>
413
+ (smem_cd, tma_stage_idx, tmem_base_addr,
414
+ base_m_idx, base_n_idx, scheduler.current_group_idx,
415
+ epilogue_warp_idx, lane_idx,
416
+ tmem_empty_barriers[accum_stage_idx],
417
+ tensor_map_cd);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  }
419
  }
 
 
 
 
 
420
  }
421
+
422
+ // TODO: Remove redundant synchronization
423
+ kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
424
+
425
+ // Deallocate tensor memory
426
+ if (warp_idx == 0)
427
+ Allocator().free(0, kNumTmemCols);
428
+
429
  #else
430
  if (blockIdx.x == 0 and threadIdx.x == 0)
431
  DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh CHANGED
@@ -5,18 +5,19 @@
5
  #include <cutlass/arch/barrier.h>
6
 
7
  #include <deep_gemm/common/utils.cuh>
8
- #include <deep_gemm/common/sm100_utils.cuh>
 
 
 
9
 
10
  namespace deep_gemm {
11
 
12
- using namespace deep_gemm::sm100;
13
-
14
  template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
15
  uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
16
  uint32_t kSplitFactor,
17
  uint32_t kSwizzleABMode, uint32_t kSwizzleCDMode,
18
  uint32_t kNumStages, uint32_t kNumThreads>
19
- __global__ void __launch_bounds__(kNumThreads, 1)
20
  sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
21
  const __grid_constant__ cute::TmaDescriptor tensor_map_a,
22
  const __grid_constant__ cute::TmaDescriptor tensor_map_b,
@@ -30,7 +31,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
30
 
31
  // Utils
32
  const auto warp_idx = cutlass::canonical_warp_idx_sync();
33
- const auto lane_idx = get_lane_idx();
34
  DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size");
35
  DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode");
36
 
@@ -51,24 +52,24 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
51
  }
52
 
53
  // Real tensor memory size and offsets
54
- constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_N>();
55
 
56
  // Fill D/A/B
57
- auto smem_cd = PatternVisitor([&](const uint32_t& i) {
58
  return reinterpret_cast<float*>(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE));
59
  });
60
- auto smem_a = PatternVisitor([&](const uint32_t& i) {
61
  return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
62
  });
63
- auto smem_b = PatternVisitor([&](const uint32_t& i) {
64
  return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
65
  });
66
 
67
  // Fill barriers
68
  auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
69
  kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
70
- auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
71
- auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
72
  auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2);
73
 
74
  // Fill the tensor memory pointer
@@ -93,14 +94,17 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
93
  __syncthreads();
94
 
95
  // Block indices
96
- const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
97
- const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
98
  const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
99
  const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
100
  const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
101
  const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
102
  const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
103
 
 
 
 
104
  if (warp_idx == 0) {
105
  // TMA load warp
106
  for (uint32_t s = 0; s < num_total_stages; ++ s) {
@@ -115,8 +119,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
115
 
116
  // Issue TMAs
117
  if (cute::elect_one_sync()) {
118
- tma_copy<BLOCK_K, BLOCK_M, kSwizzleABMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M);
119
- tma_copy<BLOCK_K, BLOCK_N, kSwizzleABMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N);
120
  }
121
 
122
  // Arrive at full barriers
@@ -134,8 +138,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
134
  auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
135
 
136
  DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
137
- auto a_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_M, BLOCK_K, kSwizzleABMode>(smem_a[0], 0, 0);
138
- auto b_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_N, BLOCK_K, kSwizzleABMode>(smem_b[0], 0, 0);
139
  uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
140
  uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
141
 
@@ -147,14 +151,14 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
147
  "Invalid MMA instruction shape");
148
 
149
  // Wait tensor memory empty barrier arrival
150
- tcgen05_after_thread_sync();
151
 
152
  // Launch MMAs
153
  for (uint32_t s = 0; s < num_total_stages; ++ s) {
154
  // Wait TMA arrival
155
  const auto& stage_idx = s % kNumStages;
156
  full_barriers[stage_idx]->wait((s / kNumStages) & 1);
157
- tcgen05_after_thread_sync();
158
 
159
  // Issue UMMA in the leader CTA
160
  const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
@@ -163,9 +167,11 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
163
  if (cute::elect_one_sync()) {
164
  #pragma unroll
165
  for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
166
- a_desc.lo = advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_M, kSwizzleABMode, cutlass::bfloat16_t>(a_desc_base_lo, 0, k * UMMA_K);
167
- b_desc.lo = advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_N, kSwizzleABMode, cutlass::bfloat16_t>(b_desc_base_lo, 0, k * UMMA_K);
168
- SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc);
 
 
169
  }
170
  }
171
 
@@ -180,7 +186,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
180
  // i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`.
181
  // NOTES: we also forbid two CTAs to share the same SM and its tensor memory
182
  if (warp_idx == 2)
183
- DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
184
 
185
  // TMA checks
186
  constexpr uint32_t kNumBankGroupBytes = 16;
@@ -191,7 +197,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
191
 
192
  // Wait UMMA arrival
193
  tmem_full_barrier->wait(0);
194
- tcgen05_after_thread_sync();
195
 
196
  // Load from tensor memory into registers, and write shared memory with STSM
197
  DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
@@ -239,7 +245,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
239
  cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
240
  values[0], values[1], values[2], values[3]);
241
  cutlass::arch::fence_view_async_tmem_load();
242
- st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
243
  }
244
 
245
  // Synchronize all threads and issue TMA
 
5
  #include <cutlass/arch/barrier.h>
6
 
7
  #include <deep_gemm/common/utils.cuh>
8
+ #include <deep_gemm/mma/sm100.cuh>
9
+ #include <deep_gemm/ptx/ld_st.cuh>
10
+ #include <deep_gemm/ptx/tcgen05.cuh>
11
+ #include <deep_gemm/ptx/utils.cuh>
12
 
13
  namespace deep_gemm {
14
 
 
 
15
  template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
16
  uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
17
  uint32_t kSplitFactor,
18
  uint32_t kSwizzleABMode, uint32_t kSwizzleCDMode,
19
  uint32_t kNumStages, uint32_t kNumThreads>
20
+ CUTLASS_GLOBAL void __launch_bounds__(kNumThreads, 1)
21
  sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
22
  const __grid_constant__ cute::TmaDescriptor tensor_map_a,
23
  const __grid_constant__ cute::TmaDescriptor tensor_map_b,
 
31
 
32
  // Utils
33
  const auto warp_idx = cutlass::canonical_warp_idx_sync();
34
+ const auto lane_idx = ptx::get_lane_idx();
35
  DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size");
36
  DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode");
37
 
 
52
  }
53
 
54
  // Real tensor memory size and offsets
55
+ constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<BLOCK_N>();
56
 
57
  // Fill D/A/B
58
+ auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) {
59
  return reinterpret_cast<float*>(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE));
60
  });
61
+ auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
62
  return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
63
  });
64
+ auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
65
  return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
66
  });
67
 
68
  // Fill barriers
69
  auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
70
  kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
71
+ auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
72
+ auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
73
  auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2);
74
 
75
  // Fill the tensor memory pointer
 
94
  __syncthreads();
95
 
96
  // Block indices
97
+ const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N);
98
+ const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M);
99
  const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
100
  const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
101
  const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
102
  const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
103
  const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
104
 
105
+ // Wait for primary kernel completion
106
+ cudaGridDependencySynchronize();
107
+
108
  if (warp_idx == 0) {
109
  // TMA load warp
110
  for (uint32_t s = 0; s < num_total_stages; ++ s) {
 
119
 
120
  // Issue TMAs
121
  if (cute::elect_one_sync()) {
122
+ tma::copy<BLOCK_K, BLOCK_M, kSwizzleABMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M);
123
+ tma::copy<BLOCK_K, BLOCK_N, kSwizzleABMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N);
124
  }
125
 
126
  // Arrive at full barriers
 
138
  auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
139
 
140
  DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
141
+ auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, BLOCK_M, BLOCK_K, kSwizzleABMode>(smem_a[0], 0, 0);
142
+ auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, BLOCK_N, BLOCK_K, kSwizzleABMode>(smem_b[0], 0, 0);
143
  uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
144
  uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
145
 
 
151
  "Invalid MMA instruction shape");
152
 
153
  // Wait tensor memory empty barrier arrival
154
+ ptx::tcgen05_after_thread_sync();
155
 
156
  // Launch MMAs
157
  for (uint32_t s = 0; s < num_total_stages; ++ s) {
158
  // Wait TMA arrival
159
  const auto& stage_idx = s % kNumStages;
160
  full_barriers[stage_idx]->wait((s / kNumStages) & 1);
161
+ ptx::tcgen05_after_thread_sync();
162
 
163
  // Issue UMMA in the leader CTA
164
  const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
 
167
  if (cute::elect_one_sync()) {
168
  #pragma unroll
169
  for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
170
+ a_desc.lo = mma::sm100::advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_M, kSwizzleABMode, cutlass::bfloat16_t>(
171
+ a_desc_base_lo, 0, k * UMMA_K);
172
+ b_desc.lo = mma::sm100::advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_N, kSwizzleABMode, cutlass::bfloat16_t>(
173
+ b_desc_base_lo, 0, k * UMMA_K);
174
+ ptx::SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc);
175
  }
176
  }
177
 
 
186
  // i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`.
187
  // NOTES: we also forbid two CTAs to share the same SM and its tensor memory
188
  if (warp_idx == 2)
189
+ DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
190
 
191
  // TMA checks
192
  constexpr uint32_t kNumBankGroupBytes = 16;
 
197
 
198
  // Wait UMMA arrival
199
  tmem_full_barrier->wait(0);
200
+ ptx::tcgen05_after_thread_sync();
201
 
202
  // Load from tensor memory into registers, and write shared memory with STSM
203
  DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
 
245
  cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
246
  values[0], values[1], values[2], values[3]);
247
  cutlass::arch::fence_view_async_tmem_load();
248
+ ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
249
  }
250
 
251
  // Synchronize all threads and issue TMA
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cutlass/arch/barrier.h>
4
+ #include <cutlass/arch/reg_reconfig.h>
5
+
6
+ #include <cute/arch/cluster_sm90.hpp>
7
+ #include <cute/arch/copy_sm90_desc.hpp>
8
+
9
+ #include <deep_gemm/common/cute_tie.cuh>
10
+ #include <deep_gemm/common/utils.cuh>
11
+ #include <deep_gemm/mma/sm100.cuh>
12
+ #include <deep_gemm/ptx/ld_st.cuh>
13
+ #include <deep_gemm/ptx/tcgen05.cuh>
14
+ #include <deep_gemm/ptx/utils.cuh>
15
+
16
+ namespace deep_gemm {
17
+
18
+ template <uint32_t kNumHeads, uint32_t kHeadDim,
19
+ bool kIsCompressedLogits,
20
+ uint32_t BLOCK_Q, uint32_t BLOCK_KV,
21
+ uint32_t kNumQStages, uint32_t kNumKVStages,
22
+ uint32_t kNumSMs,
23
+ uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
24
+ typename logits_dtype_t,
25
+ uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
26
+ CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
27
+ void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
28
+ const uint32_t max_seqlen_k,
29
+ const uint32_t logits_stride,
30
+ const uint32_t* cu_seq_len_k_start,
31
+ const uint32_t* cu_seq_len_k_end,
32
+ logits_dtype_t* logits,
33
+ const __grid_constant__ cute::TmaDescriptor tensor_map_q,
34
+ const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q,
35
+ const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
36
+ const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv,
37
+ const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
38
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
39
+
40
+ // Utils
41
+ const auto sm_idx = blockIdx.x;
42
+ const auto warp_idx = cutlass::canonical_warp_idx_sync();
43
+ const auto warpgroup_idx = warp_idx / 4;
44
+ const auto lane_idx = ptx::get_lane_idx();
45
+ constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
46
+
47
+ // Prefetch TMA descriptors
48
+ if (warp_idx == kSpecWarpStart) {
49
+ cute::prefetch_tma_descriptor(&tensor_map_q);
50
+ cute::prefetch_tma_descriptor(&tensor_map_sf_q);
51
+ cute::prefetch_tma_descriptor(&tensor_map_weights);
52
+ cute::prefetch_tma_descriptor(&tensor_map_kv);
53
+ cute::prefetch_tma_descriptor(&tensor_map_sf_kv);
54
+ }
55
+
56
+ // UMMA configs
57
+ static constexpr uint32_t kNumTmemStages = 3;
58
+ static constexpr uint32_t kNumUTCCPAlignedElems = 128;
59
+ static constexpr uint32_t UMMA_M = 128;
60
+ static constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads;
61
+ static constexpr uint32_t UMMA_K = 64;
62
+ static constexpr uint32_t kNumSFQ = math::constexpr_align(BLOCK_Q * kNumHeads, kNumUTCCPAlignedElems);
63
+ static constexpr uint32_t kNumSFKV = math::constexpr_align(BLOCK_KV, kNumUTCCPAlignedElems);
64
+ static constexpr uint32_t kRealNumSFQ = BLOCK_Q * kNumHeads;
65
+ DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
66
+ DG_STATIC_ASSERT(BLOCK_KV == kNumMathWarpGroups * UMMA_M and BLOCK_KV % kNumUTCCPAlignedElems == 0, "Invalid `BLOCK_KV`");
67
+
68
+ // Shared memory configs
69
+ static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2);
70
+ static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * (kHeadDim / 2);
71
+ static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQ * sizeof(int);
72
+ static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * (kHeadDim / 2);
73
+ static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int);
74
+ static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
75
+
76
+ // Align to swizzling alignment bytes
77
+ extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
78
+ DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
79
+ DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
80
+
81
+ // Q and KV data on shared memory
82
+ auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
83
+ return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i;
84
+ });
85
+ auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
86
+ return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i;
87
+ });
88
+ const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages);
89
+ auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) {
90
+ return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i);
91
+ });
92
+ auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) {
93
+ return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i);
94
+ });
95
+ auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
96
+ return reinterpret_cast<float*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages
97
+ + SMEM_WEIGHT_SIZE_PER_STAGE * i);
98
+ });
99
+
100
+ // Barriers and TMEM pointer on shared memory
101
+ const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
102
+ auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
103
+ auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
104
+ auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
105
+ auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
106
+ const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
107
+ auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; });
108
+ auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; });
109
+ auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(tmem_barrier_ptr + kNumTmemStages * 2);
110
+
111
+ // Tensor memory configs
112
+ constexpr uint32_t kNumAccumTmemCols = BLOCK_Q * kNumHeads * kNumTmemStages;
113
+ constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFQ / 32 + kNumSFKV / 32>();
114
+ constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols;
115
+ constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQ / 32;
116
+ DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
117
+
118
+ // Initialize barriers
119
+ if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) {
120
+ #pragma unroll
121
+ for (uint32_t i = 0; i < kNumQStages; ++ i) {
122
+ full_q_barriers[i]->init(1);
123
+ empty_q_barriers[i]->init(kNumMathThreads + 32);
124
+ }
125
+ #pragma unroll
126
+ for (uint32_t i = 0; i < kNumKVStages; ++ i) {
127
+ full_kv_barriers[i]->init(1);
128
+ empty_kv_barriers[i]->init(1);
129
+ }
130
+ #pragma unroll
131
+ for (uint32_t i = 0; i < kNumTmemStages; ++i) {
132
+ full_tmem_barriers[i]->init(1);
133
+ empty_tmem_barriers[i]->init(128);
134
+ }
135
+ cutlass::arch::fence_barrier_init();
136
+ }
137
+
138
+ // Allocate tensor memory
139
+ if (warp_idx == kSpecWarpStart + 2)
140
+ cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
141
+ __syncthreads();
142
+
143
+ // Scheduler
144
+ const uint32_t num_q_blocks = math::ceil_div(seq_len, BLOCK_Q);
145
+ uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
146
+ auto load_schedule = [&](const uint32_t& q_idx) -> cute::tuple<uint32_t, uint32_t> {
147
+ uint32_t start = cute::numeric_limits<uint32_t>::max();
148
+ uint32_t end = cute::numeric_limits<uint32_t>::min();
149
+ #pragma unroll
150
+ for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
151
+ const auto row_idx = cute::min(q_idx * BLOCK_Q + i, seq_len - 1);
152
+ seq_k_start[i] = cute::min(cu_seq_len_k_start[row_idx], seq_len_kv);
153
+ seq_k_end[i] = cute::min(cu_seq_len_k_end[row_idx], seq_len_kv);
154
+ start = cute::min(start, seq_k_start[i]);
155
+ end = cute::max(end, seq_k_end[i]);
156
+ }
157
+ // TMA alignment requirements for SF KV
158
+ start = start / 4 * 4;
159
+ return {start, math::ceil_div(end - start, BLOCK_KV)};
160
+ };
161
+
162
+ // Make Q, KV and TMEM pipeline
163
+ auto make_pipeline = [](const uint32_t& num_stages) {
164
+ // Return current stage and phase, and advance pipeline by steps
165
+ return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple<uint32_t, uint32_t> {
166
+ uint32_t current_idx = iter_idx;
167
+ iter_idx += step;
168
+ return {current_idx % num_stages, (current_idx / num_stages) & 1};
169
+ };
170
+ };
171
+ auto advance_q_pipeline = make_pipeline(kNumQStages);
172
+ auto advance_kv_pipeline = make_pipeline(kNumKVStages);
173
+ auto advance_tmem_pipeline = make_pipeline(kNumTmemStages);
174
+
175
+ // Register reconfigurations
176
+ constexpr uint32_t kNumSpecializedRegisters = 56;
177
+ constexpr uint32_t kNumMathRegisters = 224;
178
+
179
+ // Wait for primary kernel completion
180
+ cudaGridDependencySynchronize();
181
+
182
+ if (warp_idx == kSpecWarpStart) {
183
+ // TMA warp for loading Q
184
+ cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
185
+
186
+ // Enumerate Q blocks
187
+ if (cute::elect_one_sync()) {
188
+ for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
189
+ // Wait Q consumer release
190
+ CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
191
+ empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
192
+
193
+ // Issue TMA Q
194
+ cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[q_stage_idx]),
195
+ static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
196
+ smem_q[q_stage_idx], 0, q_idx * BLOCK_Q * kNumHeads);
197
+ tma::copy<BLOCK_Q * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_idx * BLOCK_Q);
198
+ tma::copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_idx * BLOCK_Q);
199
+ full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQ * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE);
200
+ }
201
+ }
202
+ __syncwarp();
203
+ } else if (warp_idx == kSpecWarpStart + 1) {
204
+ // TMA warp for loading KV cache
205
+ cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
206
+
207
+ if (cute::elect_one_sync()) {
208
+ // Enumerate Q blocks
209
+ for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
210
+ // Load KV block ranges
211
+ CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks);
212
+
213
+ // Enumerate KV blocks
214
+ for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) {
215
+ // Wait KV consumer release
216
+ CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
217
+ empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
218
+
219
+ // Issue TMA KV
220
+ cute::SM90_TMA_LOAD_2D::copy(&tensor_map_kv, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
221
+ static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
222
+ smem_kv[kv_stage_idx], 0, kv_start + kv_idx * BLOCK_KV);
223
+ tma::copy<BLOCK_KV, 1, 0>(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx],
224
+ smem_sf_kv[kv_stage_idx],
225
+ kv_start + kv_idx * BLOCK_KV, 0);
226
+ full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE);
227
+ }
228
+ }
229
+ }
230
+ } else if (warp_idx == kSpecWarpStart + 2) {
231
+ // UMMA warp
232
+ cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
233
+ DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
234
+
235
+ // UTCCP transposer
236
+ auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
237
+ DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
238
+ uint32_t values[4];
239
+ #pragma unroll
240
+ for (uint32_t i = 0; i < 4; ++ i)
241
+ values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
242
+ __syncwarp();
243
+ #pragma unroll
244
+ for (uint32_t i = 0; i < 4; ++ i)
245
+ ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
246
+ };
247
+
248
+ // Make UMMA desc
249
+ auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<cutlass::float_e2m1_t, cutlass::float_e2m1_t, float, cutlass::float_ue8m0_t,
250
+ UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
251
+ auto sf_desc = mma::sm100::make_sf_desc(nullptr);
252
+
253
+ // Enumerate Q blocks
254
+ for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
255
+ // Load KV block ranges
256
+ CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks);
257
+
258
+ // Wait TMA Q arrivals
259
+ CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
260
+ full_q_barriers[q_stage_idx]->wait(q_phase);
261
+
262
+ // Transpose and copy SF Q
263
+ #pragma unroll
264
+ for (uint32_t i = 0; i < kNumSFQ / kNumUTCCPAlignedElems; ++ i) {
265
+ auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems;
266
+ utccp_required_smem_warp_transpose(smem_ptr);
267
+ cutlass::arch::fence_view_async_shared();
268
+ mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
269
+ if (cute::elect_one_sync())
270
+ cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4);
271
+ __syncwarp();
272
+ }
273
+
274
+ // Enumerate KV blocks
275
+ for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) {
276
+ // Wait TMA KV arrivals
277
+ CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
278
+ full_kv_barriers[kv_stage_idx]->wait(kv_phase);
279
+
280
+ // Transpose
281
+ #pragma unroll
282
+ for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
283
+ auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
284
+ utccp_required_smem_warp_transpose(smem_ptr);
285
+ cutlass::arch::fence_view_async_shared();
286
+ }
287
+
288
+ // UMMA with SF
289
+ if (cute::elect_one_sync()) {
290
+ // Copy SF KV
291
+ #pragma unroll
292
+ for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
293
+ auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
294
+ mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
295
+ cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4);
296
+ }
297
+
298
+ #pragma unroll
299
+ for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
300
+ // Wait TMEM release
301
+ CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase);
302
+ uint32_t tmem_addr = tmem_stage_idx * UMMA_N;
303
+
304
+ empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1);
305
+ ptx::tcgen05_after_thread_sync();
306
+
307
+ // Issue UMMA with SF
308
+ #pragma unroll
309
+ for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
310
+ auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2);
311
+ // TODO: generalize umma desc
312
+ DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim");
313
+ auto a_desc = mma::sm100::make_smem_desc(
314
+ cute::UMMA::LayoutType::SWIZZLE_64B,
315
+ smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2,
316
+ 8 * (kHeadDim / 2), 0);
317
+ auto b_desc = mma::sm100::make_smem_desc(
318
+ cute::UMMA::LayoutType::SWIZZLE_64B,
319
+ smem_q[q_stage_idx] + k * UMMA_K / 2,
320
+ 8 * (kHeadDim / 2), 0);
321
+ ptx::SM100_MMA_MXF4_SS::fma(
322
+ a_desc, b_desc, tmem_addr, k, runtime_instr_desc,
323
+ kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ);
324
+ }
325
+ // TODO: move this into `deep_gemm/ptx/tcgen05.cuh`
326
+ asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];"
327
+ ::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx])));
328
+ }
329
+ }
330
+ cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_kv_barriers[kv_stage_idx]));
331
+ }
332
+
333
+ // UMMA warp must also arrive on empty_q to prevent running ahead
334
+ // of math warps in the Q pipeline. Without this, UMMA can consume
335
+ // kNumQStages Q blocks before math warps release any, causing a
336
+ // circular dependency: UMMA waits full_q -> TMA_Q waits empty_q
337
+ // -> Math waits full_tmem -> UMMA (already moved on).
338
+ empty_q_barriers[q_stage_idx]->arrive();
339
+ }
340
+ } else if (warp_idx == kSpecWarpStart + 3) {
341
+ cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
342
+ } else if (warp_idx < kSpecWarpStart) {
343
+ // Math warpgroups for reduce
344
+ cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
345
+
346
+ const auto math_warpgroup_idx = warpgroup_idx;
347
+ const auto math_thread_idx = threadIdx.x;
348
+
349
+ // Helper lambda for loading tensor memory
350
+ auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
351
+ constexpr uint32_t N = decltype(num_elems_c)::value;
352
+ DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
353
+ using Loader = cute::conditional_t<N == 32,
354
+ cute::SM100_TMEM_LOAD_32dp32b32x,
355
+ cute::SM100_TMEM_LOAD_32dp32b64x>;
356
+ [&]<size_t... Is>(cute::index_sequence<Is...>) {
357
+ Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
358
+ }(cute::make_index_sequence<N>{});
359
+ cutlass::arch::fence_view_async_tmem_load();
360
+ };
361
+
362
+ // Math warpgroups process TMEM stages alternately
363
+ // Advance pipeline to align with the assigned stage
364
+ advance_tmem_pipeline(math_warpgroup_idx);
365
+
366
+ // Local register buffers
367
+ float accum[kNumHeads];
368
+ float weights[BLOCK_Q][kNumHeads];
369
+
370
+ // Enumerate Q blocks
371
+ for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
372
+ // Load KV block ranges
373
+ CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks);
374
+
375
+ // Wait TMA Q arrivals
376
+ CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
377
+ full_q_barriers[q_stage_idx]->wait(q_phase);
378
+
379
+ // Read weights
380
+ // TODO: optimize bank conflicts
381
+ #pragma unroll
382
+ for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
383
+ #pragma unroll
384
+ for (uint32_t j = 0; j < kNumHeads; ++ j)
385
+ weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
386
+ }
387
+
388
+ // Enumerate KV blocks
389
+ for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) {
390
+ // Calculate KV offset in advance
391
+ auto kv_offset = kv_start + kv_idx * BLOCK_KV + math_thread_idx;
392
+
393
+ // Advance pipeline by `kNumMathWarpGroups` steps
394
+ // Wait UMMA arrival
395
+ CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase);
396
+ full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase);
397
+ ptx::tcgen05_after_thread_sync();
398
+
399
+ // Reduce over the head dim and store
400
+ #pragma unroll
401
+ for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
402
+ // Load accumulator from TMEM
403
+ uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
404
+ tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr, accum);
405
+ tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2);
406
+
407
+ // Release TMEM empty
408
+ if (i == BLOCK_Q - 1) {
409
+ ptx::tcgen05_before_thread_sync();
410
+ empty_tmem_barriers[tmem_stage_idx]->arrive();
411
+ }
412
+
413
+ // Accumulate weighted ReLU in parallel
414
+ auto sum_0 = make_float2(0, 0);
415
+ auto sum_1 = make_float2(0, 0);
416
+
417
+ const auto transform = [&](const uint32_t& j, const float2& sum) {
418
+ auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
419
+ auto b = make_float2(weights[i][j], weights[i][j + 1]);
420
+ return __ffma2_rn(a, b, sum);
421
+ };
422
+
423
+ #pragma unroll
424
+ for (uint32_t j = 0; j < kNumHeads; j += 4) {
425
+ sum_0 = transform(j, sum_0);
426
+ sum_1 = transform(j + 2, sum_1);
427
+ }
428
+
429
+ auto sum = __fadd2_rn(sum_0, sum_1);
430
+ auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
431
+
432
+ // Store into the global memory
433
+ // NOTES: we have redundant writes here, consider more carefully
434
+ // TODO: optimize performance
435
+ const auto q_offset = (q_idx * BLOCK_Q + i) * static_cast<uint64_t>(logits_stride);
436
+ if constexpr (kIsCompressedLogits) {
437
+ if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i])
438
+ logits[q_offset + kv_offset - seq_k_start[i]] = result;
439
+ } else {
440
+ logits[q_offset + kv_offset] = result;
441
+ }
442
+ __syncwarp();
443
+ }
444
+ }
445
+
446
+ // Release last Q empty
447
+ empty_q_barriers[q_stage_idx]->arrive();
448
+ }
449
+
450
+ // Free tensor memory
451
+ cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
452
+ if (warp_idx == 0)
453
+ cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
454
+ }
455
+ }
456
+
457
+ } // namespace deep_gemm
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cutlass/arch/barrier.h>
4
+ #include <cutlass/arch/reg_reconfig.h>
5
+
6
+ #include <cute/arch/cluster_sm90.hpp>
7
+ #include <cute/arch/copy_sm90_desc.hpp>
8
+
9
+ #include <deep_gemm/common/cute_tie.cuh>
10
+ #include <deep_gemm/common/math.cuh>
11
+ #include <deep_gemm/common/tma_copy.cuh>
12
+ #include <deep_gemm/common/utils.cuh>
13
+ #include <deep_gemm/mma/sm100.cuh>
14
+ #include <deep_gemm/ptx/ld_st.cuh>
15
+ #include <deep_gemm/ptx/tcgen05.cuh>
16
+ #include <deep_gemm/ptx/utils.cuh>
17
+ #include <deep_gemm/scheduler/paged_mqa_logits.cuh>
18
+
19
+ namespace deep_gemm {
20
+
21
+ template <uint32_t kNextN, uint32_t kNumHeads,
22
+ uint32_t kHeadDim, uint32_t BLOCK_KV,
23
+ bool kIsContextLens2D, bool kIsVarlen,
24
+ uint32_t kNumQStages, uint32_t kNumKVStages,
25
+ uint32_t SPLIT_KV,
26
+ uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
27
+ typename logits_dtype_t,
28
+ uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
29
+ CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
30
+ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
31
+ const uint32_t logits_stride, const uint32_t block_table_stride,
32
+ const uint32_t* context_lens, logits_dtype_t* logits,
33
+ const uint32_t* block_table, const uint32_t* indices,
34
+ const uint32_t* schedule_meta,
35
+ const __grid_constant__ cute::TmaDescriptor tensor_map_q,
36
+ const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q,
37
+ const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
38
+ const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv,
39
+ const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
40
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
41
+
42
+ // Utils
43
+ const auto sm_idx = blockIdx.x;
44
+ const auto warp_idx = cutlass::canonical_warp_idx_sync();
45
+ const auto warpgroup_idx = warp_idx / 4;
46
+ const auto lane_idx = ptx::get_lane_idx();
47
+ constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
48
+
49
+ // Prefetch TMA descriptors
50
+ if (warp_idx == kSpecWarpStart) {
51
+ cute::prefetch_tma_descriptor(&tensor_map_q);
52
+ cute::prefetch_tma_descriptor(&tensor_map_sf_q);
53
+ cute::prefetch_tma_descriptor(&tensor_map_weights);
54
+ cute::prefetch_tma_descriptor(&tensor_map_kv);
55
+ cute::prefetch_tma_descriptor(&tensor_map_sf_kv);
56
+ }
57
+
58
+ // For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill.
59
+ static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
60
+ static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
61
+ static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom);
62
+
63
+ // UMMA configs
64
+ static constexpr uint32_t kNumTmemStages = 3;
65
+ static constexpr uint32_t kNumUTCCPAlignedElems = 128;
66
+ static constexpr uint32_t UMMA_M = 128;
67
+ static constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads;
68
+ static constexpr uint32_t UMMA_K = 64;
69
+ static constexpr uint32_t kNumSFQAtom = math::constexpr_align(kNextNAtom * kNumHeads, kNumUTCCPAlignedElems);
70
+ static constexpr uint32_t kNumSFKV = math::constexpr_align(SPLIT_KV, kNumUTCCPAlignedElems);
71
+ static constexpr uint32_t kRealNumSFQAtom = kNextNAtom * kNumHeads;
72
+ DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
73
+ DG_STATIC_ASSERT(SPLIT_KV == kNumMathWarpGroups * UMMA_M and SPLIT_KV % kNumUTCCPAlignedElems == 0, "Invalid `SPLIT_KV`");
74
+
75
+ // Shared memory configs
76
+ static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2);
77
+ static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * (kHeadDim / 2);
78
+ static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQAtom * sizeof(int);
79
+ static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * (kHeadDim / 2);
80
+ static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int);
81
+ static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float);
82
+
83
+ // Align to swizzling alignment bytes
84
+ extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
85
+ DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
86
+ DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
87
+
88
+ // Q and KV data on shared memory
89
+ auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
90
+ return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i;
91
+ });
92
+ auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
93
+ return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i;
94
+ });
95
+ const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages);
96
+ auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) {
97
+ return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i);
98
+ });
99
+ auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) {
100
+ return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i);
101
+ });
102
+ auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
103
+ return reinterpret_cast<float*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages
104
+ + SMEM_WEIGHT_SIZE_PER_STAGE * i);
105
+ });
106
+
107
+ // Barriers and TMEM pointer on shared memory
108
+ const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
109
+ auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
110
+ auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
111
+ auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
112
+ auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
113
+ const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
114
+ auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; });
115
+ auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; });
116
+ auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(tmem_barrier_ptr + kNumTmemStages * 2);
117
+
118
+ // Tensor memory configs
119
+ constexpr uint32_t kNumAccumTmemCols = kNextNAtom * kNumHeads * kNumTmemStages;
120
+ constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFQAtom / 32 + kNumSFKV / 32>();
121
+ constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols;
122
+ constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQAtom / 32;
123
+ DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
124
+
125
+ // Initialize barriers
126
+ if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) {
127
+ #pragma unroll
128
+ for (uint32_t i = 0; i < kNumQStages; ++ i) {
129
+ full_q_barriers[i]->init(1);
130
+ empty_q_barriers[i]->init(kNumMathThreads + 32);
131
+ }
132
+ cutlass::arch::fence_barrier_init();
133
+ }
134
+ if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) {
135
+ #pragma unroll
136
+ for (uint32_t i = 0; i < kNumKVStages; ++ i) {
137
+ full_kv_barriers[i]->init(1);
138
+ empty_kv_barriers[i]->init(1);
139
+ }
140
+ cutlass::arch::fence_barrier_init();
141
+ }
142
+ if (warp_idx == kSpecWarpStart + 2) {
143
+ if (cute::elect_one_sync()) {
144
+ #pragma unroll
145
+ for (uint32_t i = 0; i < kNumTmemStages; ++i) {
146
+ full_tmem_barriers[i]->init(1);
147
+ empty_tmem_barriers[i]->init(128);
148
+ }
149
+ cutlass::arch::fence_barrier_init();
150
+ }
151
+ // Allocate tensor memory
152
+ cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
153
+ }
154
+ __syncthreads();
155
+
156
+ // Wait for primary kernel completion
157
+ cudaGridDependencySynchronize();
158
+
159
+ // Scheduler
160
+ constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
161
+ using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
162
+ DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
163
+
164
+ // Make Q, KV and TMEM pipeline
165
+ auto make_pipeline = [](const uint32_t& num_stages) {
166
+ // Return current stage and phase, and advance pipeline by steps
167
+ return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple<uint32_t, uint32_t> {
168
+ uint32_t current_idx = iter_idx;
169
+ iter_idx += step;
170
+ return {current_idx % num_stages, (current_idx / num_stages) & 1};
171
+ };
172
+ };
173
+ auto advance_q_pipeline = make_pipeline(kNumQStages);
174
+ auto advance_kv_pipeline = make_pipeline(kNumKVStages);
175
+ auto advance_tmem_pipeline = make_pipeline(kNumTmemStages);
176
+
177
+ // Register reconfigurations
178
+ constexpr uint32_t kNumSpecializedRegisters = 56;
179
+ constexpr uint32_t kNumMathRegisters = 224;
180
+
181
+ if (warp_idx == kSpecWarpStart) {
182
+ // TMA warp for loading Q
183
+ cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
184
+
185
+ if (cute::elect_one_sync()) {
186
+ auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
187
+
188
+ // Persistently schedule over blocks
189
+ // Initialize outside valid range to indicate no previous task
190
+ uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
191
+ uint32_t q_atom_idx, _, __;
192
+ while (scheduler.fetch_next_task(q_atom_idx, _, __)) {
193
+ // Issue TMA Q when (q_idx, atom_idx) changes
194
+ if (q_atom_idx != last_q_atom_idx) {
195
+ // Wait Q consumer release
196
+ CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
197
+ empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
198
+
199
+ // Issue TMA Q
200
+ const auto q_token_idx = Scheduler::atom_to_token_idx(q_atom_idx);
201
+ cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[q_stage_idx]),
202
+ static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
203
+ smem_q[q_stage_idx], 0, q_token_idx * kNumHeads);
204
+ tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_token_idx);
205
+ tma::copy<kNumHeads, kNextNAtom, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_token_idx);
206
+ full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQAtom * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE);
207
+ }
208
+ last_q_atom_idx = q_atom_idx;
209
+ }
210
+ }
211
+ __syncwarp();
212
+ } else if (warp_idx == kSpecWarpStart + 1) {
213
+ // TMA warp for loading KV cache
214
+ cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
215
+ auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
216
+
217
+ // Persistently schedule over blocks
218
+ uint32_t kv_block_idx_ptr = 32, kv_block_idx_storage;
219
+ uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
220
+ uint32_t q_atom_idx, kv_idx, num_kv;
221
+ while (scheduler.fetch_next_task(q_atom_idx, kv_idx, num_kv)) {
222
+ // Reset block table cache on kv restart
223
+ if (q_atom_idx != last_q_atom_idx)
224
+ kv_block_idx_ptr = 32;
225
+ last_q_atom_idx = q_atom_idx;
226
+
227
+ // Coalesced load of block table
228
+ if (kv_block_idx_ptr == 32) {
229
+ kv_block_idx_ptr = 0;
230
+ const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast<uint64_t>(block_table_stride);
231
+ kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
232
+ ? block_table[block_table_offset + kv_idx + lane_idx] : 0;
233
+ }
234
+ __syncwarp();
235
+
236
+ // Broadcast KV block indices
237
+ int kv_block_idx[kNumBlocksPerSplit];
238
+ #pragma unroll
239
+ for (int i = 0; i < kNumBlocksPerSplit; ++ i)
240
+ kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i);
241
+ kv_block_idx_ptr += kNumBlocksPerSplit;
242
+ DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `SPLIT_KV`");
243
+
244
+ // Wait KV consumer release
245
+ CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
246
+
247
+ // Issue TMA KV
248
+ if (cute::elect_one_sync()) {
249
+ empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
250
+ #pragma unroll
251
+ for (int i = 0; i < kNumBlocksPerSplit; ++ i) {
252
+ cute::SM90_TMA_LOAD_3D::copy(&tensor_map_kv, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
253
+ static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
254
+ smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim / 2) * i,
255
+ 0, 0, kv_block_idx[i]);
256
+ tma::copy<BLOCK_KV, 1, 0>(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx],
257
+ smem_sf_kv[kv_stage_idx] + BLOCK_KV * i,
258
+ 0, kv_block_idx[i]);
259
+ }
260
+ full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE);
261
+ }
262
+ }
263
+ } else if (warp_idx == kSpecWarpStart + 2) {
264
+ // UMMA warp
265
+ cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
266
+ auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
267
+ DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
268
+
269
+ // UTCCP transposer
270
+ auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
271
+ DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
272
+ uint32_t values[4];
273
+ #pragma unroll
274
+ for (uint32_t i = 0; i < 4; ++ i)
275
+ values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
276
+ __syncwarp();
277
+ #pragma unroll
278
+ for (uint32_t i = 0; i < 4; ++ i)
279
+ ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
280
+ };
281
+
282
+ // Make UMMA desc
283
+ auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<cutlass::float_e2m1_t, cutlass::float_e2m1_t, float, cutlass::float_ue8m0_t,
284
+ UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
285
+ auto sf_desc = mma::sm100::make_sf_desc(nullptr);
286
+
287
+ // Persistently schedule over blocks
288
+ uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
289
+ uint32_t q_atom_idx, kv_idx, _;
290
+ while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) {
291
+ // Wait TMA Q arrivals
292
+ uint32_t q_stage_idx, q_phase;
293
+ if (q_atom_idx != last_q_atom_idx) {
294
+ CUTE_TIE(advance_q_pipeline(), q_stage_idx, q_phase);
295
+
296
+ // Release previous Q empty (UMMA warp must participate to prevent
297
+ // running ahead of math warps in the Q pipeline)
298
+ if (last_q_atom_idx != batch_size * kNumNextNAtoms)
299
+ empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive();
300
+
301
+ full_q_barriers[q_stage_idx]->wait(q_phase);
302
+
303
+ // Transpose and copy SF Q
304
+ #pragma unroll
305
+ for (uint32_t i = 0; i < kNumSFQAtom / kNumUTCCPAlignedElems; ++ i) {
306
+ auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems;
307
+ utccp_required_smem_warp_transpose(smem_ptr);
308
+ cutlass::arch::fence_view_async_shared();
309
+ mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
310
+ if (cute::elect_one_sync())
311
+ cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4);
312
+ __syncwarp();
313
+ }
314
+ }
315
+ last_q_atom_idx = q_atom_idx;
316
+
317
+ // Wait TMA KV arrivals
318
+ CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
319
+ full_kv_barriers[kv_stage_idx]->wait(kv_phase);
320
+
321
+ // Transpose
322
+ #pragma unroll
323
+ for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
324
+ auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
325
+ utccp_required_smem_warp_transpose(smem_ptr);
326
+ cutlass::arch::fence_view_async_shared();
327
+ }
328
+
329
+ // UMMA with SF
330
+ if (cute::elect_one_sync()) {
331
+ // Copy SF KV
332
+ #pragma unroll
333
+ for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
334
+ auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
335
+ mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
336
+ cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4);
337
+ }
338
+
339
+ #pragma unroll
340
+ for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
341
+ // Wait TMEM release
342
+ CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase);
343
+ uint32_t tmem_addr = tmem_stage_idx * UMMA_N;
344
+
345
+ empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1);
346
+ ptx::tcgen05_after_thread_sync();
347
+
348
+ // Issue UMMA with SF
349
+ #pragma unroll
350
+ for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
351
+ auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2);
352
+ // TODO: generalize UMMA desc
353
+ DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim");
354
+ auto a_desc = mma::sm100::make_smem_desc(
355
+ cute::UMMA::LayoutType::SWIZZLE_64B,
356
+ smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2,
357
+ 8 * (kHeadDim / 2), 0);
358
+ auto b_desc = mma::sm100::make_smem_desc(
359
+ cute::UMMA::LayoutType::SWIZZLE_64B,
360
+ smem_q[q_stage_idx] + k * UMMA_K / 2,
361
+ 8 * (kHeadDim / 2), 0);
362
+ ptx::SM100_MMA_MXF4_SS::fma(a_desc, b_desc, tmem_addr, k, runtime_instr_desc,
363
+ kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ);
364
+ }
365
+ // TODO: move this PTX into headers
366
+ asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];"
367
+ ::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx])));
368
+ }
369
+ }
370
+ cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_kv_barriers[kv_stage_idx]));
371
+ }
372
+ } else if (warp_idx == kSpecWarpStart + 3) {
373
+ cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
374
+ } else if (warp_idx < kSpecWarpStart) {
375
+ // Math warpgroups for reduce
376
+ cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
377
+ auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
378
+
379
+ const auto math_warpgroup_idx = warpgroup_idx;
380
+ const auto math_thread_idx = warp_idx * 32 + lane_idx;
381
+
382
+ // Helper lambda for loading tensor memory
383
+ auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
384
+ constexpr int N = decltype(num_elems_c)::value;
385
+ DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
386
+ using Loader = cute::conditional_t<N == 32,
387
+ cute::SM100_TMEM_LOAD_32dp32b32x,
388
+ cute::SM100_TMEM_LOAD_32dp32b64x>;
389
+ [&]<size_t... Is>(cute::index_sequence<Is...>) {
390
+ Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
391
+ }(cute::make_index_sequence<N>{});
392
+ cutlass::arch::fence_view_async_tmem_load();
393
+ };
394
+
395
+ // Math warpgroups process TMEM stages alternately
396
+ // Advance pipeline to align with the assigned stage
397
+ advance_tmem_pipeline(math_warpgroup_idx);
398
+
399
+ // Local register buffers
400
+ float accum[kNumHeads];
401
+ float weights[kNextNAtom][kNumHeads];
402
+
403
+ // Persistently schedule over blocks
404
+ uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
405
+ uint32_t q_atom_idx, kv_idx, _;
406
+ bool is_paired_atom = false;
407
+ while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) {
408
+ if (q_atom_idx != last_q_atom_idx) {
409
+ CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
410
+
411
+ // Release last Q empty
412
+ if (last_q_atom_idx != batch_size * kNumNextNAtoms)
413
+ empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive();
414
+
415
+ // Wait TMA Q arrivals
416
+ full_q_barriers[q_stage_idx]->wait(q_phase);
417
+
418
+ // Read weights
419
+ #pragma unroll
420
+ for (uint32_t i = 0; i < kNextNAtom; ++ i) {
421
+ #pragma unroll
422
+ for (uint32_t j = 0; j < kNumHeads; j += 4) {
423
+ float4 raw = ptx::ld_shared((float4*)(smem_weights[q_stage_idx] + i * kNumHeads + j));
424
+ weights[i][j + 0] = raw.x;
425
+ weights[i][j + 1] = raw.y;
426
+ weights[i][j + 2] = raw.z;
427
+ weights[i][j + 3] = raw.w;
428
+ }
429
+ }
430
+
431
+ // Check if this atom pairs two tokens from the same sequence
432
+ if constexpr (kIsVarlen) {
433
+ is_paired_atom = (scheduler.get_atom_advance(q_atom_idx, batch_size) == 2);
434
+ }
435
+ }
436
+ last_q_atom_idx = q_atom_idx;
437
+
438
+ // Calculate KV offset in advance
439
+ auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx;
440
+
441
+ // Advance pipeline by `kNumMathWarpGroups` steps
442
+ // Wait UMMA arrival
443
+ CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase);
444
+ full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase);
445
+ ptx::tcgen05_after_thread_sync();
446
+
447
+ // Reduce over the head dim and store
448
+ const auto reduce_and_store = [&](auto num_iters_c) {
449
+ constexpr uint32_t kNumIters = decltype(num_iters_c)::value;
450
+
451
+ // Only loop over valid iterations
452
+ #pragma unroll
453
+ for (uint32_t i = 0; i < kNumIters; ++ i) {
454
+ // Load accumulator from TMEM
455
+ uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
456
+ tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr, accum);
457
+ tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2);
458
+
459
+ // Accumulate weighted ReLU in parallel
460
+ auto sum_0 = make_float2(0, 0);
461
+ auto sum_1 = make_float2(0, 0);
462
+
463
+ const auto transform = [&](const uint32_t& j, const float2& sum) {
464
+ auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
465
+ auto b = make_float2(weights[i][j], weights[i][j + 1]);
466
+ return __ffma2_rn(a, b, sum);
467
+ };
468
+
469
+ #pragma unroll
470
+ for (uint32_t j = 0; j < kNumHeads; j += 4) {
471
+ sum_0 = transform(j, sum_0);
472
+ sum_1 = transform(j + 2, sum_1);
473
+ }
474
+
475
+ auto sum = __fadd2_rn(sum_0, sum_1);
476
+ auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
477
+
478
+ // Store into the global memory
479
+ logits[kv_offset + i * static_cast<uint64_t>(logits_stride)] = result;
480
+ __syncwarp();
481
+ }
482
+
483
+ // Release TMEM empty
484
+ ptx::tcgen05_before_thread_sync();
485
+ empty_tmem_barriers[tmem_stage_idx]->arrive();
486
+ };
487
+
488
+ if constexpr (kIsVarlen) {
489
+ if (is_paired_atom)
490
+ reduce_and_store(cute::Int<kNextNAtom>{});
491
+ else
492
+ reduce_and_store(cute::Int<1>{});
493
+ } else if constexpr (kPadOddN) {
494
+ if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1)
495
+ reduce_and_store(cute::Int<1>{});
496
+ else
497
+ reduce_and_store(cute::Int<kNextNAtom>{});
498
+ } else {
499
+ reduce_and_store(cute::Int<kNextNAtom>{});
500
+ }
501
+ }
502
+
503
+ // Free tensor memory
504
+ cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
505
+ if (warp_idx == 0)
506
+ cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
507
+ }
508
+ }
509
+
510
+ } // namespace deep_gemm
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #pragma clang diagnostic push
3
+ #pragma clang diagnostic ignored "-Wunknown-attributes"
4
+
5
+ #include <cutlass/arch/barrier.h>
6
+
7
+ #include <deep_gemm/common/math.cuh>
8
+ #include <deep_gemm/common/tma_copy.cuh>
9
+ #include <deep_gemm/epilogue/transform.cuh>
10
+ #include <deep_gemm/epilogue/sm100_store_cd.cuh>
11
+ #include <deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh>
12
+ #include <deep_gemm/mma/sm100.cuh>
13
+ #include <deep_gemm/scheduler/gemm.cuh>
14
+ #include <deep_gemm/ptx/utils.cuh>
15
+
16
+ namespace deep_gemm {
17
+
18
+ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
19
+ uint32_t kGranKA, uint32_t kGranKB,
20
+ uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
21
+ uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
22
+ uint32_t kNumGroups,
23
+ uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
24
+ uint32_t kNumStages,
25
+ uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
26
+ uint32_t kNumMulticast, bool kIsMulticastOnA,
27
+ uint32_t kNumSMs,
28
+ bool kSwapAB,
29
+ GemmType kGemmType, bool kWithAccumulation,
30
+ typename a_dtype_t, typename b_dtype_t, typename cd_dtype_t,
31
+ typename epilogue_type_t>
32
+ CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
33
+ sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout,
34
+ uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
35
+ const __grid_constant__ cute::TmaDescriptor tensor_map_a,
36
+ const __grid_constant__ cute::TmaDescriptor tensor_map_b,
37
+ const __grid_constant__ cute::TmaDescriptor tensor_map_sfa,
38
+ const __grid_constant__ cute::TmaDescriptor tensor_map_sfb,
39
+ const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
40
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
41
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
42
+ using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
43
+
44
+ // GEMM with accumulation must have FP32 output
45
+ if constexpr (kWithAccumulation)
46
+ DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
47
+
48
+ // MMA Configs
49
+ constexpr uint32_t LAYOUT_AD_M = 128;
50
+ constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast;
51
+ constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N;
52
+ constexpr uint32_t UMMA_K = 32;
53
+ constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
54
+ constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
55
+ DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
56
+ DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
57
+ DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or
58
+ (not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size");
59
+
60
+ // SF configs
61
+ constexpr uint32_t kNumUTCCPAlignedElems = 128;
62
+ constexpr uint32_t SF_BLOCK_M = math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems);
63
+ constexpr uint32_t SF_BLOCK_N = math::constexpr_align(BLOCK_N, kNumUTCCPAlignedElems);
64
+ constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4;
65
+ constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4;
66
+ DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A");
67
+ DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B");
68
+ DG_STATIC_ASSERT((kGemmType != GemmType::KGroupedContiguous) or kGranKA == kGranKB, "K-grouped SF requires kGranKA == kGranKB");
69
+
70
+ // Epilogue configs
71
+ // Always enable pipeline for better performance
72
+ constexpr uint32_t kNumEpilogueStages = 2;
73
+ constexpr uint32_t kNumTMAStoreStages = 2;
74
+ // NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N
75
+ // per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases
76
+ constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
77
+ constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t);
78
+ constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M;
79
+ DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M");
80
+
81
+ // Share memory sizes
82
+ constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t);
83
+ constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
84
+ constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
85
+ constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
86
+ constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
87
+ constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
88
+ DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0,
89
+ "Shared memory of A/B must be aligned to 1024 bytes");
90
+ // NOTES: Make sure we have enough shared memory for UMMA padding
91
+ constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t);
92
+ DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA");
93
+
94
+ // Tensor memory size and offsets
95
+ constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages;
96
+ constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
97
+ constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
98
+ constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
99
+ constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
100
+ constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
101
+ DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
102
+
103
+ // Synchronize the cluster before 2-CTA TMEM allocation
104
+ kNumMulticast > 1 ? cute::cluster_sync() : void();
105
+
106
+ // Utils
107
+ const bool is_leader_cta = cute::block_rank_in_cluster() == 0;
108
+ const auto warp_idx = cutlass::canonical_warp_idx_sync();
109
+ const auto lane_idx = ptx::get_lane_idx();
110
+
111
+ // Prefetch TMA descriptors at the very beginning
112
+ if (warp_idx == 0) {
113
+ cute::prefetch_tma_descriptor(&tensor_map_a);
114
+ cute::prefetch_tma_descriptor(&tensor_map_b);
115
+ cute::prefetch_tma_descriptor(&tensor_map_sfa);
116
+ cute::prefetch_tma_descriptor(&tensor_map_sfb);
117
+ cute::prefetch_tma_descriptor(&tensor_map_cd);
118
+ }
119
+
120
+ // Overwrite shape constants if the compiler gives
121
+ shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
122
+ shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
123
+ shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
124
+ const auto shape_sfa_k = math::ceil_div(shape_k, kGranKA * 4);
125
+ const auto shape_sfb_k = math::ceil_div(shape_k, kGranKB * 4);
126
+
127
+ // Align to 1024 bytes for swizzle-128B
128
+ extern __shared__ __align__(1024) uint8_t smem_buffer[];
129
+
130
+ // D/A/B shared memory
131
+ auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) {
132
+ return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
133
+ });
134
+ auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
135
+ return reinterpret_cast<a_dtype_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
136
+ });
137
+ auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
138
+ return reinterpret_cast<b_dtype_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
139
+ });
140
+
141
+ // SFA/SFB shared memory
142
+ auto sf_start_ptr = reinterpret_cast<uint8_t*>(smem_b[kNumStages]);
143
+ auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) {
144
+ return reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
145
+ });
146
+ auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) {
147
+ return reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
148
+ });
149
+
150
+ // Barriers and tensor memory pointer
151
+ auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_sfb[kNumStages]);;
152
+ auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
153
+ auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
154
+ auto with_sf_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
155
+ auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
156
+ auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); });
157
+ auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
158
+
159
+ // Initialize barriers
160
+ if (warp_idx == 1 and cute::elect_one_sync()) {
161
+ #pragma unroll
162
+ for (uint32_t i = 0; i < kNumStages; ++ i) {
163
+ // Arrive at all CTAs
164
+ full_barriers[i]->init(1);
165
+ empty_barriers[i]->init(1);
166
+ // Arrive only at the leader CTA
167
+ with_sf_full_barriers[i]->init(kNumMulticast * 32);
168
+ }
169
+ #pragma unroll
170
+ for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
171
+ // Arrive at all CTAs
172
+ tmem_full_barriers[i]->init(1);
173
+ // Arrive only at the leader CTA
174
+ tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads);
175
+ }
176
+
177
+ // Make initialized barrier visible in async proxy
178
+ cutlass::arch::fence_barrier_init();
179
+ } else if (warp_idx == 2) {
180
+ // Allocate tensor memory
181
+ Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
182
+ }
183
+ kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
184
+
185
+ // Wait for primary kernel completion
186
+ cudaGridDependencySynchronize();
187
+
188
+ // Block scheduler
189
+ uint32_t m_block_idx, n_block_idx;
190
+ auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs, kGranKA * 4>(
191
+ shape_m, shape_n, shape_k, grouped_layout);
192
+
193
+ // Pipeline and TMA phases
194
+ uint32_t stage_idx = 0, phase = 0;
195
+ auto advance_pipeline = [&](uint32_t& k_block_idx) {
196
+ ++ k_block_idx;
197
+
198
+ // Flip phases only if reach the next first stage
199
+ stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
200
+ phase ^= stage_idx == 0;
201
+ };
202
+
203
+ // Dispatch warps into different roles
204
+ if (warp_idx == 0 and cute::elect_one_sync()) {
205
+ // TMA load warp
206
+ // Persistently schedule over blocks
207
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
208
+ // Use dynamic load block M, when swap-AB is enabled
209
+ const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M;
210
+
211
+ // For k-grouped layout, the number of block K is variable
212
+ const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
213
+ for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
214
+ // Wait consumer release
215
+ empty_barriers[stage_idx]->wait(phase ^ 1);
216
+
217
+ // Compute offsets
218
+ // NOTES: the group is always concatenated with the outer dimension
219
+ uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> (
220
+ shape_m, BLOCK_M, m_block_idx);
221
+ uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> (
222
+ shape_n, BLOCK_N, n_block_idx, m_block_idx);
223
+
224
+ // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
225
+ // And for all m-grouped GEMMs, A must be K-majored
226
+ DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
227
+ kMajorA == cute::UMMA::Major::K, "Invalid major");
228
+ uint32_t k_idx = k_block_idx * BLOCK_K;
229
+ uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> (
230
+ shape_k, BLOCK_K, k_block_idx, m_block_idx);
231
+ uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> (
232
+ shape_k, BLOCK_K, k_block_idx, m_block_idx);
233
+
234
+ // Add 2 CTA offsets
235
+ if constexpr (kNumMulticast > 1) {
236
+ m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0;
237
+ n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
238
+ }
239
+
240
+ // Issue TMAs
241
+ constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
242
+ const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
243
+ if constexpr (kMajorA == cute::UMMA::Major::K)
244
+ tma::copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
245
+ &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx);
246
+ if constexpr (kMajorA == cute::UMMA::Major::MN)
247
+ tma::copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
248
+ &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx);
249
+ if constexpr (kMajorB == cute::UMMA::Major::K)
250
+ tma::copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
251
+ &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx);
252
+ if constexpr (kMajorB == cute::UMMA::Major::MN)
253
+ tma::copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
254
+ &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx);
255
+ auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v<a_dtype_t, cutlass::float_e4m3_t> ? 1 : 2) +
256
+ SMEM_B_SIZE_PER_STAGE / (std::is_same_v<b_dtype_t, cutlass::float_e4m3_t> ? 1 : 2);
257
+
258
+ // Issue SFA and SFB TMAs at certain stages
259
+ // No swizzling, so one TMA for one SF is enough
260
+ if (k_block_idx % kNumSFAStagesPerLoad == 0) {
261
+ uint32_t sfa_m_idx = m_block_idx * BLOCK_M;
262
+ uint32_t sfa_k_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::SF_K>(
263
+ shape_sfa_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad));
264
+ tma::copy<BLOCK_M, 1, 0>(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx);
265
+ num_arrival_bytes += BLOCK_M * sizeof(uint32_t);
266
+ }
267
+ if (k_block_idx % kNumSFBStagesPerLoad == 0) {
268
+ uint32_t sfb_n_idx = n_block_idx * BLOCK_N;
269
+ uint32_t sfb_k_idx = scheduler.template get_global_idx<true, sched::IndexType::SF_K>(
270
+ shape_sfb_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx);
271
+ tma::copy<BLOCK_N, 1, 0>(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx);
272
+ num_arrival_bytes += BLOCK_N * sizeof(uint32_t);
273
+ }
274
+
275
+ // Arrive at full barriers
276
+ full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes);
277
+ }
278
+ }
279
+ } else if (warp_idx == 1 and is_leader_cta) {
280
+ // MMA issue warp
281
+ // NOTES: only the leader CTA will do this
282
+ // Make instruction descriptor
283
+ auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc_block_scaled<b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t,
284
+ UMMA_M, UMMA_N, kMajorB, kMajorA>()
285
+ : cute::UMMA::make_instr_desc_block_scaled<a_dtype_t, b_dtype_t, float, cutlass::float_ue8m0_t,
286
+ UMMA_M, UMMA_N, kMajorA, kMajorB>();
287
+ auto sf_desc = mma::sm100::make_sf_desc(nullptr);
288
+
289
+ DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
290
+ auto a_desc = mma::sm100::make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0);
291
+ auto b_desc = mma::sm100::make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0);
292
+ uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
293
+ uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
294
+
295
+ // Checks for MMA instructions
296
+ // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
297
+ DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
298
+ (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
299
+ (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
300
+ "Invalid MMA instruction shape");
301
+
302
+ // Persistently schedule over blocks
303
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
304
+ // Wait tensor memory empty barrier arrival
305
+ auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
306
+ auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
307
+ tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
308
+ ptx::tcgen05_after_thread_sync();
309
+
310
+ // Empty barrier arrival
311
+ auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
312
+ auto umma_arrive = [](const uint64_t* barrier) {
313
+ if constexpr (kNumMulticast == 1) {
314
+ cutlass::arch::umma_arrive(barrier);
315
+ } else {
316
+ constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
317
+ cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
318
+ }
319
+ };
320
+ umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
321
+
322
+ // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
323
+ if (do_tmem_full_arrive)
324
+ umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
325
+ __syncwarp();
326
+ };
327
+
328
+ // Dynamic update of UMMA N based on effective M, when swap-AB is enabled
329
+ if constexpr (kSwapAB) {
330
+ uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx);
331
+ mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n);
332
+ }
333
+
334
+ // Launch MMAs
335
+ const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
336
+ #pragma unroll 4
337
+ for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
338
+ // Wait TMA and SF-transpose arrival
339
+ with_sf_full_barriers[stage_idx]->wait(phase);
340
+ ptx::tcgen05_after_thread_sync();
341
+
342
+ const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx);
343
+ const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx);
344
+ if (cute::elect_one_sync()) {
345
+ // Do SF copy at certain stages
346
+ // TODO: process shared memory descriptor by addition
347
+ using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
348
+ cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
349
+ const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad;
350
+ if (sfa_stage_in_group_idx == 0) {
351
+ #pragma unroll
352
+ for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
353
+ auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
354
+ mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
355
+ cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
356
+ }
357
+ }
358
+ const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad;
359
+ if (sfb_stage_in_group_idx == 0) {
360
+ #pragma unroll
361
+ for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
362
+ auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
363
+ mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
364
+ cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
365
+ }
366
+ }
367
+
368
+ // Issue UMMA
369
+ using mma_t = cute::conditional_t<
370
+ kNumMulticast == 1, ptx::SM100_MMA_MXF8F6F4_SS, ptx::SM100_MMA_MXF8F6F4_2x1SM_SS>;
371
+ #pragma unroll
372
+ for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
373
+ const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx);
374
+ const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx);
375
+ const auto runtime_instr_desc = kSwapAB ?
376
+ mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfb_id, sfa_id):
377
+ mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id);
378
+
379
+ a_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K);
380
+ b_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K);
381
+ if constexpr (kSwapAB) {
382
+ mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N,
383
+ k_block_idx > 0 or k > 0, runtime_instr_desc,
384
+ kTmemStartColOfSFB, kTmemStartColOfSFA);
385
+ } else {
386
+ mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N,
387
+ k_block_idx > 0 or k > 0, runtime_instr_desc,
388
+ kTmemStartColOfSFA, kTmemStartColOfSFB);
389
+ }
390
+ }
391
+ }
392
+ __syncwarp();
393
+
394
+ // Commit to the mbarrier object
395
+ // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
396
+ empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
397
+ }
398
+ }
399
+
400
+ // To safely deconstruct barriers, we need another round of waits
401
+ const auto iter_idx = scheduler.current_iter - 1;
402
+ if (kNumMulticast > 1 and iter_idx >= 0) {
403
+ const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
404
+ tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
405
+ }
406
+ } else if (warp_idx == 2) {
407
+ // UTCCP transposer
408
+ auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
409
+ DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
410
+ uint32_t values[4];
411
+ #pragma unroll
412
+ for (uint32_t i = 0; i < 4; ++ i)
413
+ values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
414
+ __syncwarp();
415
+ #pragma unroll
416
+ for (uint32_t i = 0; i < 4; ++ i)
417
+ ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
418
+ };
419
+
420
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
421
+ const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
422
+ for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
423
+ // Wait TMA arrival
424
+ full_barriers[stage_idx]->wait(phase);
425
+
426
+ // Transpose for UTCCP at certain stages
427
+ if (k_block_idx % kNumSFAStagesPerLoad == 0) {
428
+ #pragma unroll
429
+ for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
430
+ utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems);
431
+ // TODO: figure out whether the proxy fence is valid for 2-CTA cases
432
+ cutlass::arch::fence_view_async_shared();
433
+ }
434
+ if (k_block_idx % kNumSFBStagesPerLoad == 0) {
435
+ #pragma unroll
436
+ for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
437
+ utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems);
438
+ // TODO: figure out whether the proxy fence is valid for 2-CTA cases
439
+ cutlass::arch::fence_view_async_shared();
440
+ }
441
+
442
+ // Arrive
443
+ with_sf_full_barriers[stage_idx]->arrive(0u);
444
+ }
445
+ }
446
+ } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) {
447
+ // Epilogue warp groups
448
+ const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
449
+
450
+ // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
451
+ // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
452
+ // NOTES: we also forbid two CTAs to share the same SM and its tensor memory
453
+ DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
454
+
455
+ // Share store pipeline between blocks
456
+ uint32_t tma_stage_idx = 0;
457
+
458
+ // Persistently schedule over blocks
459
+ while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
460
+ auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
461
+ auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
462
+
463
+ // Wait UMMA arrival
464
+ tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
465
+ ptx::tcgen05_after_thread_sync();
466
+
467
+ const auto tmem_base_addr = accum_stage_idx * UMMA_N;
468
+ const auto base_m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
469
+ const auto base_n_idx = n_block_idx * BLOCK_N;
470
+
471
+ if constexpr (kSwapAB) {
472
+ const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx);
473
+ epilogue::sm100_store_cd_swap_ab<
474
+ BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
475
+ kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
476
+ kGemmType, kWithAccumulation,
477
+ cd_dtype_t, epilogue_type_t>
478
+ (smem_cd, tma_stage_idx, tmem_base_addr,
479
+ base_m_idx, base_n_idx, scheduler.current_group_idx,
480
+ effective_m,
481
+ epilogue_warp_idx, lane_idx,
482
+ tmem_empty_barriers[accum_stage_idx],
483
+ tensor_map_cd);
484
+ } else {
485
+ epilogue::sm100_store_cd<
486
+ BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
487
+ kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
488
+ kGemmType, kWithAccumulation,
489
+ cd_dtype_t, epilogue_type_t>
490
+ (smem_cd, tma_stage_idx, tmem_base_addr,
491
+ base_m_idx, base_n_idx, scheduler.current_group_idx,
492
+ epilogue_warp_idx, lane_idx,
493
+ tmem_empty_barriers[accum_stage_idx],
494
+ tensor_map_cd);
495
+ }
496
+ }
497
+ }
498
+
499
+ // TODO: Remove redundant synchronization
500
+ kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
501
+
502
+ // Deallocate tensor memory
503
+ if (warp_idx == 0)
504
+ Allocator().free(0, kNumTmemCols);
505
+
506
+ #else
507
+ if (blockIdx.x == 0 and threadIdx.x == 0)
508
+ DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
509
+ #endif
510
+ }
511
+
512
+ }; // namespace deep_gemm
513
+
514
+ #pragma clang diagnostic pop
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh ADDED
@@ -0,0 +1,1380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cstdint>
4
+ #include <cutlass/arch/barrier.h>
5
+ #include <cutlass/arch/reg_reconfig.h>
6
+
7
+ #include <deep_gemm/common/math.cuh>
8
+ #include <deep_gemm/common/tma_copy.cuh>
9
+ #include <deep_gemm/common/utils.cuh>
10
+ #include <deep_gemm/comm/barrier.cuh>
11
+ #include <deep_gemm/layout/sym_buffer.cuh>
12
+ #include <deep_gemm/layout/mega_moe.cuh>
13
+ #include <deep_gemm/mma/sm100.cuh>
14
+ #include <deep_gemm/scheduler/mega_moe.cuh>
15
+ #include <deep_gemm/ptx/tcgen05.cuh>
16
+ #include <deep_gemm/ptx/tma.cuh>
17
+ #include <deep_gemm/ptx/utils.cuh>
18
+
19
+ namespace deep_gemm {
20
+
21
+ template <
22
+ uint32_t kNumMaxTokensPerRank,
23
+ uint32_t kHidden, uint32_t kIntermediateHidden,
24
+ uint32_t kNumExperts, uint32_t kNumTopk,
25
+ uint32_t kNumExpertsPerWave,
26
+ uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
27
+ uint32_t STORE_BLOCK_M,
28
+ uint32_t SF_BLOCK_M, uint32_t SF_BLOCK_N,
29
+ uint32_t kNumMaxPoolTokens,
30
+ uint32_t kNumPaddedSFPoolTokens,
31
+ uint32_t kNumStages,
32
+ uint32_t kNumDispatchThreads, uint32_t kNumNonEpilogueThreads,
33
+ uint32_t kNumEpilogueThreads,
34
+ uint32_t kNumSMs, uint32_t kNumRanks,
35
+ float kActivationClamp,
36
+ bool kFastMath,
37
+ uint32_t L1_SHAPE_N = kIntermediateHidden * 2,
38
+ uint32_t L1_SHAPE_K = kHidden,
39
+ uint32_t L2_SHAPE_N = kHidden,
40
+ uint32_t L2_SHAPE_K = kIntermediateHidden,
41
+ uint32_t kNumDispatchWarps = kNumDispatchThreads / 32,
42
+ uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32,
43
+ uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32,
44
+ uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4,
45
+ uint32_t kNumThreads = kNumDispatchThreads + kNumNonEpilogueThreads + kNumEpilogueThreads,
46
+ uint32_t kNumTokensPerWarp = 32 / kNumTopk,
47
+ uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks
48
+ >
49
+ CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void
50
+ sm100_fp8_fp4_mega_moe_impl(void* y,
51
+ int* cumulative_local_expert_recv_stats,
52
+ const uint32_t num_tokens,
53
+ const __grid_constant__ layout::SymBuffer<kNumRanks> sym_buffer,
54
+ const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts,
55
+ const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf,
56
+ const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights,
57
+ const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights_sf,
58
+ const __grid_constant__ cute::TmaDescriptor tensor_map_l1_output,
59
+ const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts,
60
+ const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts_sf,
61
+ const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights,
62
+ const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights_sf) {
63
+ #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
64
+ using Barrier = cutlass::arch::ClusterTransactionBarrier;
65
+ using Allocator = cute::TMEM::Allocator2Sm;
66
+
67
+ // Template checks
68
+ DG_STATIC_ASSERT(kNumDispatchThreads % 128 == 0, "Invalid number of dispatch threads");
69
+ DG_STATIC_ASSERT(kNumNonEpilogueThreads == 128, "Invalid number of MMA non-epilogue threads");
70
+ DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Invalid number of MMA epilogue and combine threads");
71
+ DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid number of experts or ranks");
72
+
73
+ // Thread indices
74
+ const bool is_leader_cta = cute::block_rank_in_cluster() == 0;
75
+ const uint32_t sm_idx = blockIdx.x;
76
+ const uint32_t thread_idx = threadIdx.x;
77
+ const uint32_t warp_idx = cutlass::canonical_warp_idx_sync();
78
+ const uint32_t lane_idx = ptx::get_lane_idx();
79
+
80
+ // Prefetch TMA descriptors at the very beginning
81
+ if (warp_idx == 0) {
82
+ cute::prefetch_tma_descriptor(&tensor_map_l1_acts);
83
+ cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf);
84
+ cute::prefetch_tma_descriptor(&tensor_map_l1_weights);
85
+ cute::prefetch_tma_descriptor(&tensor_map_l1_weights_sf);
86
+ cute::prefetch_tma_descriptor(&tensor_map_l1_output);
87
+ cute::prefetch_tma_descriptor(&tensor_map_l2_acts);
88
+ cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf);
89
+ cute::prefetch_tma_descriptor(&tensor_map_l2_weights);
90
+ cute::prefetch_tma_descriptor(&tensor_map_l2_weights_sf);
91
+ }
92
+
93
+ // Workspaces
94
+ const auto workspace = layout::Workspace(
95
+ sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk);
96
+
97
+ // Token and buffer layouts
98
+ constexpr auto fp8_token_layout = layout::Data(kHidden);
99
+ constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16));
100
+ constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden);
101
+ constexpr auto fp8_sf_layout = layout::Data(kHidden / 32);
102
+ constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 32);
103
+ constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false);
104
+ constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false);
105
+ constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
106
+
107
+ // Registered inputs
108
+ const auto input_token_buffer = layout::Buffer(
109
+ fp8_token_layout, 1, kNumMaxTokensPerRank,
110
+ workspace.get_end_ptr());
111
+ const auto input_sf_buffer = layout::Buffer(
112
+ fp8_sf_layout, 1, kNumMaxTokensPerRank,
113
+ input_token_buffer.get_end_ptr());
114
+ const auto input_topk_idx_buffer = layout::Buffer(
115
+ input_topk_idx_layout, 1, kNumMaxTokensPerRank,
116
+ input_sf_buffer.get_end_ptr());
117
+ const auto input_topk_weights_buffer = layout::Buffer(
118
+ input_topk_weights_layout, 1, kNumMaxTokensPerRank,
119
+ input_topk_idx_buffer.get_end_ptr());
120
+
121
+ // SF and its buffer configs
122
+ constexpr uint32_t kGranK = 32;
123
+ constexpr uint32_t kNumUTCCPAlignedElems = 128;
124
+ DG_STATIC_ASSERT(SF_BLOCK_M == math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems), "Invalid SF_BLOCK_M");
125
+ DG_STATIC_ASSERT(SF_BLOCK_N == BLOCK_N, "No padding is needed for SFB");
126
+
127
+ // UTCCP 4x32 transpose index mapping within each 128-element group
128
+ const auto transform_sf_token_idx = [](const uint32_t& token_idx_in_expert) {
129
+ const uint32_t idx = token_idx_in_expert % BLOCK_M;
130
+ return token_idx_in_expert / BLOCK_M * SF_BLOCK_M +
131
+ (idx & ~127u) + (idx & 31u) * 4 + ((idx >> 5) & 3u);
132
+ };
133
+
134
+ // L1 inputs
135
+ const auto l1_token_buffer = layout::Buffer(
136
+ fp8_token_layout, 1, kNumMaxPoolTokens,
137
+ input_topk_weights_buffer.get_end_ptr());
138
+ const auto l1_sf_buffer = layout::Buffer(
139
+ fp8_sf_layout, 1, kNumPaddedSFPoolTokens,
140
+ l1_token_buffer.get_end_ptr());
141
+ const auto l1_topk_weights_buffer = layout::Buffer(
142
+ l1_topk_weights_layout, 1, kNumMaxPoolTokens,
143
+ l1_sf_buffer.get_end_ptr());
144
+
145
+ // L2 inputs
146
+ const auto l2_token_buffer = layout::Buffer(
147
+ fp8_intermediate_token_layout, 1, kNumMaxPoolTokens,
148
+ l1_topk_weights_buffer.get_end_ptr()
149
+ );
150
+ const auto l2_sf_buffer = layout::Buffer(
151
+ fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens,
152
+ l2_token_buffer.get_end_ptr()
153
+ );
154
+
155
+ // Combine inputs
156
+ const auto combine_token_buffer = layout::Buffer(
157
+ bf16_token_layout, kNumTopk, kNumMaxTokensPerRank,
158
+ l2_sf_buffer.get_end_ptr()
159
+ );
160
+
161
+ // Data types
162
+ // NOTES: activations are FP8 (e4m3), weights are FP4 (e2m1)
163
+ using a_dtype_t = cutlass::float_e4m3_t;
164
+ using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t;
165
+
166
+ // MMA configs
167
+ // NOTES: always swap A/B, 2-CTA MMA, and matrices are K-major
168
+ constexpr uint32_t LAYOUT_AD_M = 128;
169
+ constexpr uint32_t UMMA_M = LAYOUT_AD_M * 2;
170
+ constexpr uint32_t UMMA_N = BLOCK_M; // Swap AB
171
+ constexpr uint32_t UMMA_K = 32;
172
+ constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A
173
+ constexpr uint32_t LOAD_BLOCK_N = BLOCK_N;
174
+ DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M");
175
+ DG_STATIC_ASSERT(BLOCK_N == LAYOUT_AD_M, "Invalid block N");
176
+ DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
177
+
178
+ // Swizzle configs
179
+ constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t);
180
+ constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t);
181
+ constexpr uint32_t kSwizzleCDMode = 128;
182
+ DG_STATIC_ASSERT(BLOCK_N % kSwizzleCDMode == 0, "Invalid block N");
183
+
184
+ // Epilogue configs
185
+ constexpr uint32_t kNumEpilogueStages = 2;
186
+ constexpr uint32_t kNumTMAStoreStages = 2;
187
+
188
+ // Shared memory
189
+ constexpr uint32_t kSharedMemoryAlignment = 1024;
190
+ extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[];
191
+
192
+ // Shared memory sizes
193
+ // NOTES: FP8 CD output for L1 (2 TMA stages, BLOCK_N/2 post-SwiGLU), BF16 output for L2 (no TMA, a single stage)
194
+ constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2;
195
+ constexpr uint32_t SMEM_EXPERT_COUNT_SIZE =
196
+ math::constexpr_align<uint32_t>(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment);
197
+ constexpr uint32_t SMEM_SEND_BUFFER_SIZE =
198
+ math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment);
199
+ constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
200
+ constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
201
+ constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
202
+ constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
203
+ constexpr uint32_t SMEM_CD_L1_SIZE =
204
+ kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) * kNumTMAStoreStages;
205
+ constexpr uint32_t SMEM_CD_L2_SIZE =
206
+ kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16);
207
+ constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE;
208
+ constexpr uint32_t SMEM_CD_L1_SIZE_PER_STAGE = SMEM_CD_L1_SIZE / kNumTMAStoreStages;
209
+ constexpr uint32_t SMEM_BEFORE_BARRIER_SIZE =
210
+ SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
211
+ DG_STATIC_ASSERT(SMEM_CD_SIZE % kSharedMemoryAlignment == 0 and
212
+ SMEM_A_SIZE_PER_STAGE % kSharedMemoryAlignment == 0 and
213
+ SMEM_B_SIZE_PER_STAGE % kSharedMemoryAlignment == 0,
214
+ "Shared memory of CD/A/B must be aligned to 1024 bytes");
215
+
216
+ // Tensor memory size
217
+ constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages;
218
+ constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
219
+ constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
220
+ constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
221
+ constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
222
+ constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
223
+ DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
224
+
225
+ // Assign shared memory for dispatch warps
226
+ const auto smem_expert_count = reinterpret_cast<uint32_t*>(smem_buffer);
227
+ const auto smem_send_buffers = layout::Buffer(
228
+ fp8_token_layout, kNumDispatchWarps, 1,
229
+ math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE));
230
+
231
+ // GEMM shared memory: C/D, A, B
232
+ // NOTES: GEMM shared memory starts after the dispatch region, aligned to 1024 bytes
233
+ auto smem_gemm_base = math::advance_ptr(
234
+ smem_buffer, SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE
235
+ );
236
+
237
+ // D/A/B shared memory
238
+ auto smem_cd = utils::PatternVisitor([=](const uint32_t& i) {
239
+ return math::advance_ptr<uint8_t>(smem_gemm_base, i * SMEM_CD_L1_SIZE_PER_STAGE);
240
+ });
241
+ auto smem_cd_l2 = smem_cd[0];
242
+ auto smem_a = utils::PatternVisitor([=](const uint32_t& i) {
243
+ return math::advance_ptr<a_dtype_t>(smem_gemm_base, SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
244
+ });
245
+ auto smem_b = utils::PatternVisitor([=](const uint32_t& i) {
246
+ return math::advance_ptr<b_dtype_t>(smem_gemm_base, SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
247
+ });
248
+
249
+ // SF shared memory: SFA and SFB per pipeline stage
250
+ auto sf_start_ptr = math::advance_ptr<uint8_t>(smem_gemm_base,
251
+ SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
252
+ auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) {
253
+ return reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
254
+ });
255
+ auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) {
256
+ return reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
257
+ });
258
+
259
+ // Epilogue amax reduction shared memory
260
+ auto smem_amax_reduction = reinterpret_cast<float2*>(smem_sfb[kNumStages]);
261
+
262
+ // Barriers and tensor memory pointer
263
+ auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_amax_reduction + STORE_BLOCK_M * kNumEpilogueWarps / 2);
264
+ auto dispatch_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
265
+ auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + i); });
266
+ auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages + i); });
267
+ auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + i); });
268
+ auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages + i); });
269
+ auto combine_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + i); });
270
+ auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + kNumEpilogueWarps * 2);
271
+
272
+ // A cluster sync is essential for 2CTA tensor memory allocation
273
+ comm::cluster_sync_with_relaxed_arrive();
274
+
275
+ // Initialization
276
+ if (warp_idx == 0) {
277
+ // Clean shared memory
278
+ if (cute::elect_one_sync())
279
+ ptx::st_shared_bulk(smem_expert_count, kNumExperts * sizeof(uint32_t));
280
+ } else if (warp_idx == 1) {
281
+ // Init m-barriers for dispatch
282
+ #pragma unroll
283
+ for (uint32_t i = lane_idx; i < kNumDispatchWarps; i += 32)
284
+ dispatch_barriers[i]->init(1);
285
+ cutlass::arch::fence_barrier_init();
286
+ } else if (warp_idx == 2) {
287
+ // Init GEMM barriers
288
+ if (cute::elect_one_sync()) {
289
+ #pragma unroll
290
+ for (uint32_t i = 0; i < kNumStages; ++ i) {
291
+ // Arrive at all CTAs
292
+ full_barriers[i]->init(2 * 2);
293
+ empty_barriers[i]->init(1);
294
+ }
295
+ #pragma unroll
296
+ for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
297
+ // Arrive at all CTAs
298
+ tmem_full_barriers[i]->init(1);
299
+ // Arrive only at the leader CTA
300
+ tmem_empty_barriers[i]->init(2 * kNumEpilogueThreads);
301
+ }
302
+ #pragma unroll
303
+ for (uint32_t i = 0; i < kNumEpilogueWarps * 2; ++ i)
304
+ combine_barriers[i]->init(1);
305
+ }
306
+ cutlass::arch::fence_barrier_init();
307
+ } else if (warp_idx == 3) {
308
+ // Allocate tensor memory
309
+ Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
310
+ }
311
+ // NOTES: Using `.relaxed` is allowed here since `fence_barrier_init` is `.release.cluster`,
312
+ // and `barrier.cluster.wait.aligned` is by default `.acquire`
313
+ comm::cluster_sync_with_relaxed_arrive();
314
+
315
+ // Task scheduler
316
+ auto scheduler = sched::MegaMoEScheduler<
317
+ BLOCK_M, BLOCK_N, BLOCK_K,
318
+ L1_SHAPE_N, L1_SHAPE_K,
319
+ L2_SHAPE_N, L2_SHAPE_K,
320
+ kNumExpertsPerRank,
321
+ kNumExpertsPerWave,
322
+ kNumSMs, kNumRanks>(workspace);
323
+
324
+ // MMA pipeline and TMA phases
325
+ uint32_t stage_idx = 0, phase = 0;
326
+ auto advance_pipeline = [&](uint32_t& k_block_idx) {
327
+ ++ k_block_idx;
328
+
329
+ // Flip phases only if reach the next first stage
330
+ stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
331
+ phase ^= stage_idx == 0;
332
+ };
333
+
334
+ // Intra-SM Barrier indices
335
+ constexpr uint32_t kDispatchBarrierIdx = 0;
336
+ constexpr uint32_t kDispatchWithEpilogueBarrierIdx = 1;
337
+ constexpr uint32_t kEpilogueFullBarrierIdx = 2;
338
+ constexpr uint32_t kEpilogueWGBarrierStartIdx = 3;
339
+
340
+ // NVLink barrier tags
341
+ constexpr uint32_t kBeforeDispatchPullBarrierTag = 1;
342
+ constexpr uint32_t kBeforeCombineReduceBarrierTag = 2;
343
+ constexpr uint32_t kAfterWorkspaceCleanBarrierTag = 3;
344
+
345
+ // Adjust registers
346
+ constexpr uint32_t kNumDispatchRegisters = 48;
347
+ constexpr uint32_t kNumNonEpilogueRegisters = 40;
348
+ constexpr uint32_t kNumEpilogueRegisters = 208;
349
+ DG_STATIC_ASSERT(kNumDispatchRegisters * kNumDispatchThreads +
350
+ kNumNonEpilogueRegisters * kNumNonEpilogueThreads +
351
+ kNumEpilogueRegisters * kNumEpilogueThreads <= 64512,
352
+ "Too many registers");
353
+
354
+ // Grid sync index assignments (dispatch and epilogue use separate counters to avoid conflicts)
355
+ constexpr uint32_t kDispatchGridSyncIndex = 0;
356
+ constexpr uint32_t kEpilogueGridSyncIndex = 1;
357
+
358
+ // Different warp roles
359
+ if (warp_idx < kNumDispatchWarps) {
360
+ // Adjust registers
361
+ cutlass::arch::warpgroup_reg_dealloc<kNumDispatchRegisters>();
362
+
363
+ // Dispatch warps
364
+ DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk");
365
+ constexpr uint32_t kNumActivateLanes = kNumTokensPerWarp * kNumTopk;
366
+ const auto read_topk_idx = [&](const auto& process) {
367
+ // TODO: figure out better unrolling
368
+ // Now, `unroll` is better than `unroll 8`
369
+ #pragma unroll
370
+ for (uint32_t i = (sm_idx * kNumDispatchWarps + warp_idx) * kNumTokensPerWarp;
371
+ i < num_tokens;
372
+ i += kNumSMs * kNumDispatchWarps * kNumTokensPerWarp) {
373
+ // Allocate slots for each token-topk
374
+ int expert_idx = -1;
375
+ if (i + (lane_idx / kNumTopk) < num_tokens and lane_idx < kNumActivateLanes) {
376
+ expert_idx = static_cast<int>(
377
+ __ldg(input_topk_idx_buffer.get_base_ptr<int64_t>() + i * kNumTopk + lane_idx));
378
+ if (expert_idx >= 0)
379
+ process(i * kNumTopk + lane_idx, expert_idx);
380
+ }
381
+ __syncwarp();
382
+ }
383
+ };
384
+
385
+ // Count experts' tokens
386
+ read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) {
387
+ atomicAdd_block(smem_expert_count + expert_idx, 1);
388
+ });
389
+ ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx);
390
+
391
+ // Get SM offset (~6.5 us)
392
+ #pragma unroll
393
+ for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) {
394
+ const uint64_t send_value = (1ull << 32) | static_cast<uint64_t>(smem_expert_count[i]);
395
+ smem_expert_count[i] = static_cast<uint32_t>(
396
+ ptx::atomic_add(workspace.get_expert_send_count_ptr(i), send_value));
397
+ }
398
+ ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx);
399
+
400
+ // Write source indices (~2 us with 512 tokens)
401
+ read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) {
402
+ const auto dst_rank_idx = expert_idx / kNumExpertsPerRank;
403
+ const auto dst_slot_idx = atomicAdd_block(smem_expert_count + expert_idx, 1);
404
+ const auto dst_ptr = workspace.get_src_token_topk_idx_ptr(
405
+ expert_idx % kNumExpertsPerRank, sym_buffer.rank_idx, dst_slot_idx);
406
+ *sym_buffer.map(dst_ptr, dst_rank_idx) = token_topk_idx;
407
+ });
408
+
409
+ // Grid sync
410
+ comm::grid_sync<kNumSMs, kDispatchGridSyncIndex>(
411
+ workspace, sm_idx, thread_idx,
412
+ [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }
413
+ );
414
+
415
+ // Write expert count
416
+ if (sm_idx == 0) {
417
+ #pragma unroll
418
+ for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) {
419
+ const auto dst_rank_idx = i / kNumExpertsPerRank;
420
+ const auto dst_local_expert_idx = i % kNumExpertsPerRank;
421
+ const auto expert_status = *workspace.get_expert_send_count_ptr(i);
422
+ *sym_buffer.map(
423
+ workspace.get_expert_recv_count_ptr(sym_buffer.rank_idx, dst_local_expert_idx),
424
+ dst_rank_idx) = expert_status & 0xffffffff;
425
+ ptx::atomic_add_sys(
426
+ sym_buffer.map(workspace.get_expert_recv_count_sum_ptr(dst_local_expert_idx), dst_rank_idx),
427
+ expert_status);
428
+ }
429
+ }
430
+ ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx);
431
+
432
+ // Barrier before pulling
433
+ comm::nvlink_barrier<kNumRanks, kNumSMs, kNumDispatchThreads,
434
+ kDispatchGridSyncIndex, kBeforeDispatchPullBarrierTag>(
435
+ workspace, sym_buffer, sm_idx, thread_idx,
436
+ [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); },
437
+ /* After the grid sync above, there is no more writes by other SMs (except 0) */ false,
438
+ /* After the NVLink barrier, there is a grid sync */ true
439
+ );
440
+
441
+ // Ensure the epilogue barrier cannot run with the pull barrier
442
+ ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx);
443
+
444
+ // Pull token data and SF from remote ranks into local L1 buffer
445
+ uint32_t pull_mbarrier_phase = 0;
446
+ const auto pull_buffer = smem_send_buffers.get_rank_buffer(warp_idx).get_data_buffer(0);
447
+ const auto pull_mbarrier = dispatch_barriers[warp_idx];
448
+
449
+ // Cache expert token counts in registers (same pattern as scheduler)
450
+ scheduler.fetch_expert_recv_count();
451
+
452
+ // Per-rank counts for current expert (re-loaded when expert changes)
453
+ constexpr uint32_t kNumRanksPerLane = math::constexpr_ceil_div(kNumRanks, 32u);
454
+ int current_expert_idx = -1;
455
+ uint32_t stored_rank_count[kNumRanksPerLane] = {};
456
+ uint32_t expert_start_idx = 0, expert_end_idx = 0;
457
+ uint32_t expert_pool_block_offset = 0;
458
+
459
+ constexpr uint32_t kNumGlobalWarps = kNumSMs * kNumDispatchWarps;
460
+ for (uint32_t token_idx = sm_idx * kNumDispatchWarps + warp_idx; ; token_idx += kNumGlobalWarps) {
461
+ // Advance expert until within the range
462
+ int old_expert_idx = current_expert_idx;
463
+ while (token_idx >= expert_end_idx) {
464
+ if (++ current_expert_idx >= kNumExpertsPerRank)
465
+ break;
466
+
467
+ // Update pool block offset for the new expert
468
+ expert_pool_block_offset += math::ceil_div(expert_end_idx - expert_start_idx, BLOCK_M);
469
+
470
+ // Move start and end to the next expert
471
+ expert_start_idx = expert_end_idx;
472
+ expert_end_idx += scheduler.get_num_tokens(current_expert_idx);
473
+ }
474
+
475
+ // Finish all tokens
476
+ if (current_expert_idx >= kNumExpertsPerRank)
477
+ break;
478
+
479
+ // Load per-rank counts when expert changes
480
+ if (old_expert_idx != current_expert_idx) {
481
+ old_expert_idx = current_expert_idx;
482
+ #pragma unroll
483
+ for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) {
484
+ const uint32_t j = i * 32 + lane_idx;
485
+ // TODO: this is not coalesced
486
+ stored_rank_count[i] = j < kNumRanks ?
487
+ static_cast<uint32_t>(*workspace.get_expert_recv_count_ptr(j, current_expert_idx)) : 0;
488
+ }
489
+ }
490
+
491
+ // Round-robin rank selection via iterative min-peeling
492
+ uint32_t current_rank_in_expert_idx;
493
+ uint32_t remaining[kNumRanksPerLane];
494
+ #pragma unroll
495
+ for (uint32_t i = 0; i < kNumRanksPerLane; ++ i)
496
+ remaining[i] = stored_rank_count[i];
497
+ uint32_t offset = 0;
498
+ uint32_t token_idx_in_expert = token_idx - expert_start_idx;
499
+ uint32_t slot_idx = token_idx_in_expert;
500
+ uint32_t token_idx_in_rank;
501
+ while (true) {
502
+ // Compute active count and min across all ranks
503
+ // NOTES: reduce within each lane first, then warp-reduce once
504
+ uint32_t num_actives_in_lane = 0;
505
+ uint32_t min_in_lane = 0xffffffff;
506
+ #pragma unroll
507
+ for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) {
508
+ num_actives_in_lane += remaining[i] > 0;
509
+ if (remaining[i] > 0)
510
+ min_in_lane = cute::min(min_in_lane, remaining[i]);
511
+ }
512
+ const uint32_t num_active_ranks = __reduce_add_sync(0xffffffff, num_actives_in_lane);
513
+ const uint32_t length = __reduce_min_sync(0xffffffff, min_in_lane);
514
+
515
+ // Hit in the current round
516
+ const uint32_t num_round_tokens = length * num_active_ranks;
517
+ if (slot_idx < num_round_tokens) {
518
+ const uint32_t slot_idx_in_round = slot_idx % num_active_ranks;
519
+ uint32_t num_seen_ranks = 0;
520
+ current_rank_in_expert_idx = 0;
521
+ #pragma unroll
522
+ for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) {
523
+ const uint32_t mask = __ballot_sync(0xffffffff, remaining[i] > 0);
524
+ const uint32_t num_active_lanes = __popc(mask);
525
+ if (slot_idx_in_round >= num_seen_ranks and slot_idx_in_round < num_seen_ranks + num_active_lanes)
526
+ current_rank_in_expert_idx = i * 32 + __fns(mask, 0, slot_idx_in_round - num_seen_ranks + 1);
527
+ num_seen_ranks += num_active_lanes;
528
+ }
529
+ token_idx_in_rank = offset + (slot_idx / num_active_ranks);
530
+ break;
531
+ }
532
+
533
+ // Move into the next round
534
+ slot_idx -= num_round_tokens;
535
+ offset += length;
536
+ #pragma unroll
537
+ for (uint32_t i = 0; i < kNumRanksPerLane; ++ i)
538
+ remaining[i] -= cute::min(remaining[i], length);
539
+ }
540
+
541
+ // Read source token-topk index (written by remote dispatch via NVLink)
542
+ const uint32_t src_token_topk_idx = *workspace.get_src_token_topk_idx_ptr(
543
+ current_expert_idx, current_rank_in_expert_idx, token_idx_in_rank);
544
+ const uint32_t src_token_idx = src_token_topk_idx / kNumTopk;
545
+ const uint32_t src_topk_idx = src_token_topk_idx % kNumTopk;
546
+
547
+ // TMA load token from remote rank into shared memory
548
+ if (cute::elect_one_sync()) {
549
+ ptx::tma_load_1d(
550
+ pull_buffer.get_base_ptr(),
551
+ sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(),
552
+ current_rank_in_expert_idx),
553
+ pull_mbarrier, kHidden);
554
+ }
555
+ __syncwarp();
556
+
557
+ // Load and store SF (overlaps with TMA token load)
558
+ constexpr uint32_t kNumSFUint32 = kHidden / 128;
559
+ DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 128 == 0, "Invalid SF");
560
+ const auto remote_sf_ptr = sym_buffer.map(
561
+ input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr<uint32_t>(),
562
+ current_rank_in_expert_idx);
563
+ const auto local_sf_ptr = l1_sf_buffer.get_base_ptr<uint32_t>();
564
+ const auto sf_pool_token_idx = expert_pool_block_offset * SF_BLOCK_M +
565
+ transform_sf_token_idx(token_idx_in_expert);
566
+ #pragma unroll
567
+ for (uint32_t i = 0; i < math::constexpr_ceil_div(kNumSFUint32, 32u); ++ i) {
568
+ const uint32_t j = i * 32 + lane_idx;
569
+ if (j < kNumSFUint32)
570
+ local_sf_ptr[j * kNumPaddedSFPoolTokens + sf_pool_token_idx] = remote_sf_ptr[j];
571
+ }
572
+ __syncwarp();
573
+
574
+ // Store weights and token data
575
+ const uint32_t pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert;
576
+ if (cute::elect_one_sync()) {
577
+ // Load weights
578
+ const auto weight = *sym_buffer.map(
579
+ input_topk_weights_buffer.get_base_ptr<float>() + src_token_topk_idx,
580
+ current_rank_in_expert_idx);
581
+ *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr<float>() = weight;
582
+
583
+ // Wait for TMA token load to complete
584
+ ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden);
585
+ ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase);
586
+
587
+ // Store token to local L1 buffer via TMA
588
+ ptx::tma_store_1d(
589
+ l1_token_buffer.get_data_buffer(pool_token_idx).get_base_ptr(),
590
+ pull_buffer.get_base_ptr(), pull_buffer.get_num_bytes());
591
+
592
+ // Write source metadata for combine write-back
593
+ *workspace.get_token_src_metadata_ptr(pool_token_idx) =
594
+ {current_rank_in_expert_idx, src_token_idx, src_topk_idx};
595
+
596
+ // Wait for token TMA store to complete
597
+ cute::tma_store_arrive();
598
+ ptx::tma_store_wait<0>();
599
+ ptx::red_add_rel(
600
+ workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + token_idx_in_expert / BLOCK_M), 1);
601
+ }
602
+ __syncwarp();
603
+ }
604
+
605
+ // Clean workspace for the next usage, and also do cumulative stats
606
+ // NOTES: it is overlapped with combine reduction epilogue
607
+ ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx);
608
+
609
+ DG_STATIC_ASSERT(kNumSMs > 1, "Invalid SM count");
610
+ if (sm_idx == 0) {
611
+ // SM 0: clear expert send count
612
+ #pragma unroll
613
+ for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads)
614
+ *workspace.get_expert_send_count_ptr(i) = 0;
615
+ } else {
616
+ // Other SMs: clean blocks
617
+ for (uint32_t i = sm_idx - 1; i < kNumExpertsPerRank; i += kNumSMs - 1) {
618
+ // Read expert token count before clearing
619
+ const auto num_recv_tokens = static_cast<uint32_t>(
620
+ *workspace.get_expert_recv_count_sum_ptr(i));
621
+ const auto num_recv_m_blocks = math::ceil_div(num_recv_tokens, BLOCK_M);
622
+
623
+ // Compute expert pool block offset
624
+ expert_pool_block_offset = scheduler.get_pool_block_offset(i);
625
+
626
+ // Wait read count ready
627
+ ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx);
628
+
629
+ // Clean expert token count, and add cumulative results
630
+ DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps");
631
+ if (warp_idx == 0) {
632
+ *workspace.get_expert_recv_count_sum_ptr(i) = 0;
633
+ } else if (warp_idx == 1) {
634
+ if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr)
635
+ ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast<int>(num_recv_tokens));
636
+ __syncwarp();
637
+ }
638
+
639
+ // Clean per-rank token count
640
+ for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads)
641
+ *workspace.get_expert_recv_count_ptr(j, i) = 0;
642
+ __syncwarp();
643
+
644
+ // Clean L1 and L2 arrival stuffs
645
+ for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) {
646
+ *workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0;
647
+ *workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0;
648
+ }
649
+ __syncwarp();
650
+ }
651
+ }
652
+
653
+ // Wait for all ranks to finish cleaning
654
+ comm::nvlink_barrier<kNumRanks, kNumSMs, kNumDispatchThreads,
655
+ kDispatchGridSyncIndex, kAfterWorkspaceCleanBarrierTag>(
656
+ workspace, sym_buffer, sm_idx, thread_idx,
657
+ [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); },
658
+ /* Before the NVLink barrier, there is a grid sync */ true,
659
+ /* At the end of kernel does not need to sync */ false
660
+ );
661
+ } else if (warp_idx == kNumDispatchWarps) {
662
+ // Adjust registers
663
+ cutlass::arch::warpgroup_reg_dealloc<kNumNonEpilogueRegisters>();
664
+
665
+ // GEMM TMA load warp for tokens with SFA
666
+ scheduler.for_each_block([&](const sched::BlockPhase& block_phase,
667
+ const uint32_t& local_expert_idx,
668
+ const uint32_t& num_k_blocks,
669
+ const uint32_t& m_block_idx, const uint32_t& n_block_idx) {
670
+ const auto tensor_map_a_ptr = block_phase == sched::BlockPhase::Linear2
671
+ ? &tensor_map_l2_acts : &tensor_map_l1_acts;
672
+ const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2
673
+ ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf;
674
+
675
+ const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K;
676
+ const auto shape_sfa_k = math::ceil_div(shape_k, kGranK * 4u);
677
+
678
+ // Compute pool block offset for this expert
679
+ const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx;
680
+
681
+ // Wait the entire token arrival for linear 1
682
+ if (block_phase == sched::BlockPhase::Linear1) {
683
+ const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx);
684
+ const auto expected = scheduler.template get_valid_m<false>();
685
+ while (ptx::ld_acq(ptr) != expected);
686
+ } else {
687
+ // The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2x L1 blocks' arrival
688
+ // NOTES: Originally we wait blocks on-demand to overlap L1 calculation
689
+ // with L2, but this optimization is negative when `num_experts_per_wave`
690
+ // guarantees L1's completion when L2 starts. So we remove it.
691
+ // In the future, if `num_experts_per_wave` is not large enough
692
+ // due to small `num_experts_per_rank`, we may need to add it back or add a switch
693
+ DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes");
694
+ const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx);
695
+ // NOTES: Equivalent to `(1ull << (2 * num_k_blocks)) - 1`, but split into two shifts
696
+ // to avoid undefined behavior when `num_k_blocks == 32`
697
+ const uint64_t expected = ((1ull << num_k_blocks) << num_k_blocks) - 1;
698
+ while (ptx::ld_acq_gpu(ptr) != expected);
699
+ }
700
+
701
+ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) {
702
+ // Wait consumer release
703
+ empty_barriers[stage_idx]->wait(phase ^ 1);
704
+
705
+ // Compute token offset from pool block index
706
+ uint32_t m_idx = pool_block_idx * BLOCK_M;
707
+ uint32_t k_idx = k_block_idx * BLOCK_K;
708
+ uint32_t sfa_m_idx = pool_block_idx * SF_BLOCK_M;
709
+ uint32_t sfa_k_idx = k_block_idx;
710
+
711
+ // Add 2 CTA offsets for non-leader CTA
712
+ if (not is_leader_cta)
713
+ m_idx += scheduler.template get_valid_m<true>() / 2;
714
+
715
+ // TMA copy tokens and SFA, then arrive at full barrier
716
+ if (cute::elect_one_sync()) {
717
+ tma::copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(
718
+ tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx, 2);
719
+ tma::copy<SF_BLOCK_M, 1, 0>(
720
+ tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2);
721
+ if (is_leader_cta) {
722
+ full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE * 2 + SF_BLOCK_M * sizeof(uint32_t) * 2);
723
+ } else {
724
+ full_barriers[stage_idx]->arrive(0u);
725
+ }
726
+ }
727
+ __syncwarp();
728
+ }
729
+ });
730
+ } else if (warp_idx == kNumDispatchWarps + 1) {
731
+ // Adjust registers
732
+ cutlass::arch::warpgroup_reg_dealloc<kNumNonEpilogueRegisters>();
733
+
734
+ // GEMM TMA load warp for weights with SF
735
+ scheduler.for_each_block([&](const sched::BlockPhase& block_phase,
736
+ const uint32_t& local_expert_idx,
737
+ const uint32_t& num_k_blocks,
738
+ const uint32_t& m_block_idx, const uint32_t& n_block_idx) {
739
+ const auto tensor_map_b_ptr =
740
+ block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights : &tensor_map_l1_weights;
741
+ const auto tensor_map_sfb_ptr =
742
+ block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights_sf : &tensor_map_l1_weights_sf;
743
+
744
+ const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K;
745
+ const auto shape_n = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_N : L1_SHAPE_N;
746
+ const auto shape_sfb_k = math::ceil_div(shape_k, kGranK * 4u);
747
+
748
+ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) {
749
+ // Wait consumer release
750
+ empty_barriers[stage_idx]->wait(phase ^ 1);
751
+
752
+ // Compute weight offset
753
+ uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N;
754
+ uint32_t k_idx = k_block_idx * BLOCK_K;
755
+ uint32_t sfb_n_idx = n_block_idx * BLOCK_N;
756
+ uint32_t sfb_k_idx = local_expert_idx * shape_sfb_k + k_block_idx;
757
+
758
+ // TMA copy weights with SF
759
+ if (cute::elect_one_sync()) {
760
+ tma::copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(
761
+ tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx, 2);
762
+ tma::copy<BLOCK_N, 1, 0>(
763
+ tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2);
764
+ if (is_leader_cta) {
765
+ full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE + BLOCK_N * sizeof(uint32_t) * 2);
766
+ } else {
767
+ full_barriers[stage_idx]->arrive(0u);
768
+ }
769
+ }
770
+ __syncwarp();
771
+ }
772
+ });
773
+ } else if (warp_idx == kNumDispatchWarps + 2) {
774
+ // Adjust registers
775
+ cutlass::arch::warpgroup_reg_dealloc<kNumNonEpilogueRegisters>();
776
+
777
+ // GEMM MMA issue warp (only the leader CTA will run)
778
+ if (is_leader_cta) {
779
+ // Make instruction descriptor with block scaling
780
+ // NOTES: always swap A/B
781
+ auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<
782
+ b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t,
783
+ UMMA_M, UMMA_N,
784
+ cute::UMMA::Major::K, cute::UMMA::Major::K
785
+ >();
786
+ auto sf_desc = mma::sm100::make_sf_desc(nullptr);
787
+
788
+ DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
789
+ auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0);
790
+ auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0);
791
+ uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
792
+ uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
793
+
794
+ // Checks for MMA instructions
795
+ DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
796
+ (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
797
+ (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
798
+ "Invalid MMA instruction shape");
799
+
800
+ // Persistently schedule over blocks
801
+ uint32_t current_iter_idx = 0;
802
+ scheduler.for_each_block([&](const sched::BlockPhase& block_phase,
803
+ const uint32_t& local_expert_idx,
804
+ const uint32_t& num_k_blocks,
805
+ const uint32_t& m_block_idx, const uint32_t& n_block_idx) {
806
+ // Dynamic update of UMMA N based on effective M
807
+ mma::sm100::update_instr_desc_with_umma_n(instr_desc, scheduler.template get_valid_m<true>());
808
+
809
+ // Wait tensor memory empty barrier arrival
810
+ const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages;
811
+ const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1;
812
+ tmem_empty_barriers[accum_stage_idx]->wait(accum_phase ^ 1);
813
+ ptx::tcgen05_after_thread_sync();
814
+
815
+ // Empty barrier arrival
816
+ auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
817
+ auto umma_arrive = [](const uint64_t* barrier) {
818
+ constexpr uint16_t kCTAMask = (1 << 2) - 1;
819
+ cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
820
+ };
821
+ umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
822
+
823
+ // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
824
+ if (do_tmem_full_arrive)
825
+ umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
826
+ __syncwarp();
827
+ };
828
+
829
+ // Launch MMAs
830
+ #pragma unroll 2
831
+ for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) {
832
+ // Wait TMA load completion
833
+ full_barriers[stage_idx]->wait(phase);
834
+ ptx::tcgen05_after_thread_sync();
835
+
836
+ const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx);
837
+ const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx);
838
+ if (cute::elect_one_sync()) {
839
+ // UTCCP copy SFA and SFB to TMEM
840
+ using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta;
841
+ #pragma unroll
842
+ for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
843
+ auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
844
+ mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
845
+ cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
846
+ }
847
+ #pragma unroll
848
+ for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
849
+ auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
850
+ mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
851
+ cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
852
+ }
853
+
854
+ // Issue UMMA
855
+ #pragma unroll
856
+ for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
857
+ const auto runtime_instr_desc =
858
+ mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k);
859
+ a_desc.lo = mma::sm100::advance_umma_desc_lo<
860
+ cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K);
861
+ b_desc.lo = mma::sm100::advance_umma_desc_lo<
862
+ cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K);
863
+ ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma(
864
+ b_desc, a_desc, accum_stage_idx * UMMA_N,
865
+ k_block_idx > 0 or k > 0, runtime_instr_desc,
866
+ kTmemStartColOfSFB, kTmemStartColOfSFA);
867
+ }
868
+ }
869
+ __syncwarp();
870
+
871
+ // Commit to the mbarrier object
872
+ // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
873
+ empty_barrier_arrive(k_block_idx == num_k_blocks - 1);
874
+ }
875
+ });
876
+
877
+ // To safely deconstruct barriers, we need another round of waits
878
+ if (current_iter_idx > 0) {
879
+ const auto accum_phase_idx = ((current_iter_idx - 1) / kNumEpilogueStages) & 1;
880
+ tmem_empty_barriers[(current_iter_idx - 1) % kNumEpilogueStages]->wait(accum_phase_idx);
881
+ }
882
+ }
883
+ } else if (warp_idx == kNumDispatchWarps + 3) {
884
+ // Adjust registers
885
+ cutlass::arch::warpgroup_reg_dealloc<kNumNonEpilogueRegisters>();
886
+
887
+ } else if (warp_idx >= kNumDispatchWarps + kNumMMANonEpilogueWarps) {
888
+ // Adjust registers
889
+ cutlass::arch::warpgroup_reg_alloc<kNumEpilogueRegisters>();
890
+
891
+ // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
892
+ // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
893
+ // NOTES: we also forbid two CTAs to share the same SM and its tensor memory
894
+ DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
895
+
896
+ // GEMM epilogue warps
897
+ const auto epilogue_warp_idx = warp_idx - (kNumDispatchWarps + kNumMMANonEpilogueWarps);
898
+ const auto epilogue_wg_idx = epilogue_warp_idx / 4;
899
+ const auto epilogue_thread_idx = epilogue_warp_idx * 32 + lane_idx;
900
+ const auto warp_idx_in_wg = epilogue_warp_idx % 4;
901
+ DG_STATIC_ASSERT((kNumDispatchWarps + kNumMMANonEpilogueWarps) % 4 == 0 and
902
+ kNumEpilogueWarps % 4 == 0, "Invalid epilogue warps");
903
+
904
+ // TODO: support effective block M
905
+ // NOTES:
906
+ // - 2 warpgroups divide the whole BM into BM / 2
907
+ // - 4 warps divide the whole BN into BN / 4
908
+ // - BM / 2 is further divided into stored blocks, i.e. with `STORE_BLOCK_M` size
909
+ // - `STORE_BLOCK_M` in further divided into `ATOM_M`
910
+ constexpr uint32_t WG_BLOCK_M = BLOCK_M / kNumEpilogueWarpgroups;
911
+ constexpr uint32_t ATOM_M = 8;
912
+ constexpr uint32_t kNumBankGroupBytes = 16u;
913
+ constexpr uint32_t kNumAtomsPerStore = STORE_BLOCK_M / ATOM_M;
914
+ DG_STATIC_ASSERT(BLOCK_M % kNumEpilogueWarpgroups == 0, "Invalid block M");
915
+ DG_STATIC_ASSERT(WG_BLOCK_M % STORE_BLOCK_M == 0, "Invalid warpgroup block M");
916
+ DG_STATIC_ASSERT(STORE_BLOCK_M % ATOM_M == 0, "Invalid store block M");
917
+ DG_STATIC_ASSERT(BLOCK_N == 128, "Invalid block N");
918
+
919
+ // Ensure the epilogue barrier cannot run with the pull barrier
920
+ ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx);
921
+
922
+ // Persistently schedule over blocks
923
+ uint32_t current_iter_idx = 0;
924
+ scheduler.for_each_block([&](const sched::BlockPhase& block_phase,
925
+ const uint32_t& local_expert_idx,
926
+ const uint32_t& num_k_blocks,
927
+ const uint32_t& m_block_idx, const uint32_t& n_block_idx) {
928
+ // Wait UMMA arrival
929
+ const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages;
930
+ const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1;
931
+ tmem_full_barriers[accum_stage_idx]->wait(accum_phase);
932
+ ptx::tcgen05_after_thread_sync();
933
+
934
+ // Compute offsets
935
+ // NOTES: use shuffle here to let NVCC know warp divergence won't happen
936
+ const uint32_t valid_m = ptx::exchange(scheduler.template get_valid_m<false>(), 0);
937
+ const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx;
938
+ uint32_t m_idx = pool_block_idx * BLOCK_M;
939
+ uint32_t n_idx = n_block_idx * BLOCK_N;
940
+
941
+ if (block_phase == sched::BlockPhase::Linear1) {
942
+ // Unified L1 epilogue: SwiGLU in-place using granularity 8 interleaved weights
943
+ // With `SM100_TMEM_LOAD_16dp256b1x`, gate/up pairs are:
944
+ // (values[0], values[2]), (values[1], values[3]),
945
+ // (values[4], values[6]), (values[5], values[7])
946
+ float stored_cached_weight = 0;
947
+
948
+ #pragma unroll
949
+ for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) {
950
+ // Early break if the entire store block is beyond the valid token range
951
+ if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) {
952
+ ptx::tcgen05_before_thread_sync();
953
+ tmem_empty_barriers[accum_stage_idx]->arrive(0u);
954
+ break;
955
+ }
956
+
957
+ // Iterate all atoms in the store block
958
+ float2 swiglu_values[kNumAtomsPerStore * 2];
959
+ float2 amax_values[kNumAtomsPerStore];
960
+ #pragma unroll
961
+ for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) {
962
+ const uint32_t j = s * kNumAtomsPerStore + i;
963
+
964
+ // Load weights from global into register cache per 32 tokens
965
+ DG_STATIC_ASSERT(32 % ATOM_M == 0, "Invalid block size");
966
+ if ((j * ATOM_M) % 32 == 0 and (WG_BLOCK_M % 32 == 0 or j * ATOM_M + lane_idx < WG_BLOCK_M)) {
967
+ stored_cached_weight = *l1_topk_weights_buffer
968
+ .get_data_buffer(m_idx + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M + lane_idx)
969
+ .get_base_ptr<float>();
970
+ }
971
+
972
+ // Load weights from register cache
973
+ const float2 weights = {
974
+ ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 0),
975
+ ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 1)
976
+ };
977
+
978
+ // Load from TMEM
979
+ uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M;
980
+ uint32_t values[ATOM_M];
981
+ cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr,
982
+ values[0], values[1], values[2], values[3]);
983
+ cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000,
984
+ values[4], values[5], values[6], values[7]);
985
+ cutlass::arch::fence_view_async_tmem_load();
986
+
987
+ // Signal tensor memory consumed on the last atom
988
+ if (j == WG_BLOCK_M / ATOM_M - 1) {
989
+ ptx::tcgen05_before_thread_sync();
990
+ tmem_empty_barriers[accum_stage_idx]->arrive(0u);
991
+ }
992
+
993
+ // Apply SwiGLU: silu(gate) * up
994
+ // Gate/up pairs: (0, 2), (1, 3), (4, 6), (5, 7)
995
+ auto fp32_values = reinterpret_cast<float*>(values);
996
+ #pragma unroll
997
+ for (uint32_t k = 0; k < 2; ++ k) {
998
+ auto bf16_gate = __float22bfloat162_rn(make_float2(fp32_values[k * 4], fp32_values[k * 4 + 1]));
999
+ auto bf16_up = __float22bfloat162_rn(make_float2(fp32_values[k * 4 + 2], fp32_values[k * 4 + 3]));
1000
+
1001
+ // Clamp
1002
+ if constexpr (kActivationClamp != cute::numeric_limits<float>::infinity()) {
1003
+ bf16_gate = __hmin2(bf16_gate, {kActivationClamp, kActivationClamp});
1004
+ bf16_up = __hmax2(bf16_up, {-kActivationClamp, -kActivationClamp});
1005
+ bf16_up = __hmin2(bf16_up, {kActivationClamp, kActivationClamp});
1006
+ }
1007
+
1008
+ // SwiGLU
1009
+ auto gate = __bfloat1622float2(bf16_gate);
1010
+ auto neg_gate_exp = make_float2(
1011
+ kFastMath ? __expf(-gate.x) : expf(-gate.x),
1012
+ kFastMath ? __expf(-gate.y) : expf(-gate.y));
1013
+ const auto denom = __fadd2_rn({1.0f, 1.0f}, neg_gate_exp);
1014
+ if constexpr (kFastMath) {
1015
+ gate = __fmul2_rn(gate, {math::fast_rcp(denom.x), math::fast_rcp(denom.y)});
1016
+ } else {
1017
+ gate = {gate.x / denom.x, gate.y / denom.y};
1018
+ }
1019
+ const auto up = __bfloat1622float2(bf16_up);
1020
+ swiglu_values[i * 2 + k] = __fmul2_rn(__fmul2_rn(gate, up), weights);
1021
+ }
1022
+
1023
+ // Amax reduction
1024
+ amax_values[i].x = math::warp_reduce<4, true>(
1025
+ cute::max(cute::abs(swiglu_values[i * 2 + 0].x), cute::abs(swiglu_values[i * 2 + 1].x)),
1026
+ math::ReduceMax<float>());
1027
+ amax_values[i].y = math::warp_reduce<4, true>(
1028
+ cute::max(cute::abs(swiglu_values[i * 2 + 0].y), cute::abs(swiglu_values[i * 2 + 1].y)),
1029
+ math::ReduceMax<float>());
1030
+ if (lane_idx < 4)
1031
+ smem_amax_reduction[epilogue_warp_idx * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx] = amax_values[i];
1032
+ __syncwarp();
1033
+ }
1034
+
1035
+ // Wait shared memory release from previous TMA store
1036
+ // And fence `smem_amax_reduction`
1037
+ const uint32_t tma_stage_idx = s % kNumTMAStoreStages;
1038
+ ptx::tma_store_wait<kNumTMAStoreStages - 1>();
1039
+ ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx);
1040
+
1041
+ // Cast to FP8 E4M3 and store into shared memory
1042
+ #pragma unroll
1043
+ for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) {
1044
+ // Reduce amax
1045
+ const float2 wp_amax =
1046
+ smem_amax_reduction[(epilogue_warp_idx ^ 1) * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx % 4];
1047
+ amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x);
1048
+ amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y);
1049
+
1050
+ // Calculate SF
1051
+ float2 sf, sf_inv;
1052
+ math::get_e4m3_sf_and_sf_inv(amax_values[i], sf, sf_inv);
1053
+
1054
+ // Cast
1055
+ const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv);
1056
+ const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv);
1057
+ const auto fp8x4_values = __nv_fp8x4_e4m3(make_float4(upper.x, upper.y, lower.x, lower.y));
1058
+
1059
+ // STSM
1060
+ uint32_t row = lane_idx;
1061
+ uint32_t col = warp_idx_in_wg;
1062
+ const auto smem_ptr = smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N
1063
+ + i * ATOM_M * L1_OUT_BLOCK_N
1064
+ + row * L1_OUT_BLOCK_N
1065
+ + (col ^ (row / 2)) * kNumBankGroupBytes;
1066
+ ptx::SM100_U8x4_STSM_T<__nv_fp8x4_e4m3>::copy(fp8x4_values, smem_ptr);
1067
+
1068
+ // Store SF to `l2_sf_buffer` as UE8M0 (MN-major layout)
1069
+ // Only one warp per pair writes (both hold the same SF after cross-warp reduce)
1070
+ // Each lane < 4 holds SF for 2 rows (sf.x and sf.y)
1071
+ if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) {
1072
+ const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2;
1073
+ const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4;
1074
+ const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t);
1075
+ const auto sf_base_ptr = l2_sf_buffer.get_base_ptr<uint8_t>();
1076
+ // NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4
1077
+ // NOTES: originally there was:
1078
+ // - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2
1079
+ // - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)`
1080
+ // We find out that
1081
+ // 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside
1082
+ // 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside
1083
+ // This reduce the number of computation instructions.
1084
+ const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M;
1085
+ __builtin_assume(token_base_idx < BLOCK_M);
1086
+ const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M
1087
+ + m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4;
1088
+ const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx;
1089
+ sf_base_ptr[sf_addr] =
1090
+ (*reinterpret_cast<const uint32_t*>(&sf.x) >> 23);
1091
+ sf_base_ptr[sf_addr + 4 * static_cast<uint32_t>(sizeof(uint32_t))] =
1092
+ (*reinterpret_cast<const uint32_t*>(&sf.y) >> 23);
1093
+ }
1094
+ __syncwarp();
1095
+ }
1096
+ ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx);
1097
+
1098
+ // Issue TMA store after all atoms in this store block
1099
+ if (warp_idx_in_wg == 0 and cute::elect_one_sync()) {
1100
+ uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N;
1101
+ cute::tma_store_fence();
1102
+ cute::SM90_TMA_STORE_2D::copy(
1103
+ &tensor_map_l1_output,
1104
+ smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N,
1105
+ out_n_idx,
1106
+ m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M);
1107
+ cute::tma_store_arrive();
1108
+ }
1109
+ __syncwarp();
1110
+ }
1111
+
1112
+ // Notify L2
1113
+ // TODO: less epilogue sync scope
1114
+ ptx::tma_store_wait<0>();
1115
+ ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx);
1116
+ if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
1117
+ DG_STATIC_ASSERT(L2_SHAPE_K <= 64 * L1_OUT_BLOCK_N, "L2 shape K is too large");
1118
+ ptx::red_or_rel_gpu(
1119
+ workspace.get_l2_arrival_mask_ptr(pool_block_idx),
1120
+ 1ull << n_block_idx
1121
+ );
1122
+ }
1123
+ __syncwarp();
1124
+ } else {
1125
+ DG_STATIC_ASSERT(STORE_BLOCK_M % 8 == 0, "Invalid store M");
1126
+ constexpr uint32_t kNumRowsPerWarp = STORE_BLOCK_M / 8;
1127
+
1128
+ // L2 BF16 epilogue: write GEMM output to remote combine buffer via NVLink
1129
+ #pragma unroll
1130
+ for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) {
1131
+ // Early break if the entire store block is beyond the valid token range
1132
+ // TODO: check performance
1133
+ if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) {
1134
+ ptx::tcgen05_before_thread_sync();
1135
+ tmem_empty_barriers[accum_stage_idx]->arrive(0u);
1136
+ break;
1137
+ }
1138
+
1139
+ #pragma unroll
1140
+ for (uint32_t i = 0; i < STORE_BLOCK_M / ATOM_M; ++ i) {
1141
+ // Load from TMEM using .16x256b shape to satisfy STSM layout requirements
1142
+ // Start from lane index 0 and 16
1143
+ uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M;
1144
+ uint32_t values[ATOM_M];
1145
+ cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr,
1146
+ values[0], values[1], values[2], values[3]);
1147
+ cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000,
1148
+ values[4], values[5], values[6], values[7]);
1149
+ cutlass::arch::fence_view_async_tmem_load();
1150
+
1151
+ // Wait shared memory release from previous NVLink store
1152
+ // NOTES: skip for the first store block since the prior full barrier already ensures completion
1153
+ if (i == 0 and s > 0)
1154
+ ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx);
1155
+
1156
+ // Signal tensor memory consumed
1157
+ if (s == WG_BLOCK_M / STORE_BLOCK_M - 1 and i == STORE_BLOCK_M / ATOM_M - 1) {
1158
+ ptx::tcgen05_before_thread_sync();
1159
+ tmem_empty_barriers[accum_stage_idx]->arrive(0u);
1160
+ }
1161
+
1162
+ // Store into shared memory
1163
+ // NOTES: only use first 16 lanes for address
1164
+ // NOTES: 2 warps share a BF16 swizzle atom
1165
+ uint32_t row = lane_idx % 8;
1166
+ uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8;
1167
+ const auto smem_ptr = smem_cd_l2 +
1168
+ epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(nv_bfloat16)) +
1169
+ (warp_idx_in_wg / 2) * STORE_BLOCK_M * kSwizzleCDMode +
1170
+ i * ATOM_M * kSwizzleCDMode +
1171
+ row * (kNumBankGroupBytes * 8) +
1172
+ (col ^ row) * kNumBankGroupBytes;
1173
+ ptx::SM90_U32x4_STSM_T<uint32_t>::copy(
1174
+ math::cast_into_bf16_and_pack(values[0], values[1]),
1175
+ math::cast_into_bf16_and_pack(values[2], values[3]),
1176
+ math::cast_into_bf16_and_pack(values[4], values[5]),
1177
+ math::cast_into_bf16_and_pack(values[6], values[7]),
1178
+ smem_ptr
1179
+ );
1180
+ }
1181
+
1182
+ // Wait shared memory ready
1183
+ ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx);
1184
+
1185
+ // Write into remote buffers
1186
+ // One warp per row, now the layout is different from shared memory storing
1187
+ const uint32_t row_in_atom = (warp_idx_in_wg * 2 + lane_idx / 16) % ATOM_M;
1188
+ const uint32_t bank_group_idx = lane_idx % 8;
1189
+
1190
+ #pragma unroll
1191
+ for (uint32_t j = 0; j < kNumRowsPerWarp; ++ j) {
1192
+ const uint32_t row_in_store = j * 8 + warp_idx_in_wg * 2 + lane_idx / 16;
1193
+ const uint32_t m_idx_in_block = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + row_in_store;
1194
+
1195
+ // Skip padding rows beyond the actual token count for this expert
1196
+ if (m_idx_in_block >= valid_m)
1197
+ break;
1198
+
1199
+ const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block);
1200
+ const uint32_t dst_rank_idx = src_metadata.rank_idx;
1201
+ const uint32_t dst_token_idx = src_metadata.token_idx;
1202
+ const uint32_t dst_topk_idx = src_metadata.topk_idx;
1203
+
1204
+ // Read from shared memory
1205
+ const auto smem_ptr = smem_cd_l2 +
1206
+ epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(nv_bfloat16)) +
1207
+ (lane_idx % 16 / 8) * STORE_BLOCK_M * kSwizzleCDMode +
1208
+ row_in_store * kSwizzleCDMode +
1209
+ (bank_group_idx ^ row_in_atom) * kNumBankGroupBytes;
1210
+ const auto packed = ptx::ld_shared(reinterpret_cast<float4*>(smem_ptr));
1211
+
1212
+ // Write into remote
1213
+ const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx)
1214
+ .get_data_buffer(dst_token_idx);
1215
+ const auto dst_ptr = math::advance_ptr<float4>(
1216
+ dst_token.get_base_ptr(),
1217
+ n_idx * static_cast<uint32_t>(sizeof(nv_bfloat16)) + (lane_idx % 16) * static_cast<uint32_t>(sizeof(float4)));
1218
+ *sym_buffer.map(dst_ptr, dst_rank_idx) = packed;
1219
+ }
1220
+ }
1221
+
1222
+ // Ensure the next epilogue safe to use shared memory
1223
+ ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx);
1224
+ }
1225
+ });
1226
+
1227
+ // Deallocate tensor memory
1228
+ // NOTES: must be called by the same logical warp ID on both CTAs
1229
+ if (epilogue_warp_idx == 0)
1230
+ Allocator().free(0, kNumTmemCols);
1231
+
1232
+ // NVLink barrier (grid sync + cross-rank signal + grid sync): ~4 us
1233
+ comm::nvlink_barrier<kNumRanks, kNumSMs, kNumEpilogueThreads,
1234
+ kEpilogueGridSyncIndex, kBeforeCombineReduceBarrierTag>(
1235
+ workspace, sym_buffer, sm_idx, epilogue_thread_idx,
1236
+ [&]() { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); }
1237
+ );
1238
+
1239
+ // Barrier with dispatch warps, so that they can do clean workspace
1240
+ ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx);
1241
+
1242
+ // Combine: reduce top-k results and write back
1243
+ // NOTES: reuse shared memory from start up to the barriers
1244
+ // 1 token, 1 topk latency: ~3 us
1245
+ constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16);
1246
+ constexpr uint32_t kNumElemsPerUint4 = sizeof(uint4) / sizeof(nv_bfloat162);
1247
+
1248
+ // 3 slots of chunk is needed: 2 load stages and 1 store
1249
+ constexpr uint32_t kNumChunkSlots = 3;
1250
+ constexpr uint32_t kNumMaxRegistersForBuffer = 128;
1251
+
1252
+ // NOTES: either 1 or 2 chunks for simplicity
1253
+ // NOTES: Restrict on both smem and register
1254
+ constexpr uint32_t kNumChunks =
1255
+ kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE and kHidden <= 32 * kNumMaxRegistersForBuffer ? 1 : 2;
1256
+ constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks;
1257
+ constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4);
1258
+ constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32;
1259
+ DG_STATIC_ASSERT(kHidden % kNumChunks == 0, "Hidden must be divisible by number of chunks");
1260
+ DG_STATIC_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / kNumChunks <= SMEM_BEFORE_BARRIER_SIZE, "Hidden is too large");
1261
+ DG_STATIC_ASSERT(kNumChunkBytes % 16 == 0, "Combine chunk must be TMA-aligned (16 bytes)");
1262
+ DG_STATIC_ASSERT(kNumChunkBytes % sizeof(uint4) == 0, "Combine chunk must be divisible by 16 bytes");
1263
+ DG_STATIC_ASSERT(kNumChunkUint4 % 32 == 0, "Combine chunk must be a multiple of 32 16-byte elements (one per lane)");
1264
+ DG_STATIC_ASSERT(kNumTopk <= 32, "Top-k must fit in a single warp");
1265
+
1266
+ // Verify combined shared memory budget at runtime
1267
+ DG_DEVICE_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumChunkBytes <= static_cast<uint32_t>(
1268
+ reinterpret_cast<uint8_t*>(barrier_start_ptr) - smem_buffer));
1269
+
1270
+ // Per-warp buffer: 2 stage load buffers + 1 store buffer
1271
+ const auto combine_load_buffer = utils::PatternVisitor([&](const uint32_t& i) {
1272
+ return math::advance_ptr<uint4>(smem_buffer, (epilogue_warp_idx + i * kNumEpilogueWarps) * kNumChunkBytes);
1273
+ });
1274
+ const auto combine_store_buffer = math::advance_ptr<uint4>(smem_buffer, (epilogue_warp_idx + kNumEpilogueWarps * 2) * kNumChunkBytes);
1275
+
1276
+ // Per-warp barriers
1277
+ auto combine_load_barriers = utils::PatternVisitor([&](const uint32_t& i) {
1278
+ return combine_barriers[i + epilogue_warp_idx * 2];
1279
+ });
1280
+
1281
+ // Iterate over all tokens
1282
+ uint32_t combine_phase = 0;
1283
+ uint32_t load_stage_idx = 0;
1284
+ for (uint32_t token_idx = sm_idx * kNumEpilogueWarps + epilogue_warp_idx;
1285
+ token_idx < num_tokens;
1286
+ token_idx += kNumSMs * kNumEpilogueWarps) {
1287
+ // Read top-k slot indices: each lane reads one slot, then broadcast via exchange
1288
+ DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk");
1289
+ const int stored_topk_slot_idx = lane_idx < kNumTopk ?
1290
+ static_cast<int>(__ldg(input_topk_idx_buffer.get_base_ptr<int64_t>() + token_idx * kNumTopk + lane_idx)) : -1;
1291
+ const uint32_t total_mask = __ballot_sync(0xffffffff, stored_topk_slot_idx >= 0);
1292
+
1293
+ // Iterate all chunks
1294
+ for (uint32_t chunk = 0; chunk < kNumChunks; ++ chunk) {
1295
+ const uint32_t chunk_byte_offset = chunk * kNumChunkBytes;
1296
+
1297
+ // Move mask and load
1298
+ uint32_t mask = total_mask;
1299
+ const auto move_mask_and_load = [&](const uint32_t& i) {
1300
+ if (mask) {
1301
+ // Move
1302
+ const uint32_t slot_idx = __ffs(mask) - 1;
1303
+ mask ^= 1 << slot_idx;
1304
+
1305
+ // Load
1306
+ if (cute::elect_one_sync()) {
1307
+ const auto src_ptr = math::advance_ptr<uint8_t>(
1308
+ combine_token_buffer.get_rank_buffer(slot_idx)
1309
+ .get_data_buffer(token_idx).get_base_ptr(),
1310
+ chunk_byte_offset);
1311
+ ptx::tma_load_1d(combine_load_buffer[i], src_ptr, combine_load_barriers[i], kNumChunkBytes);
1312
+ ptx::mbarrier_arrive_and_set_tx(combine_load_barriers[i], kNumChunkBytes);
1313
+ }
1314
+ __syncwarp();
1315
+ return true;
1316
+ }
1317
+ return false;
1318
+ };
1319
+
1320
+ // Load the first selection
1321
+ bool do_reduce = move_mask_and_load(load_stage_idx);
1322
+
1323
+ // Accumulate all top-k contributions for this chunk in float registers
1324
+ float2 reduced[kNumUint4PerLane * kNumElemsPerUint4] = {};
1325
+ while (do_reduce) {
1326
+ // Prefetch next top-k into the buffer while current is being accumulated
1327
+ do_reduce = move_mask_and_load(load_stage_idx ^ 1);
1328
+
1329
+ // Accumulate
1330
+ combine_load_barriers[load_stage_idx]->wait(combine_phase);
1331
+ #pragma unroll
1332
+ for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) {
1333
+ const auto uint4_values = combine_load_buffer[load_stage_idx][j * 32 + lane_idx];
1334
+ const auto bf16_values = reinterpret_cast<const nv_bfloat162*>(&uint4_values);
1335
+ #pragma unroll
1336
+ for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l)
1337
+ ptx::accumulate(reduced[j * kNumElemsPerUint4 + l], bf16_values[l]);
1338
+ }
1339
+ combine_phase ^= load_stage_idx;
1340
+ load_stage_idx ^= 1;
1341
+ }
1342
+
1343
+ // Cast
1344
+ #pragma unroll
1345
+ for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) {
1346
+ uint4 casted;
1347
+ auto casted_bf16 = reinterpret_cast<nv_bfloat162*>(&casted);
1348
+ #pragma unroll
1349
+ for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l)
1350
+ casted_bf16[l] = __float22bfloat162_rn(reduced[j * kNumElemsPerUint4 + l]);
1351
+
1352
+ // Wait share memory release and write
1353
+ if (j == 0) {
1354
+ ptx::tma_store_wait<0>();
1355
+ __syncwarp();
1356
+ }
1357
+ ptx::st_shared(combine_store_buffer + j * 32 + lane_idx,
1358
+ casted.x, casted.y, casted.z, casted.w);
1359
+ }
1360
+ __syncwarp();
1361
+
1362
+ // TMA store the token chunk
1363
+ if (cute::elect_one_sync()) {
1364
+ cute::tma_store_fence();
1365
+ ptx::tma_store_1d(
1366
+ math::advance_ptr(y, static_cast<uint64_t>(token_idx) * kNumHiddenBytes + chunk_byte_offset),
1367
+ combine_store_buffer, kNumChunkBytes);
1368
+ cute::tma_store_arrive();
1369
+ }
1370
+ __syncwarp();
1371
+ }
1372
+ }
1373
+ }
1374
+ #else
1375
+ if (blockIdx.x == 0 and threadIdx.x == 0)
1376
+ DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
1377
+ #endif
1378
+ }
1379
+
1380
+ } // namespace deep_gemm
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh CHANGED
@@ -155,6 +155,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
155
  auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
156
  DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
157
 
 
 
 
158
  // Initialize barriers
159
  if (warp_idx == 1 and cute::elect_one_sync()) {
160
  #pragma unroll
@@ -546,12 +549,13 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
546
  }
547
  }
548
  }
549
-
550
- // Deallocate tensor memory by the last UMMA store warp
551
- // NOTES: warp 0 is waiting TMA store
552
- if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
553
- Allocator().free(0, kNumTmemCols);
554
  }
 
 
 
 
 
 
555
  #else
556
  if (blockIdx.x == 0 and threadIdx.x == 0)
557
  DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
 
155
  auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
156
  DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
157
 
158
+ if (kNumMulticast > 1)
159
+ cute::cluster_sync();
160
+
161
  // Initialize barriers
162
  if (warp_idx == 1 and cute::elect_one_sync()) {
163
  #pragma unroll
 
549
  }
550
  }
551
  }
 
 
 
 
 
552
  }
553
+
554
+ // Deallocate tensor memory
555
+ kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
556
+ if (warp_idx == 0)
557
+ Allocator().free(0, kNumTmemCols);
558
+
559
  #else
560
  if (blockIdx.x == 0 and threadIdx.x == 0)
561
  DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh CHANGED
@@ -6,27 +6,31 @@
6
  #include <cute/arch/cluster_sm90.hpp>
7
  #include <cute/arch/copy_sm90_desc.hpp>
8
 
 
 
 
9
  #include <deep_gemm/common/utils.cuh>
10
- #include <deep_gemm/common/sm90_utils.cuh>
11
- #include <deep_gemm/common/sm100_utils.cuh>
 
 
12
 
13
  namespace deep_gemm {
14
 
15
- using namespace deep_gemm::sm90;
16
- using namespace deep_gemm::sm100;
17
-
18
  template <uint32_t kNumHeads, uint32_t kHeadDim,
19
  bool kIsCompressedLogits,
20
  uint32_t BLOCK_Q, uint32_t BLOCK_KV,
21
  uint32_t kNumQStages, uint32_t kNumKVStages,
 
22
  uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
 
23
  uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
24
- __global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
25
  void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
26
- const uint32_t max_seqlen_k, const uint64_t stride_logits,
27
  uint32_t* cu_seq_len_k_start,
28
  uint32_t* cu_seq_len_k_end,
29
- float* logits,
30
  const __grid_constant__ cute::TmaDescriptor tensor_map_q,
31
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
32
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
@@ -35,26 +39,26 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
35
  // Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64`
36
  // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
37
  // Q should be load only at once for a block
38
- const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q);
39
 
40
  // Types
41
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
42
 
43
- // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
44
- const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
45
- const auto& warp_in_group_idx = warp_idx % 4;
46
- const auto& warpgroup_idx = warp_idx / 4;
47
- const auto& lane_idx = get_lane_idx();
 
48
 
49
  // Prefetch TMA descriptors
50
  DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
51
- if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
52
  cute::prefetch_tma_descriptor(&tensor_map_q);
53
  cute::prefetch_tma_descriptor(&tensor_map_kv);
54
  cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
55
  cute::prefetch_tma_descriptor(&tensor_map_weights);
56
  }
57
- __syncwarp();
58
 
59
  // Shared memory configs
60
  // NOTES: weight may be unaligned
@@ -62,7 +66,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
62
  static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
63
  static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
64
  static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
65
- static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u);
66
 
67
  // Align to 512 bytes for swizzle-64B
68
  extern __shared__ __align__(512) uint8_t smem_buffer[];
@@ -75,19 +79,19 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
75
  DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
76
 
77
  // Data on shared memory
78
- auto smem_q = PatternVisitor([&](const uint32_t& i) {
79
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
80
  SMEM_Q_SIZE_PER_STAGE * i);
81
  });
82
- auto smem_weights = PatternVisitor([&](const uint32_t& i) {
83
  return reinterpret_cast<float*>(smem_buffer +
84
  SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
85
  });
86
- auto smem_kv = PatternVisitor([&](const uint32_t& i) {
87
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
88
  SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
89
  });
90
- auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
91
  return reinterpret_cast<float*>(smem_buffer +
92
  SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages +
93
  SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
@@ -95,76 +99,77 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
95
 
96
  // TMA barriers
97
  auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
98
- auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
99
- auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
100
- auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
101
- auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
102
- auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); });
103
- auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); });
104
 
105
  // Tensor memory allocation
106
  auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2);
107
 
108
  // Initialize barriers
109
  DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads");
110
- const bool& is_tma_load_warp = (warp_idx == (kNumMathThreads / 32));
111
- const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 1));
112
- if (is_tma_load_warp and cute::elect_one_sync()) {
113
  #pragma unroll
114
  for (uint32_t i = 0; i < kNumQStages; ++ i) {
115
  full_q_barriers[i]->init(1);
116
- empty_q_barriers[i]->init(kNumMathThreads);
117
  }
118
  #pragma unroll
119
  for (uint32_t i = 0; i < kNumKVStages; ++ i) {
120
  full_kv_barriers[i]->init(1);
121
  empty_kv_barriers[i]->init(kNumMathThreads);
122
  }
123
- #pragma unroll
124
- for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
125
- full_umma_barriers[i]->init(1);
126
- empty_umma_barriers[i]->init(128);
127
- }
128
-
129
- // Make initialized barrier visible in async proxy
130
  cutlass::arch::fence_barrier_init();
131
- } else if (is_umma_warp) {
 
 
 
 
 
 
 
 
 
132
  // Allocate tensor memory
133
  cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
134
  }
135
  __syncthreads();
136
 
137
  // Register reconfigurations
138
- constexpr uint32_t kNumSpecializedRegisters = 24;
139
- constexpr uint32_t kNumMathRegisters = 240;
140
 
141
  // Block scheduler
142
- uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0;
143
- const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
144
- return {block_q_idx + gridDim.x, q_iter_idx + 1};
145
  };
146
  uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
147
- const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
148
  uint32_t start = cute::numeric_limits<uint32_t>::max();
149
  uint32_t end = cute::numeric_limits<uint32_t>::min();
150
 
151
  #pragma unroll
152
  for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
153
- const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
154
- seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx);
155
- seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx);
156
  start = min(start, min(seq_k_start[i], seq_len_kv));
157
  end = max(end, min(seq_k_end[i], seq_len_kv));
158
  }
 
159
  start = start / 4 * 4;
160
  return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
161
  ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
162
- start, ceil_div(end - start, BLOCK_KV)}; // Task info
163
  };
164
 
165
  // KV pipeline
166
  uint32_t num_total_kv_blocks = 0;
167
- const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
168
  return {
169
  (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
170
  ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
@@ -177,13 +182,16 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
177
  constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
178
  constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads;
179
 
180
- if (is_tma_load_warp) {
 
 
 
181
  cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
182
 
183
  // Prefetch
184
- const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
185
- tma_copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
186
- tma_copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
187
  full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
188
  };
189
  if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
@@ -209,10 +217,10 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
209
  empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
210
 
211
  // Issue TMA KV
212
- tma_copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
213
- smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
214
- tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
215
- smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
216
  full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
217
  }
218
  num_total_kv_blocks += num_kv_blocks;
@@ -221,11 +229,11 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
221
  CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
222
  }
223
  }
224
- } else if (is_umma_warp) {
225
  cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
226
 
227
  // Require full allocation
228
- DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
229
 
230
  // Make UMMA desc
231
  auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
@@ -252,12 +260,12 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
252
  #pragma unroll
253
  for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
254
  empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1);
255
- tcgen05_after_thread_sync();
256
  #pragma unroll
257
  for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
258
- auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
259
  smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
260
- auto b_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
261
  smem_q[q_stage_idx], 0, k * UMMA_K);
262
  cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
263
  }
@@ -266,23 +274,37 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
266
  }
267
  num_total_kv_blocks += num_kv_blocks;
268
 
 
 
 
 
269
  // Jump to the next block
270
  CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
271
  }
272
- } else if (warp_idx >= kNumMathThreads / 32) {
273
  cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
274
- } else if (warp_idx < kNumMathThreads / 32) {
275
  cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
276
 
277
  // Offsets
278
- const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0);
279
- const auto& warp_offset = warp_idx * 32;
280
- const auto& v_offset = lane_idx;
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
- // Preload weights
283
- constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads);
284
- float weights[BLOCK_Q][kNumWeightsInReg];
285
- DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
286
 
287
  while (block_q_idx < num_q_blocks) {
288
  CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
@@ -293,9 +315,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
293
  // Read weights
294
  #pragma unroll
295
  for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
296
- for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) {
297
- weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
298
- }
299
  }
300
 
301
  // Compute over KV blocks
@@ -307,82 +329,59 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
307
  full_kv_barriers[kv_stage_idx]->wait(kv_phase);
308
 
309
  // Read per-KV scales
310
- float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_offset);
311
 
312
  // Wait UMMA arrival
313
  full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1);
314
- tcgen05_after_thread_sync();
315
 
316
  // Release KV empty
317
  empty_kv_barriers[kv_stage_idx]->arrive();
318
 
319
  // Reduce over the head dim and store
320
- const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset;
321
- static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
322
  DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
323
 
324
- constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q;
325
- DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems");
326
- uint32_t shifted_accum[kNumLDTMElems];
327
- auto tmem_load = [&](auto... Is) {
328
- if constexpr (kNumLDTMElems == 32) {
329
- cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...);
330
- } else if constexpr (kNumLDTMElems == 64) {
331
- cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...);
332
- } else if constexpr (kNumLDTMElems == 128) {
333
- cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...);
334
- }
335
- };
336
- [&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
337
- cutlass::arch::fence_view_async_tmem_load();
338
-
339
- tcgen05_before_thread_sync();
340
- empty_umma_barriers[warpgroup_idx]->arrive();
341
-
342
  #pragma unroll
343
  for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
344
- auto accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
 
 
 
 
 
 
 
 
345
 
 
346
  auto sum_0 = make_float2(0, 0);
347
  auto sum_1 = make_float2(0, 0);
348
 
349
- const auto& transform_reg = [&](const uint32_t& j, const float2& sum) {
350
  auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
351
  auto b = make_float2(weights[i][j], weights[i][j + 1]);
352
  return __ffma2_rn(a, b, sum);
353
  };
354
 
355
  #pragma unroll
356
- for (int j = 0; j < kNumWeightsInReg; j += 4) {
357
- sum_0 = transform_reg(j, sum_0);
358
- sum_1 = transform_reg(j + 2, sum_1);
359
- }
360
-
361
- const auto& transform_smem = [&](const uint32_t& j, const float2& sum) {
362
- auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
363
- auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j),
364
- ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1));
365
- return __ffma2_rn(a, b, sum);
366
- };
367
-
368
- #pragma unroll
369
- for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) {
370
- sum_0 = transform_smem(j, sum_0);
371
- sum_1 = transform_smem(j + 2, sum_1);
372
  }
373
 
374
  auto sum = __fadd2_rn(sum_0, sum_1);
375
- float result = scale_kv * (sum.x + sum.y);
376
 
377
  // Store into the global memory
378
- // NOTES: we have redundant writes here, consider more carefully
379
- const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
380
  if constexpr (kIsCompressedLogits) {
381
- if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i])
382
- logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result;
383
  } else {
384
- logits[q_idx * stride_logits + kv_offset + v_offset] = result;
385
  }
 
386
  }
387
  }
388
  num_total_kv_blocks += num_kv_blocks;
@@ -393,12 +392,12 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
393
  // Jump to the next block
394
  CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
395
  }
396
- }
397
 
398
- // Free tensor memory
399
- __syncthreads();
400
- if (is_tma_load_warp)
401
- cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
 
402
  }
403
 
404
  } // namespace deep_gemm
 
6
  #include <cute/arch/cluster_sm90.hpp>
7
  #include <cute/arch/copy_sm90_desc.hpp>
8
 
9
+ #include <deep_gemm/common/cute_tie.cuh>
10
+ #include <deep_gemm/common/math.cuh>
11
+ #include <deep_gemm/common/tma_copy.cuh>
12
  #include <deep_gemm/common/utils.cuh>
13
+ #include <deep_gemm/mma/sm100.cuh>
14
+ #include <deep_gemm/ptx/ld_st.cuh>
15
+ #include <deep_gemm/ptx/tcgen05.cuh>
16
+ #include <deep_gemm/ptx/utils.cuh>
17
 
18
  namespace deep_gemm {
19
 
 
 
 
20
  template <uint32_t kNumHeads, uint32_t kHeadDim,
21
  bool kIsCompressedLogits,
22
  uint32_t BLOCK_Q, uint32_t BLOCK_KV,
23
  uint32_t kNumQStages, uint32_t kNumKVStages,
24
+ uint32_t kNumSMs,
25
  uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
26
+ typename logits_dtype_t,
27
  uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
28
+ CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
29
  void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
30
+ const uint32_t max_seqlen_k, const uint32_t stride_logits,
31
  uint32_t* cu_seq_len_k_start,
32
  uint32_t* cu_seq_len_k_end,
33
+ logits_dtype_t* logits,
34
  const __grid_constant__ cute::TmaDescriptor tensor_map_q,
35
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
36
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
 
39
  // Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64`
40
  // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
41
  // Q should be load only at once for a block
42
+ const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q);
43
 
44
  // Types
45
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
46
 
47
+ // Utils
48
+ const auto sm_idx = blockIdx.x;
49
+ const auto warp_idx = cutlass::canonical_warp_idx_sync();
50
+ const auto warpgroup_idx = warp_idx / 4;
51
+ const auto lane_idx = ptx::get_lane_idx();
52
+ constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
53
 
54
  // Prefetch TMA descriptors
55
  DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
56
+ if (warp_idx == kSpecWarpStart) {
57
  cute::prefetch_tma_descriptor(&tensor_map_q);
58
  cute::prefetch_tma_descriptor(&tensor_map_kv);
59
  cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
60
  cute::prefetch_tma_descriptor(&tensor_map_weights);
61
  }
 
62
 
63
  // Shared memory configs
64
  // NOTES: weight may be unaligned
 
66
  static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
67
  static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
68
  static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
69
+ static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u);
70
 
71
  // Align to 512 bytes for swizzle-64B
72
  extern __shared__ __align__(512) uint8_t smem_buffer[];
 
79
  DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
80
 
81
  // Data on shared memory
82
+ auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
83
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
84
  SMEM_Q_SIZE_PER_STAGE * i);
85
  });
86
+ auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
87
  return reinterpret_cast<float*>(smem_buffer +
88
  SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
89
  });
90
+ auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
91
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
92
  SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
93
  });
94
+ auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
95
  return reinterpret_cast<float*>(smem_buffer +
96
  SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages +
97
  SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
 
99
 
100
  // TMA barriers
101
  auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
102
+ auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
103
+ auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
104
+ auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
105
+ auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
106
+ auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); });
107
+ auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); });
108
 
109
  // Tensor memory allocation
110
  auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2);
111
 
112
  // Initialize barriers
113
  DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads");
114
+ if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) {
 
 
115
  #pragma unroll
116
  for (uint32_t i = 0; i < kNumQStages; ++ i) {
117
  full_q_barriers[i]->init(1);
118
+ empty_q_barriers[i]->init(kNumMathThreads + 32);
119
  }
120
  #pragma unroll
121
  for (uint32_t i = 0; i < kNumKVStages; ++ i) {
122
  full_kv_barriers[i]->init(1);
123
  empty_kv_barriers[i]->init(kNumMathThreads);
124
  }
 
 
 
 
 
 
 
125
  cutlass::arch::fence_barrier_init();
126
+ }
127
+ if (warp_idx == kSpecWarpStart + 1) {
128
+ if (cute::elect_one_sync()) {
129
+ #pragma unroll
130
+ for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
131
+ full_umma_barriers[i]->init(1);
132
+ empty_umma_barriers[i]->init(128);
133
+ }
134
+ cutlass::arch::fence_barrier_init();
135
+ }
136
  // Allocate tensor memory
137
  cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
138
  }
139
  __syncthreads();
140
 
141
  // Register reconfigurations
142
+ constexpr uint32_t kNumSpecializedRegisters = 40;
143
+ constexpr uint32_t kNumMathRegisters = 232;
144
 
145
  // Block scheduler
146
+ uint32_t block_q_idx = sm_idx, q_iter_idx = 0;
147
+ const auto get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
148
+ return {block_q_idx + kNumSMs, q_iter_idx + 1};
149
  };
150
  uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
151
+ const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
152
  uint32_t start = cute::numeric_limits<uint32_t>::max();
153
  uint32_t end = cute::numeric_limits<uint32_t>::min();
154
 
155
  #pragma unroll
156
  for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
157
+ const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
158
+ seq_k_start[i] = cu_seq_len_k_start[q_idx];
159
+ seq_k_end[i] = cu_seq_len_k_end[q_idx];
160
  start = min(start, min(seq_k_start[i], seq_len_kv));
161
  end = max(end, min(seq_k_end[i], seq_len_kv));
162
  }
163
+ // TMA alignment requirements for SF KV
164
  start = start / 4 * 4;
165
  return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
166
  ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
167
+ start, math::ceil_div(end - start, BLOCK_KV)}; // Task info
168
  };
169
 
170
  // KV pipeline
171
  uint32_t num_total_kv_blocks = 0;
172
+ const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
173
  return {
174
  (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
175
  ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
 
182
  constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
183
  constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads;
184
 
185
+ // Wait for primary kernel completion
186
+ cudaGridDependencySynchronize();
187
+
188
+ if (warp_idx == kSpecWarpStart) {
189
  cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
190
 
191
  // Prefetch
192
+ const auto issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
193
+ tma::copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
194
+ tma::copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
195
  full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
196
  };
197
  if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
 
217
  empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
218
 
219
  // Issue TMA KV
220
+ tma::copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
221
+ smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
222
+ tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
223
+ smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
224
  full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
225
  }
226
  num_total_kv_blocks += num_kv_blocks;
 
229
  CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
230
  }
231
  }
232
+ } else if (warp_idx == kSpecWarpStart + 1) {
233
  cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
234
 
235
  // Require full allocation
236
+ DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
237
 
238
  // Make UMMA desc
239
  auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
 
260
  #pragma unroll
261
  for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
262
  empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1);
263
+ ptx::tcgen05_after_thread_sync();
264
  #pragma unroll
265
  for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
266
+ auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
267
  smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
268
+ auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
269
  smem_q[q_stage_idx], 0, k * UMMA_K);
270
  cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
271
  }
 
274
  }
275
  num_total_kv_blocks += num_kv_blocks;
276
 
277
+ // UMMA warp must also arrive on empty_q to prevent running ahead
278
+ // of math warps in the Q pipeline
279
+ empty_q_barriers[q_stage_idx]->arrive();
280
+
281
  // Jump to the next block
282
  CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
283
  }
284
+ } else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) {
285
  cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
286
+ } else if (warp_idx < kSpecWarpStart) {
287
  cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
288
 
289
  // Offsets
290
+ const auto tmem_start = warpgroup_idx * UMMA_N;
291
+ const auto math_thread_idx = warp_idx * 32 + lane_idx;
292
+
293
+ // Helper lambda for loading tensor memory
294
+ auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
295
+ constexpr int N = decltype(num_elems_c)::value;
296
+ DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
297
+ using Loader = cute::conditional_t<N == 32,
298
+ cute::SM100_TMEM_LOAD_32dp32b32x,
299
+ cute::SM100_TMEM_LOAD_32dp32b64x>;
300
+ [&]<size_t... Is>(cute::index_sequence<Is...>) {
301
+ Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
302
+ }(cute::make_index_sequence<N>{});
303
+ cutlass::arch::fence_view_async_tmem_load();
304
+ };
305
 
306
+ // Local register buffers
307
+ float weights[BLOCK_Q][kNumHeads];
 
 
308
 
309
  while (block_q_idx < num_q_blocks) {
310
  CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
 
315
  // Read weights
316
  #pragma unroll
317
  for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
318
+ #pragma unroll
319
+ for (uint32_t j = 0; j < kNumHeads; ++ j)
320
+ weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
321
  }
322
 
323
  // Compute over KV blocks
 
329
  full_kv_barriers[kv_stage_idx]->wait(kv_phase);
330
 
331
  // Read per-KV scales
332
+ float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx);
333
 
334
  // Wait UMMA arrival
335
  full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1);
336
+ ptx::tcgen05_after_thread_sync();
337
 
338
  // Release KV empty
339
  empty_kv_barriers[kv_stage_idx]->arrive();
340
 
341
  // Reduce over the head dim and store
342
+ const auto kv_offset = kv_start + kv_block_idx * BLOCK_KV + math_thread_idx;
 
343
  DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  #pragma unroll
346
  for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
347
+ // Load accumulator from TMEM
348
+ float accum[kNumHeads];
349
+ tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
350
+
351
+ // Release TMEM empty
352
+ if (i == BLOCK_Q - 1) {
353
+ ptx::tcgen05_before_thread_sync();
354
+ empty_umma_barriers[warpgroup_idx]->arrive();
355
+ }
356
 
357
+ // Accumulate weighted ReLU in parallel
358
  auto sum_0 = make_float2(0, 0);
359
  auto sum_1 = make_float2(0, 0);
360
 
361
+ const auto transform = [&](const uint32_t& j, const float2& sum) {
362
  auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
363
  auto b = make_float2(weights[i][j], weights[i][j + 1]);
364
  return __ffma2_rn(a, b, sum);
365
  };
366
 
367
  #pragma unroll
368
+ for (uint32_t j = 0; j < kNumHeads; j += 4) {
369
+ sum_0 = transform(j, sum_0);
370
+ sum_1 = transform(j + 2, sum_1);
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  }
372
 
373
  auto sum = __fadd2_rn(sum_0, sum_1);
374
+ auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
375
 
376
  // Store into the global memory
377
+ const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast<uint64_t>(stride_logits);
 
378
  if constexpr (kIsCompressedLogits) {
379
+ if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i])
380
+ logits[q_offset + kv_offset - seq_k_start[i]] = result;
381
  } else {
382
+ logits[q_offset + kv_offset] = result;
383
  }
384
+ __syncwarp();
385
  }
386
  }
387
  num_total_kv_blocks += num_kv_blocks;
 
392
  // Jump to the next block
393
  CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
394
  }
 
395
 
396
+ // Free tensor memory
397
+ cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
398
+ if (warp_idx == 0)
399
+ cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
400
+ }
401
  }
402
 
403
  } // namespace deep_gemm
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh CHANGED
@@ -6,56 +6,65 @@
6
  #include <cute/arch/cluster_sm90.hpp>
7
  #include <cute/arch/copy_sm90_desc.hpp>
8
 
 
 
 
9
  #include <deep_gemm/common/utils.cuh>
10
- #include <deep_gemm/common/sm90_utils.cuh>
11
- #include <deep_gemm/common/sm100_utils.cuh>
12
-
13
- #include <deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh>
 
14
 
15
  namespace deep_gemm {
16
 
17
- using namespace deep_gemm::sm90;
18
- using namespace deep_gemm::sm100;
19
-
20
  template <uint32_t kNextN, uint32_t kNumHeads,
21
  uint32_t kHeadDim, uint32_t BLOCK_KV,
22
- bool kIsContextLens2D,
23
  uint32_t kNumQStages, uint32_t kNumKVStages,
24
  uint32_t SPLIT_KV,
25
  uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
 
26
  uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
27
- __global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
28
  void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
29
- const uint64_t logits_stride, const uint64_t block_table_stride,
30
- const uint32_t* context_lens, float* logits,
31
- const uint32_t* block_table, const uint32_t* schedule_meta,
 
32
  const __grid_constant__ cute::TmaDescriptor tensor_map_q,
33
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
34
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
35
  const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
36
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
37
 
38
- // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
39
- const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
40
- const auto& warpgroup_idx = warp_idx / 4;
41
- const auto& lane_idx = get_lane_idx();
 
 
42
 
43
  // Prefetch TMA descriptors
44
  DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
45
- if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
46
  cute::prefetch_tma_descriptor(&tensor_map_q);
47
  cute::prefetch_tma_descriptor(&tensor_map_kv);
48
  cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
49
  cute::prefetch_tma_descriptor(&tensor_map_weights);
50
  }
51
- __syncwarp();
 
 
 
 
52
 
53
  // Shared memory configs
54
  static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
55
- static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
56
  static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
57
  static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float);
58
- static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
59
 
60
  // Align to swizzling alignment bytes
61
  extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
@@ -63,43 +72,40 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
63
  DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
64
 
65
  // Q and KV data on shared memory
66
- auto smem_q = PatternVisitor([&](const uint32_t& i) {
67
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
68
  });
69
- auto smem_kv = PatternVisitor([&](const uint32_t& i) {
70
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i);
71
  });
72
  constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages;
73
- auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
74
  return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
75
  });
76
- auto smem_weights = PatternVisitor([&](const uint32_t& i) {
77
  return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
78
  });
79
 
80
  // Barriers and TMEM pointer on shared memory
81
  const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
82
- auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
83
- auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
84
- auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
85
- auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
86
  const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
87
- auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; });
88
- auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; });
89
  auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(umma_barrier_ptr + kNumMathWarpGroups * 2);
90
 
91
- constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups;
92
  DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
93
- const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4);
94
- const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4);
95
- const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1);
96
 
97
  // Initialize barriers
98
- if (is_tma_load_warp and cute::elect_one_sync()) {
99
  #pragma unroll
100
  for (uint32_t i = 0; i < kNumQStages; ++ i) {
101
  full_q_barriers[i]->init(1);
102
- empty_q_barriers[i]->init(kNumMathThreads);
103
  }
104
  #pragma unroll
105
  for (uint32_t i = 0; i < kNumKVStages; ++ i) {
@@ -108,7 +114,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
108
  }
109
  cutlass::arch::fence_barrier_init();
110
  }
111
- if (is_umma_warp) {
112
  if (cute::elect_one_sync()) {
113
  #pragma unroll
114
  for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) {
@@ -123,79 +129,92 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
123
  __syncthreads();
124
 
125
  // Register reconfigurations
126
- constexpr uint32_t kNumSpecializedRegisters = 40;
127
- constexpr uint32_t kNumMathRegisters = 232;
 
 
 
128
 
129
  // Scheduler
130
  constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
131
- auto scheduler = PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit>(batch_size, blockIdx.x, context_lens, schedule_meta);
132
  DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
133
 
134
  // Q and KV pipeline
135
- const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
136
  return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
137
  };
138
- const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
139
  return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
140
  };
141
- uint32_t q_iter_idx = 0, kv_iter_idx = 0;
142
 
143
  // UMMA settings
144
  // Construct instruction with layout D
145
  constexpr uint32_t UMMA_M = 128;
146
  constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
147
- constexpr uint32_t UMMA_N = kNextN * kNumHeads;
148
  DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
149
 
150
- if (is_tma_load_warp) {
151
- // TMA warp-group for loading data
152
  cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
 
 
153
 
154
- const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
155
  if (cute::elect_one_sync()) {
156
- tma_copy<kHeadDim, kNextN * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
157
- tma_copy<kNextN * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx);
 
158
  full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
159
  }
160
  };
161
 
162
- // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
163
- uint32_t q_idx = batch_size, kv_idx, num_kv;
164
- uint32_t next_q_idx, next_kv_idx, next_num_kv;
165
  bool fetched_next_task;
166
 
167
  // Prefetch the first Q
168
- if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)))
169
- issue_tma_q(0, next_q_idx), q_iter_idx = 1;
170
 
171
- int kv_block_idx_ptr = 32;
172
  uint32_t kv_block_idx_storage;
173
 
174
  while (fetched_next_task) {
175
- // Prefetch next Q when current Q changes
176
- bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1));
177
- q_idx = next_q_idx;
 
 
 
 
 
178
  kv_idx = next_kv_idx;
179
  num_kv = next_num_kv;
180
 
181
  // Read KV block index
182
- // TODO: deal with `-1`?
183
- if (kv_idx == 0 or kv_block_idx_ptr == 32) {
184
  kv_block_idx_ptr = 0;
185
- kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0);
 
 
186
  }
 
187
  DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`");
188
 
189
  // Wait Q consumer release and issue TMA Q
190
  if (prefetch_q) {
191
  CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
192
  empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
193
- issue_tma_q(q_stage_idx, q_idx + 1);
194
  }
195
 
196
- int kv_block_idx[kNumBlocksPerSplit];
197
  #pragma unroll
198
- for (int i = 0; i < kNumBlocksPerSplit; ++ i)
199
  kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i);
200
  kv_block_idx_ptr += kNumBlocksPerSplit;
201
 
@@ -205,45 +224,53 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
205
 
206
  if (cute::elect_one_sync()) {
207
  #pragma unroll
208
- for (int i = 0; i < kNumBlocksPerSplit; ++ i) {
209
- tma_copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
210
- smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i,
211
- 0, 0, 1, kv_block_idx[i]);
212
- tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
213
- smem_kv_scales[kv_stage_idx] + BLOCK_KV * i,
214
- 0, kv_block_idx[i]);
215
  }
216
  full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
217
  }
218
 
219
  // Fetch next task
220
- fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv);
221
  }
222
- } else if (is_umma_warp) {
223
  cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
 
 
224
 
225
  // Require full allocation
226
- DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
227
 
228
  // Make UMMA desc
229
  auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
230
  UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
231
  auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
232
 
233
- uint32_t q_idx = batch_size, kv_idx;
234
- uint32_t next_q_idx, next_kv_idx, next_num_kv;
235
  uint32_t q_stage_idx, q_phase;
236
  uint32_t umma_phase = 1;
237
 
238
- while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
239
- if (q_idx != next_q_idx) {
 
 
 
 
 
240
  CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
241
  full_q_barriers[q_stage_idx]->wait(q_phase);
242
  }
243
 
244
- q_idx = next_q_idx;
245
  kv_idx = next_kv_idx;
246
 
 
247
  CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
248
  full_kv_barriers[kv_stage_idx]->wait(kv_phase);
249
 
@@ -251,12 +278,12 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
251
  #pragma unroll
252
  for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
253
  empty_umma_barriers[i]->wait(umma_phase);
254
- tcgen05_after_thread_sync();
255
  #pragma unroll
256
  for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
257
- auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
258
  smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
259
- auto b_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
260
  smem_q[q_stage_idx], 0, k * UMMA_K);
261
  cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
262
  }
@@ -264,29 +291,46 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
264
  }
265
  umma_phase ^= 1;
266
  }
267
- } else if (is_math_warp) {
268
- // Math warp-groups for WGMMA
 
 
269
  cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
 
 
270
 
271
  // Offsets
272
- const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0);
273
- const uint32_t thread_idx = threadIdx.x;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
- // Weights
276
- constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads));
277
- float weights[kNextN][kNumWeightsInReg];
278
- DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
279
 
280
- // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
281
- uint32_t q_idx = batch_size, kv_idx;
282
- uint32_t next_q_idx, next_kv_idx, next_num_kv;
283
  uint32_t q_stage_idx, q_phase;
284
  uint32_t umma_phase = 0;
 
285
 
286
- while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
287
- // Current Q changes
288
- if (q_idx != next_q_idx) {
289
- // Release Last Q empty
290
  if (q_iter_idx > 0)
291
  empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
292
 
@@ -296,30 +340,34 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
296
 
297
  // Read weights
298
  #pragma unroll
299
- for (uint32_t i = 0; i < kNextN; ++ i) {
300
- for (uint32_t j = 0; j < kNumWeightsInReg; ++ j)
301
- weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
 
 
 
 
 
302
  }
303
  }
304
 
305
- // Get current Q and KV index
306
- q_idx = next_q_idx;
307
  kv_idx = next_kv_idx;
308
 
309
  // Calculate KV offset in advance
310
- auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV;
311
 
312
- // Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]`
313
  // Wait TMA KV arrival
314
  CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
315
  full_kv_barriers[kv_stage_idx]->wait(kv_phase);
316
 
317
  // Read per-KV scales
318
- float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx);
319
 
320
  // Wait UMMA arrival
321
- full_umma_barriers[warpgroup_idx]->wait(umma_phase);
322
- tcgen05_after_thread_sync();
323
  umma_phase ^= 1;
324
 
325
  // Release KV empty
@@ -327,72 +375,65 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
327
 
328
  // Reduce over the head dim and store
329
  DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
330
- constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN;
331
- uint32_t shifted_accum[kNumLDTMElems];
332
- DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM");
333
- auto tmem_load = [&](auto... Is) {
334
- if constexpr (kNumLDTMElems == 32) {
335
- cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...);
336
- } else if constexpr (kNumLDTMElems == 64) {
337
- cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...);
338
- } else if constexpr (kNumLDTMElems == 128) {
339
- cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...);
340
- }
341
- };
342
- [&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
343
- cutlass::arch::fence_view_async_tmem_load();
344
-
345
- tcgen05_before_thread_sync();
346
- empty_umma_barriers[warpgroup_idx]->arrive();
347
-
348
- #pragma unroll
349
- for (uint32_t i = 0; i < kNextN; ++ i) {
350
- auto accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
351
-
352
- auto sum_0 = make_float2(0, 0);
353
- auto sum_1 = make_float2(0, 0);
354
 
355
- const auto& transform_reg = [&](const uint32_t& j, const float2& sum) {
356
- auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
357
- auto b = make_float2(weights[i][j], weights[i][j + 1]);
358
- return __ffma2_rn(a, b, sum);
359
- };
360
 
361
  #pragma unroll
362
- for (int j = 0; j < kNumWeightsInReg; j += 4) {
363
- sum_0 = transform_reg(j, sum_0);
364
- sum_1 = transform_reg(j + 2, sum_1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  }
366
 
367
- const auto& transform_smem = [&](const uint32_t& j, const float2& sum) {
368
- auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
369
- auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j),
370
- ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1));
371
- return __ffma2_rn(a, b, sum);
372
- };
373
-
374
- #pragma unroll
375
- for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) {
376
- sum_0 = transform_smem(j, sum_0);
377
- sum_1 = transform_smem(j + 2, sum_1);
378
- }
379
-
380
- auto sum = __fadd2_rn(sum_0, sum_1);
381
- float result = scale_kv * (sum.x + sum.y);
382
 
383
- // Store into the global memory
384
- // NOTES: we have redundant writes here, consider more carefully
385
- logits[kv_offset + i * logits_stride + thread_idx] = result;
 
 
 
 
 
 
 
 
 
386
  }
387
  }
388
- } else {
389
- cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
390
- }
391
 
392
- // Free tensor memory
393
- __syncthreads();
394
- if (is_umma_warp)
395
- cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
 
396
  }
397
 
398
  } // namespace deep_gemm
 
6
  #include <cute/arch/cluster_sm90.hpp>
7
  #include <cute/arch/copy_sm90_desc.hpp>
8
 
9
+ #include <deep_gemm/common/cute_tie.cuh>
10
+ #include <deep_gemm/common/math.cuh>
11
+ #include <deep_gemm/common/tma_copy.cuh>
12
  #include <deep_gemm/common/utils.cuh>
13
+ #include <deep_gemm/mma/sm100.cuh>
14
+ #include <deep_gemm/ptx/ld_st.cuh>
15
+ #include <deep_gemm/ptx/tcgen05.cuh>
16
+ #include <deep_gemm/ptx/utils.cuh>
17
+ #include <deep_gemm/scheduler/paged_mqa_logits.cuh>
18
 
19
  namespace deep_gemm {
20
 
 
 
 
21
  template <uint32_t kNextN, uint32_t kNumHeads,
22
  uint32_t kHeadDim, uint32_t BLOCK_KV,
23
+ bool kIsContextLens2D, bool kIsVarlen,
24
  uint32_t kNumQStages, uint32_t kNumKVStages,
25
  uint32_t SPLIT_KV,
26
  uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
27
+ typename logits_dtype_t,
28
  uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
29
+ CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
30
  void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
31
+ const uint32_t logits_stride, const uint32_t block_table_stride,
32
+ const uint32_t* context_lens, logits_dtype_t* logits,
33
+ const uint32_t* block_table, const uint32_t* indices,
34
+ const uint32_t* schedule_meta,
35
  const __grid_constant__ cute::TmaDescriptor tensor_map_q,
36
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
37
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
38
  const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
39
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
40
 
41
+ // Utils
42
+ const auto sm_idx = blockIdx.x;
43
+ const auto warp_idx = cutlass::canonical_warp_idx_sync();
44
+ const auto warpgroup_idx = warp_idx / 4;
45
+ const auto lane_idx = ptx::get_lane_idx();
46
+ constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
47
 
48
  // Prefetch TMA descriptors
49
  DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
50
+ if (warp_idx == kSpecWarpStart) {
51
  cute::prefetch_tma_descriptor(&tensor_map_q);
52
  cute::prefetch_tma_descriptor(&tensor_map_kv);
53
  cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
54
  cute::prefetch_tma_descriptor(&tensor_map_weights);
55
  }
56
+
57
+ // For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill.
58
+ static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
59
+ static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
60
+ static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom);
61
 
62
  // Shared memory configs
63
  static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
64
+ static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
65
  static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
66
  static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float);
67
+ static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float);
68
 
69
  // Align to swizzling alignment bytes
70
  extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
 
72
  DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
73
 
74
  // Q and KV data on shared memory
75
+ auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
76
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
77
  });
78
+ auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
79
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i);
80
  });
81
  constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages;
82
+ auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
83
  return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
84
  });
85
+ auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
86
  return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
87
  });
88
 
89
  // Barriers and TMEM pointer on shared memory
90
  const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
91
+ auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
92
+ auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
93
+ auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
94
+ auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
95
  const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
96
+ auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; });
97
+ auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; });
98
  auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(umma_barrier_ptr + kNumMathWarpGroups * 2);
99
 
100
+ constexpr uint32_t kNumTmemCols = kNextNAtom * kNumHeads * kNumMathWarpGroups;
101
  DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
 
 
 
102
 
103
  // Initialize barriers
104
+ if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) {
105
  #pragma unroll
106
  for (uint32_t i = 0; i < kNumQStages; ++ i) {
107
  full_q_barriers[i]->init(1);
108
+ empty_q_barriers[i]->init(kNumMathThreads + 32);
109
  }
110
  #pragma unroll
111
  for (uint32_t i = 0; i < kNumKVStages; ++ i) {
 
114
  }
115
  cutlass::arch::fence_barrier_init();
116
  }
117
+ if (warp_idx == kSpecWarpStart + 1) {
118
  if (cute::elect_one_sync()) {
119
  #pragma unroll
120
  for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) {
 
129
  __syncthreads();
130
 
131
  // Register reconfigurations
132
+ constexpr uint32_t kNumSpecializedRegisters = 56;
133
+ constexpr uint32_t kNumMathRegisters = 224;
134
+
135
+ // Wait for primary kernel completion
136
+ cudaGridDependencySynchronize();
137
 
138
  // Scheduler
139
  constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
140
+ using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
141
  DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
142
 
143
  // Q and KV pipeline
144
+ const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
145
  return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
146
  };
147
+ const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
148
  return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
149
  };
 
150
 
151
  // UMMA settings
152
  // Construct instruction with layout D
153
  constexpr uint32_t UMMA_M = 128;
154
  constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
155
+ constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads;
156
  DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
157
 
158
+ if (warp_idx == kSpecWarpStart) {
159
+ // TMA warp for loading data
160
  cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
161
+ auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
162
+ uint32_t q_iter_idx = 0, kv_iter_idx = 0;
163
 
164
+ const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& tma_q_atom_idx) {
165
  if (cute::elect_one_sync()) {
166
+ const auto q_token_idx = Scheduler::atom_to_token_idx(tma_q_atom_idx);
167
+ tma::copy<kHeadDim, kNextNAtom * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_token_idx * kNumHeads);
168
+ tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_token_idx);
169
  full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
170
  }
171
  };
172
 
173
+ // Initialize outside valid range to indicate no previous task
174
+ uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx, num_kv;
175
+ uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
176
  bool fetched_next_task;
177
 
178
  // Prefetch the first Q
179
+ if ((fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)))
180
+ issue_tma_q(0, next_q_atom_idx), q_iter_idx = 1;
181
 
182
+ uint32_t kv_block_idx_ptr = 32;
183
  uint32_t kv_block_idx_storage;
184
 
185
  while (fetched_next_task) {
186
+ // Prefetch next Q when (q, atom) changes
187
+ const auto next_advance = scheduler.get_atom_advance(next_q_atom_idx, batch_size);
188
+ bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + next_advance);
189
+
190
+ if (q_atom_idx != next_q_atom_idx)
191
+ kv_block_idx_ptr = 32;
192
+
193
+ q_atom_idx = next_q_atom_idx;
194
  kv_idx = next_kv_idx;
195
  num_kv = next_num_kv;
196
 
197
  // Read KV block index
198
+ // TODO(xuzhean): consider -1
199
+ if (kv_block_idx_ptr == 32) {
200
  kv_block_idx_ptr = 0;
201
+ const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast<uint64_t>(block_table_stride);
202
+ kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
203
+ ? block_table[block_table_offset + kv_idx + lane_idx] : 0;
204
  }
205
+ __syncwarp();
206
  DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`");
207
 
208
  // Wait Q consumer release and issue TMA Q
209
  if (prefetch_q) {
210
  CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
211
  empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
212
+ issue_tma_q(q_stage_idx, q_atom_idx + next_advance);
213
  }
214
 
215
+ uint32_t kv_block_idx[kNumBlocksPerSplit];
216
  #pragma unroll
217
+ for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i)
218
  kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i);
219
  kv_block_idx_ptr += kNumBlocksPerSplit;
220
 
 
224
 
225
  if (cute::elect_one_sync()) {
226
  #pragma unroll
227
+ for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i) {
228
+ tma::copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
229
+ smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i,
230
+ 0, 0, 1, kv_block_idx[i]);
231
+ tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
232
+ smem_kv_scales[kv_stage_idx] + BLOCK_KV * i,
233
+ 0, kv_block_idx[i]);
234
  }
235
  full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
236
  }
237
 
238
  // Fetch next task
239
+ fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv);
240
  }
241
+ } else if (warp_idx == kSpecWarpStart + 1) {
242
  cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
243
+ auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
244
+ uint32_t q_iter_idx = 0, kv_iter_idx = 0;
245
 
246
  // Require full allocation
247
+ DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
248
 
249
  // Make UMMA desc
250
  auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
251
  UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
252
  auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
253
 
254
+ uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx;
255
+ uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
256
  uint32_t q_stage_idx, q_phase;
257
  uint32_t umma_phase = 1;
258
 
259
+ while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) {
260
+ if (q_atom_idx != next_q_atom_idx) {
261
+ // Release previous Q empty (UMMA warp must participate to prevent
262
+ // running ahead of math warps in the Q pipeline)
263
+ if (q_iter_idx > 0)
264
+ empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
265
+
266
  CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
267
  full_q_barriers[q_stage_idx]->wait(q_phase);
268
  }
269
 
270
+ q_atom_idx = next_q_atom_idx;
271
  kv_idx = next_kv_idx;
272
 
273
+ // Wait KV arrival
274
  CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
275
  full_kv_barriers[kv_stage_idx]->wait(kv_phase);
276
 
 
278
  #pragma unroll
279
  for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
280
  empty_umma_barriers[i]->wait(umma_phase);
281
+ ptx::tcgen05_after_thread_sync();
282
  #pragma unroll
283
  for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
284
+ auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
285
  smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
286
+ auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
287
  smem_q[q_stage_idx], 0, k * UMMA_K);
288
  cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
289
  }
 
291
  }
292
  umma_phase ^= 1;
293
  }
294
+ } else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) {
295
+ cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
296
+ } else if (warp_idx < kSpecWarpStart) {
297
+ // Math warpgroups for reduce
298
  cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
299
+ auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
300
+ uint32_t q_iter_idx = 0, kv_iter_idx = 0;
301
 
302
  // Offsets
303
+ const auto math_warpgroup_idx = warpgroup_idx;
304
+ const auto tmem_start = math_warpgroup_idx * UMMA_N;
305
+ const auto math_thread_idx = warp_idx * 32 + lane_idx;
306
+
307
+ // Helper lambda for loading tensor memory
308
+ auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
309
+ constexpr int N = decltype(num_elems_c)::value;
310
+ DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
311
+ using Loader = cute::conditional_t<N == 32,
312
+ cute::SM100_TMEM_LOAD_32dp32b32x,
313
+ cute::SM100_TMEM_LOAD_32dp32b64x>;
314
+ [&]<size_t... Is>(cute::index_sequence<Is...>) {
315
+ Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
316
+ }(cute::make_index_sequence<N>{});
317
+ cutlass::arch::fence_view_async_tmem_load();
318
+ };
319
 
320
+ // Local register buffers
321
+ float weights[kNextNAtom][kNumHeads];
 
 
322
 
323
+ // Initialize outside valid range to indicate no previous task
324
+ uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx;
325
+ uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
326
  uint32_t q_stage_idx, q_phase;
327
  uint32_t umma_phase = 0;
328
+ bool is_paired_atom = false;
329
 
330
+ while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) {
331
+ // Q or atom changes
332
+ if (q_atom_idx != next_q_atom_idx) {
333
+ // Release last Q empty
334
  if (q_iter_idx > 0)
335
  empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
336
 
 
340
 
341
  // Read weights
342
  #pragma unroll
343
+ for (uint32_t i = 0; i < kNextNAtom; ++ i) {
344
+ #pragma unroll
345
+ for (uint32_t j = 0; j < kNumHeads; ++ j)
346
+ weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
347
+ }
348
+
349
+ if constexpr (kIsVarlen) {
350
+ is_paired_atom = (scheduler.get_atom_advance(next_q_atom_idx, batch_size) == 2);
351
  }
352
  }
353
 
354
+ // Get current task indices
355
+ q_atom_idx = next_q_atom_idx;
356
  kv_idx = next_kv_idx;
357
 
358
  // Calculate KV offset in advance
359
+ auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV;
360
 
 
361
  // Wait TMA KV arrival
362
  CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
363
  full_kv_barriers[kv_stage_idx]->wait(kv_phase);
364
 
365
  // Read per-KV scales
366
+ float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx);
367
 
368
  // Wait UMMA arrival
369
+ full_umma_barriers[math_warpgroup_idx]->wait(umma_phase);
370
+ ptx::tcgen05_after_thread_sync();
371
  umma_phase ^= 1;
372
 
373
  // Release KV empty
 
375
 
376
  // Reduce over the head dim and store
377
  DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
 
379
+ const auto reduce_and_store = [&](auto num_iters_c) {
380
+ constexpr uint32_t kNumIters = decltype(num_iters_c)::value;
381
+ float accum[kNumHeads];
 
 
382
 
383
  #pragma unroll
384
+ for (uint32_t i = 0; i < kNumIters; ++ i) {
385
+ // Load accumulator from TMEM
386
+ tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
387
+
388
+ // Accumulate weighted ReLU in parallel
389
+ auto sum_0 = make_float2(0, 0);
390
+ auto sum_1 = make_float2(0, 0);
391
+
392
+ const auto transform = [&](const uint32_t& j, const float2& sum) {
393
+ auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
394
+ auto b = make_float2(weights[i][j], weights[i][j + 1]);
395
+ return __ffma2_rn(a, b, sum);
396
+ };
397
+
398
+ #pragma unroll
399
+ for (uint32_t j = 0; j < kNumHeads; j += 4) {
400
+ sum_0 = transform(j, sum_0);
401
+ sum_1 = transform(j + 2, sum_1);
402
+ }
403
+
404
+ auto sum = __fadd2_rn(sum_0, sum_1);
405
+ auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
406
+
407
+ // Store into the global memory
408
+ logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + math_thread_idx] = result;
409
+ __syncwarp();
410
  }
411
 
412
+ // Release TMEM empty
413
+ ptx::tcgen05_before_thread_sync();
414
+ empty_umma_barriers[math_warpgroup_idx]->arrive();
415
+ };
 
 
 
 
 
 
 
 
 
 
 
416
 
417
+ if constexpr (kIsVarlen) {
418
+ if (is_paired_atom)
419
+ reduce_and_store(cute::Int<kNextNAtom>{});
420
+ else
421
+ reduce_and_store(cute::Int<1>{});
422
+ } else if constexpr (kPadOddN) {
423
+ if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1)
424
+ reduce_and_store(cute::Int<1>{});
425
+ else
426
+ reduce_and_store(cute::Int<kNextNAtom>{});
427
+ } else {
428
+ reduce_and_store(cute::Int<kNextNAtom>{});
429
  }
430
  }
 
 
 
431
 
432
+ // Free tensor memory
433
+ cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
434
+ if (warp_idx == 0)
435
+ cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
436
+ }
437
  }
438
 
439
  } // namespace deep_gemm
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh CHANGED
@@ -4,20 +4,22 @@
4
 
5
  #include <cutlass/arch/barrier.h>
6
 
7
- #include <deep_gemm/common/reduction.cuh>
 
 
8
  #include <deep_gemm/common/utils.cuh>
9
- #include <deep_gemm/common/sm90_utils.cuh>
10
- #include <deep_gemm/common/sm100_utils.cuh>
 
 
11
 
12
  namespace deep_gemm {
13
 
14
- using namespace deep_gemm::sm100;
15
-
16
  template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
17
- __device__ __forceinline__
18
  uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) {
19
  // Calculate the index of the bank group to be written in the atom
20
- const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase);
21
 
22
  // Reshape the atom in another view and swizzle
23
  // - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)`
@@ -37,7 +39,7 @@ template <uint32_t SHAPE_N, uint32_t SHAPE_K,
37
  uint32_t kSwizzleCDMode,
38
  uint32_t kNumStages,
39
  uint32_t kNumMMAThreads, uint32_t kNumCastAndReduceThreads>
40
- __global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1)
41
  sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
42
  const __grid_constant__ cute::TmaDescriptor tensor_map_a,
43
  const __grid_constant__ cute::TmaDescriptor tensor_map_b,
@@ -58,7 +60,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
58
 
59
  // Utils
60
  const auto warp_idx = cutlass::canonical_warp_idx_sync();
61
- const auto lane_idx = get_lane_idx();
62
 
63
  // Align to 1024 bytes for swizzle-128B
64
  extern __shared__ __align__(1024) uint8_t smem_buffer[];
@@ -70,7 +72,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
70
  DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
71
 
72
  // Real tensor memory size and offsets
73
- constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_K * kNumCastStages + BLOCK_N>();
74
 
75
  // Prefetch TMA descriptors at the very beginning
76
  if (warp_idx == 0 and cute::elect_one_sync()) {
@@ -82,20 +84,20 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
82
  // Data on shared memory (layout as ordered below)
83
  // Fill D/A/B pointers
84
  auto smem_cd = reinterpret_cast<float*>(smem_buffer);
85
- auto smem_a = PatternVisitor([&](const uint32_t& i) {
86
  return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
87
  });
88
- auto smem_b = PatternVisitor([&](const uint32_t& i) {
89
  return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
90
  });
91
 
92
  // Fill barriers
93
  auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
94
  kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
95
- auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
96
- auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
97
- auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
98
- auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
99
  auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4;
100
 
101
  // Fill the tensor memory pointer
@@ -121,7 +123,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
121
  }
122
  __syncthreads();
123
 
124
- constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
125
  constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
126
  constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
127
  const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
@@ -131,6 +133,9 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
131
  const uint32_t m_offset = shape_m * k_split_idx;
132
  const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
133
 
 
 
 
134
  // Dispatch warps into different roles
135
  if (warp_idx < kNumMMAThreads / 32) {
136
  // TMA load warp
@@ -145,8 +150,8 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
145
  uint32_t k_idx = k_offset + s * BLOCK_K;
146
 
147
  // Issue TMAs
148
- tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
149
- tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
150
 
151
  // Arrive at full barriers
152
  constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
@@ -168,7 +173,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
168
  const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
169
 
170
  DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
171
- auto b_desc = make_umma_desc<kMajorB, BLOCK_N, BLOCK_SWIZZLED_BK, kSwizzleBMode>(smem_b[0], 0, 0);
172
  const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
173
 
174
  // Checks for MMA instructions
@@ -185,7 +190,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
185
  const auto& stage_idx = s % kNumStages;
186
  const auto& cast_stage_idx = s % kNumCastStages;
187
  full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1);
188
- tcgen05_after_thread_sync();
189
 
190
  // Issue UMMA
191
  const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
@@ -194,7 +199,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
194
  const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK;
195
  const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK;
196
  const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK;
197
- b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, float>(b_desc_base_lo, offset, in_atom_idx);
198
  umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc);
199
  }
200
 
@@ -218,7 +223,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
218
 
219
  // Wait UMMA arrival
220
  tmem_full_barrier->wait(0);
221
- tcgen05_after_thread_sync();
222
 
223
  // Load from tensor memory into registers, and write shared memory with STSM
224
  DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough");
@@ -239,7 +244,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
239
  values[0], values[1], values[2], values[3]);
240
  cutlass::arch::fence_view_async_tmem_load();
241
  if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16))
242
- st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
243
  if constexpr (BLOCK_M == 64)
244
  __syncwarp();
245
  }
@@ -290,9 +295,9 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
290
  #pragma unroll
291
  for (uint32_t i = 0; i < kNumLoads; i += 2) {
292
  auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset<kSwizzleAMode>(i + lane_idx / 16, lane_idx % 16);
293
- sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0],
294
- uint32_values[0][i + 1], uint32_values[1][i + 1],
295
- smem_ptr);
296
  }
297
 
298
  // Wait tensor memory empty
@@ -321,15 +326,15 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
321
  cutlass::arch::fence_view_async_tmem_store();
322
 
323
  // Arrive for issuing MMAs
324
- tcgen05_before_thread_sync();
325
  full_cast_barriers[cast_stage_idx]->arrive();
326
  }
327
 
328
  // Intra-warp reduction and write back
329
  #pragma unroll
330
  for (uint32_t u = 0; u < 2; ++ u) {
331
- const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y);
332
- const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8;
333
  if (lane_idx % 4 == 0 and m_idx < shape_m)
334
  sqr_sum[m_offset + m_idx] = reduced_sum;
335
  }
 
4
 
5
  #include <cutlass/arch/barrier.h>
6
 
7
+ #include <deep_gemm/common/cute_tie.cuh>
8
+ #include <deep_gemm/common/math.cuh>
9
+ #include <deep_gemm/common/tma_copy.cuh>
10
  #include <deep_gemm/common/utils.cuh>
11
+ #include <deep_gemm/mma/sm100.cuh>
12
+ #include <deep_gemm/ptx/ld_st.cuh>
13
+ #include <deep_gemm/ptx/tcgen05.cuh>
14
+ #include <deep_gemm/ptx/utils.cuh>
15
 
16
  namespace deep_gemm {
17
 
 
 
18
  template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
19
+ CUTLASS_DEVICE
20
  uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) {
21
  // Calculate the index of the bank group to be written in the atom
22
+ const auto bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase);
23
 
24
  // Reshape the atom in another view and swizzle
25
  // - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)`
 
39
  uint32_t kSwizzleCDMode,
40
  uint32_t kNumStages,
41
  uint32_t kNumMMAThreads, uint32_t kNumCastAndReduceThreads>
42
+ CUTLASS_GLOBAL void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1)
43
  sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
44
  const __grid_constant__ cute::TmaDescriptor tensor_map_a,
45
  const __grid_constant__ cute::TmaDescriptor tensor_map_b,
 
60
 
61
  // Utils
62
  const auto warp_idx = cutlass::canonical_warp_idx_sync();
63
+ const auto lane_idx = ptx::get_lane_idx();
64
 
65
  // Align to 1024 bytes for swizzle-128B
66
  extern __shared__ __align__(1024) uint8_t smem_buffer[];
 
72
  DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
73
 
74
  // Real tensor memory size and offsets
75
+ constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<BLOCK_K * kNumCastStages + BLOCK_N>();
76
 
77
  // Prefetch TMA descriptors at the very beginning
78
  if (warp_idx == 0 and cute::elect_one_sync()) {
 
84
  // Data on shared memory (layout as ordered below)
85
  // Fill D/A/B pointers
86
  auto smem_cd = reinterpret_cast<float*>(smem_buffer);
87
+ auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
88
  return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
89
  });
90
+ auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
91
  return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
92
  });
93
 
94
  // Fill barriers
95
  auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
96
  kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
97
+ auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
98
+ auto full_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
99
+ auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
100
+ auto empty_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
101
  auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4;
102
 
103
  // Fill the tensor memory pointer
 
123
  }
124
  __syncthreads();
125
 
126
+ constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K);
127
  constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
128
  constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
129
  const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
 
133
  const uint32_t m_offset = shape_m * k_split_idx;
134
  const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
135
 
136
+ // Wait for primary kernel completion
137
+ cudaGridDependencySynchronize();
138
+
139
  // Dispatch warps into different roles
140
  if (warp_idx < kNumMMAThreads / 32) {
141
  // TMA load warp
 
150
  uint32_t k_idx = k_offset + s * BLOCK_K;
151
 
152
  // Issue TMAs
153
+ tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
154
+ tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
155
 
156
  // Arrive at full barriers
157
  constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
 
173
  const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
174
 
175
  DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
176
+ auto b_desc = mma::sm100::make_umma_desc<kMajorB, BLOCK_N, BLOCK_SWIZZLED_BK, kSwizzleBMode>(smem_b[0], 0, 0);
177
  const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
178
 
179
  // Checks for MMA instructions
 
190
  const auto& stage_idx = s % kNumStages;
191
  const auto& cast_stage_idx = s % kNumCastStages;
192
  full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1);
193
+ ptx::tcgen05_after_thread_sync();
194
 
195
  // Issue UMMA
196
  const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
 
199
  const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK;
200
  const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK;
201
  const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK;
202
+ b_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, float>(b_desc_base_lo, offset, in_atom_idx);
203
  umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc);
204
  }
205
 
 
223
 
224
  // Wait UMMA arrival
225
  tmem_full_barrier->wait(0);
226
+ ptx::tcgen05_after_thread_sync();
227
 
228
  // Load from tensor memory into registers, and write shared memory with STSM
229
  DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough");
 
244
  values[0], values[1], values[2], values[3]);
245
  cutlass::arch::fence_view_async_tmem_load();
246
  if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16))
247
+ ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
248
  if constexpr (BLOCK_M == 64)
249
  __syncwarp();
250
  }
 
295
  #pragma unroll
296
  for (uint32_t i = 0; i < kNumLoads; i += 2) {
297
  auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset<kSwizzleAMode>(i + lane_idx / 16, lane_idx % 16);
298
+ ptx::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0],
299
+ uint32_values[0][i + 1], uint32_values[1][i + 1],
300
+ smem_ptr);
301
  }
302
 
303
  // Wait tensor memory empty
 
326
  cutlass::arch::fence_view_async_tmem_store();
327
 
328
  // Arrive for issuing MMAs
329
+ ptx::tcgen05_before_thread_sync();
330
  full_cast_barriers[cast_stage_idx]->arrive();
331
  }
332
 
333
  // Intra-warp reduction and write back
334
  #pragma unroll
335
  for (uint32_t u = 0; u < 2; ++ u) {
336
+ const auto reduced_sum = math::warp_reduce_sum<4>(sum[u].x + sum[u].y);
337
+ const auto m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8;
338
  if (lane_idx % 4 == 0 and m_idx < shape_m)
339
  sqr_sum[m_offset + m_idx] = reduced_sum;
340
  }
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh CHANGED
@@ -11,14 +11,19 @@
11
  #include <cute/arch/copy_sm90_tma.hpp>
12
  #include <cute/arch/mma_sm100_desc.hpp>
13
 
 
14
  #include <deep_gemm/common/utils.cuh>
15
- #include <deep_gemm/common/scheduler.cuh>
16
- #include <deep_gemm/common/sm90_utils.cuh>
 
 
 
 
 
 
17
 
18
  namespace deep_gemm {
19
 
20
- using namespace deep_gemm::sm90;
21
-
22
  template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
23
  uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
24
  uint32_t kNumGroups,
@@ -30,7 +35,7 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
30
  uint32_t kNumSMs,
31
  GemmType kGemmType, bool kWithAccumulation,
32
  typename cd_dtype_t>
33
- __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
34
  sm90_bf16_gemm_impl(int* grouped_layout,
35
  uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
36
  const __grid_constant__ cute::TmaDescriptor tensor_map_a,
@@ -51,7 +56,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
51
  constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge;
52
 
53
  // Types
54
- using WGMMA = typename BF16MMASelector<BLOCK_N, kMajorA, kMajorB>::type;
55
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
56
  DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
57
 
@@ -61,7 +66,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
61
  shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
62
 
63
  // Shared memory
64
- static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(cd_dtype_t)), 1024u);
65
  static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
66
  static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
67
 
@@ -71,7 +76,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
71
 
72
  // Configs
73
  const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
74
- const uint32_t lane_idx = get_lane_idx();
75
 
76
  // Prefetch TMA descriptors at the very beginning
77
  if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
@@ -88,17 +93,17 @@ sm90_bf16_gemm_impl(int* grouped_layout,
88
 
89
  // D/A/B shared memory
90
  auto smem_d = reinterpret_cast<cd_dtype_t*>(smem_buffer);
91
- auto smem_a = PatternVisitor([&](const uint32_t& i) {
92
  return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
93
  });
94
- auto smem_b = PatternVisitor([&](const uint32_t& i) {
95
  return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
96
  });
97
 
98
  // Fill barriers
99
  auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
100
- auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
101
- auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
102
 
103
  // Initialize barriers
104
  if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
@@ -119,9 +124,12 @@ sm90_bf16_gemm_impl(int* grouped_layout,
119
  constexpr uint32_t kNumTMARegisters = 48;
120
  constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224;
121
 
 
 
 
122
  // Block scheduler
123
  uint32_t m_block_idx, n_block_idx;
124
- auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
125
 
126
  // Pipeline and TMA phases
127
  uint32_t stage_idx = 0, phase = 0;
@@ -151,7 +159,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
151
  const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
152
  DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
153
 
154
- const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
155
  for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
156
  // Wait consumer release
157
  empty_barriers[stage_idx]->wait(phase ^ 1);
@@ -159,31 +167,30 @@ sm90_bf16_gemm_impl(int* grouped_layout,
159
  constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
160
  auto& full_barrier = *full_barriers[stage_idx];
161
 
162
- const auto m_idx = scheduler.template get_global_idx<kWithGroupOffsetA, IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
163
- const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx);
164
 
165
  DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
166
- uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> (
167
  shape_k, BLOCK_K, k_block_idx, m_block_idx);
168
- uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> (
169
  shape_k, BLOCK_K, k_block_idx, m_block_idx);
170
 
171
  // Issue TMAs
172
  constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
173
  const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
174
  if constexpr (kMajorA == cute::UMMA::Major::K)
175
- tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
176
  &tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx);
177
  if constexpr (kMajorA == cute::UMMA::Major::MN)
178
- tma_copy<BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
179
  &tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx);
180
  if constexpr (kMajorB == cute::UMMA::Major::K)
181
- tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
182
  &tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx);
183
  if constexpr (kMajorB == cute::UMMA::Major::MN)
184
- tma_copy<BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
185
  &tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx);
186
-
187
  full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
188
  }
189
  }
@@ -203,8 +210,8 @@ sm90_bf16_gemm_impl(int* grouped_layout,
203
 
204
  // Merged stages only happens in NT normal GEMM cases
205
  constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
206
- auto a_desc = make_gmma_desc<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], math_wg_idx * WGMMA::M, 0);
207
- auto b_desc = make_gmma_desc<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
208
  const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
209
  const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
210
 
@@ -229,10 +236,10 @@ sm90_bf16_gemm_impl(int* grouped_layout,
229
  };
230
 
231
  // TODO: remove some useless computation for unaligned Ms
232
- const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
233
  for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
234
- const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
235
- const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
236
 
237
  // Wait TMA arrivals
238
  full_barriers[stage_idx]->wait(phase);
@@ -240,26 +247,26 @@ sm90_bf16_gemm_impl(int* grouped_layout,
240
  // Commit WGMMA instructions
241
  #pragma unroll
242
  for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
243
- warpgroup_fence_operand(accum[i]);
244
- warpgroup_arrive();
245
  #pragma unroll
246
  for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
247
  auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
248
  #pragma unroll
249
  for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
250
- const uint32_t& atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K;
251
- a_desc.reg32_[0] = advance_gmma_desc_lo<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode, nv_bfloat16>(
252
  a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K);
253
- b_desc.reg32_[0] = advance_gmma_desc_lo<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode, nv_bfloat16>(
254
  b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K);
255
  WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1);
256
  }
257
  }
258
- warpgroup_commit_batch();
259
  #pragma unroll
260
  for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
261
- warpgroup_fence_operand(accum[i]);
262
- warpgroup_wait<0>();
263
 
264
  // Notify barrier arrival
265
  empty_barrier_arrive(stage_idx);
@@ -324,7 +331,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
324
  }
325
 
326
  // NOTES: only 16 lanes' addresses are used
327
- SM90_U32x2_STSM_N<nv_bfloat162>::copy(
328
  __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
329
  __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
330
  smem_ptr
@@ -341,8 +348,8 @@ sm90_bf16_gemm_impl(int* grouped_layout,
341
  auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2);
342
  #pragma unroll
343
  for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
344
- st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]));
345
- st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]));
346
  }
347
  }
348
  }
@@ -350,7 +357,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
350
  cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0);
351
 
352
  // Use TMA store to write back to global memory
353
- const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
354
  DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
355
  if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
356
  auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
 
11
  #include <cute/arch/copy_sm90_tma.hpp>
12
  #include <cute/arch/mma_sm100_desc.hpp>
13
 
14
+ #include <deep_gemm/common/math.cuh>
15
  #include <deep_gemm/common/utils.cuh>
16
+ #include <deep_gemm/common/tma_copy.cuh>
17
+ #include <deep_gemm/common/types.cuh>
18
+ #include <deep_gemm/mma/sm90.cuh>
19
+ #include <deep_gemm/epilogue/transform.cuh>
20
+ #include <deep_gemm/ptx/ld_st.cuh>
21
+ #include <deep_gemm/ptx/utils.cuh>
22
+ #include <deep_gemm/ptx/wgmma.cuh>
23
+ #include <deep_gemm/scheduler/gemm.cuh>
24
 
25
  namespace deep_gemm {
26
 
 
 
27
  template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
28
  uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
29
  uint32_t kNumGroups,
 
35
  uint32_t kNumSMs,
36
  GemmType kGemmType, bool kWithAccumulation,
37
  typename cd_dtype_t>
38
+ CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
39
  sm90_bf16_gemm_impl(int* grouped_layout,
40
  uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
41
  const __grid_constant__ cute::TmaDescriptor tensor_map_a,
 
56
  constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge;
57
 
58
  // Types
59
+ using WGMMA = typename mma::sm90::BF16MMASelector<BLOCK_N, kMajorA, kMajorB>::type;
60
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
61
  DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
62
 
 
66
  shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
67
 
68
  // Shared memory
69
+ static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(cd_dtype_t)), 1024u);
70
  static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
71
  static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
72
 
 
76
 
77
  // Configs
78
  const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
79
+ const uint32_t lane_idx = ptx::get_lane_idx();
80
 
81
  // Prefetch TMA descriptors at the very beginning
82
  if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
 
93
 
94
  // D/A/B shared memory
95
  auto smem_d = reinterpret_cast<cd_dtype_t*>(smem_buffer);
96
+ auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
97
  return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
98
  });
99
+ auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
100
  return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
101
  });
102
 
103
  // Fill barriers
104
  auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
105
+ auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
106
+ auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
107
 
108
  // Initialize barriers
109
  if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
 
124
  constexpr uint32_t kNumTMARegisters = 48;
125
  constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224;
126
 
127
+ // Wait for primary kernel completion
128
+ cudaGridDependencySynchronize();
129
+
130
  // Block scheduler
131
  uint32_t m_block_idx, n_block_idx;
132
+ auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
133
 
134
  // Pipeline and TMA phases
135
  uint32_t stage_idx = 0, phase = 0;
 
159
  const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
160
  DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
161
 
162
+ const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
163
  for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
164
  // Wait consumer release
165
  empty_barriers[stage_idx]->wait(phase ^ 1);
 
167
  constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
168
  auto& full_barrier = *full_barriers[stage_idx];
169
 
170
+ const auto m_idx = scheduler.template get_global_idx<kWithGroupOffsetA, sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
171
+ const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx);
172
 
173
  DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
174
+ uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> (
175
  shape_k, BLOCK_K, k_block_idx, m_block_idx);
176
+ uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> (
177
  shape_k, BLOCK_K, k_block_idx, m_block_idx);
178
 
179
  // Issue TMAs
180
  constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
181
  const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
182
  if constexpr (kMajorA == cute::UMMA::Major::K)
183
+ tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
184
  &tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx);
185
  if constexpr (kMajorA == cute::UMMA::Major::MN)
186
+ tma::copy<BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
187
  &tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx);
188
  if constexpr (kMajorB == cute::UMMA::Major::K)
189
+ tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
190
  &tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx);
191
  if constexpr (kMajorB == cute::UMMA::Major::MN)
192
+ tma::copy<BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
193
  &tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx);
 
194
  full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
195
  }
196
  }
 
210
 
211
  // Merged stages only happens in NT normal GEMM cases
212
  constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
213
+ auto a_desc = mma::sm90::make_gmma_desc<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], math_wg_idx * WGMMA::M, 0);
214
+ auto b_desc = mma::sm90::make_gmma_desc<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
215
  const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
216
  const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
217
 
 
236
  };
237
 
238
  // TODO: remove some useless computation for unaligned Ms
239
+ const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
240
  for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
241
+ const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
242
+ const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
243
 
244
  // Wait TMA arrivals
245
  full_barriers[stage_idx]->wait(phase);
 
247
  // Commit WGMMA instructions
248
  #pragma unroll
249
  for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
250
+ ptx::warpgroup_fence_operand(accum[i]);
251
+ ptx::warpgroup_arrive();
252
  #pragma unroll
253
  for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
254
  auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
255
  #pragma unroll
256
  for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
257
+ const uint32_t atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K;
258
+ a_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode, nv_bfloat16>(
259
  a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K);
260
+ b_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode, nv_bfloat16>(
261
  b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K);
262
  WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1);
263
  }
264
  }
265
+ ptx::warpgroup_commit_batch();
266
  #pragma unroll
267
  for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
268
+ ptx::warpgroup_fence_operand(accum[i]);
269
+ ptx::warpgroup_wait<0>();
270
 
271
  // Notify barrier arrival
272
  empty_barrier_arrive(stage_idx);
 
331
  }
332
 
333
  // NOTES: only 16 lanes' addresses are used
334
+ ptx::SM90_U32x2_STSM_N<nv_bfloat162>::copy(
335
  __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
336
  __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
337
  smem_ptr
 
348
  auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2);
349
  #pragma unroll
350
  for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
351
+ ptx::st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]));
352
+ ptx::st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]));
353
  }
354
  }
355
  }
 
357
  cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0);
358
 
359
  // Use TMA store to write back to global memory
360
+ const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
361
  DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
362
  if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
363
  auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh CHANGED
@@ -4,26 +4,32 @@
4
  #include <cutlass/arch/barrier.h>
5
  #include <cutlass/arch/reg_reconfig.h>
6
 
 
7
  #include <deep_gemm/common/utils.cuh>
8
- #include <deep_gemm/common/sm90_utils.cuh>
 
 
 
 
 
 
 
9
 
10
  namespace deep_gemm {
11
 
12
- using namespace deep_gemm::sm90;
13
-
14
  template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
15
  uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
16
  uint32_t kSplitFactor,
17
  uint32_t kNumStages,
18
  uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
19
- __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
20
  sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
21
  const __grid_constant__ cute::TmaDescriptor tensor_map_a,
22
  const __grid_constant__ cute::TmaDescriptor tensor_map_b,
23
  float *d) {
24
  #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
25
  // Types
26
- using WGMMA = typename BF16MMASelector<BLOCK_N>::type;
27
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
28
  DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
29
 
@@ -33,7 +39,7 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
33
 
34
  // Configs
35
  const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
36
- const uint32_t lane_idx = get_lane_idx();
37
  DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M");
38
  DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads");
39
  DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads");
@@ -48,17 +54,17 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
48
  // Align to 1024 bytes for swizzle-128B
49
  // Fill shared memory pointers
50
  extern __shared__ __align__(1024) uint8_t smem_buffer[];
51
- auto smem_a = PatternVisitor([&](const uint32_t& i) {
52
  return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE));
53
  });
54
- auto smem_b = PatternVisitor([&](const uint32_t& i) {
55
  return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
56
  });
57
 
58
  // Fill barriers
59
  auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
60
- auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
61
- auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
62
 
63
  // Initialize barriers
64
  if (warp_idx == 1 and cute::elect_one_sync()) {
@@ -80,14 +86,17 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
80
  constexpr uint32_t kNumMathRegisters = 232;
81
 
82
  // Block indices
83
- const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
84
- const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
85
  const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
86
  const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
87
  const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
88
  const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
89
  const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
90
 
 
 
 
91
  if (warp_idx >= kNumMathThreads / 32) {
92
  // TMA warp-group for loading data
93
  cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
@@ -98,18 +107,18 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
98
  #pragma unroll
99
  for (uint32_t s = 0; s < num_total_stages; ++ s) {
100
  // Wait consumer release
101
- const auto& stage_idx = s % kNumStages;
102
  empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1);
103
 
104
  auto& full_barrier = *full_barriers[stage_idx];
105
- const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
106
- const uint32_t& k_idx = sk_idx % SHAPE_K;
107
- const uint32_t& s_idx = sk_idx / SHAPE_K;
108
 
109
  constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16);
110
- tma_copy<BLOCK_K, BLOCK_M, kSwizzle>(
111
  &tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1);
112
- tma_copy<BLOCK_K, BLOCK_N, kSwizzle>(
113
  &tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1);
114
  full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
115
  }
@@ -125,32 +134,32 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
125
  // Launch MMAs
126
  for (uint32_t s = 0; s < num_total_stages; ++ s) {
127
  // Wait TMA arrivals
128
- const auto& stage_idx = s % kNumStages;
129
  full_barriers[stage_idx]->wait((s / kNumStages) & 1);
130
 
131
  // Commit WGMMA instructions
132
  #pragma unroll
133
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
134
- warpgroup_fence_operand(accum[i]);
135
- warpgroup_arrive();
136
  #pragma unroll
137
  for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
138
- auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1);
139
- auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
140
  WGMMA::wgmma(desc_a, desc_b, accum, 1);
141
  }
142
- warpgroup_commit_batch();
143
  #pragma unroll
144
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
145
- warpgroup_fence_operand(accum[i]);
146
- warpgroup_wait<0>();
147
 
148
  // Notify barrier arrival at the last warpgroup wave
149
  empty_barriers[stage_idx]->arrive();
150
  }
151
 
152
- const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4;
153
- const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2;
154
  #pragma unroll
155
  for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
156
  if (col + i * 8 >= SHAPE_N)
 
4
  #include <cutlass/arch/barrier.h>
5
  #include <cutlass/arch/reg_reconfig.h>
6
 
7
+ #include <deep_gemm/common/math.cuh>
8
  #include <deep_gemm/common/utils.cuh>
9
+ #include <deep_gemm/common/tma_copy.cuh>
10
+ #include <deep_gemm/common/types.cuh>
11
+ #include <deep_gemm/mma/sm90.cuh>
12
+ #include <deep_gemm/epilogue/transform.cuh>
13
+ #include <deep_gemm/ptx/ld_st.cuh>
14
+ #include <deep_gemm/ptx/utils.cuh>
15
+ #include <deep_gemm/ptx/wgmma.cuh>
16
+ #include <deep_gemm/scheduler/gemm.cuh>
17
 
18
  namespace deep_gemm {
19
 
 
 
20
  template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
21
  uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
22
  uint32_t kSplitFactor,
23
  uint32_t kNumStages,
24
  uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
25
+ CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
26
  sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
27
  const __grid_constant__ cute::TmaDescriptor tensor_map_a,
28
  const __grid_constant__ cute::TmaDescriptor tensor_map_b,
29
  float *d) {
30
  #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
31
  // Types
32
+ using WGMMA = typename mma::sm90::BF16MMASelector<BLOCK_N>::type;
33
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
34
  DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
35
 
 
39
 
40
  // Configs
41
  const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
42
+ const uint32_t lane_idx = ptx::get_lane_idx();
43
  DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M");
44
  DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads");
45
  DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads");
 
54
  // Align to 1024 bytes for swizzle-128B
55
  // Fill shared memory pointers
56
  extern __shared__ __align__(1024) uint8_t smem_buffer[];
57
+ auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
58
  return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE));
59
  });
60
+ auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
61
  return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
62
  });
63
 
64
  // Fill barriers
65
  auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
66
+ auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
67
+ auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
68
 
69
  // Initialize barriers
70
  if (warp_idx == 1 and cute::elect_one_sync()) {
 
86
  constexpr uint32_t kNumMathRegisters = 232;
87
 
88
  // Block indices
89
+ const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N);
90
+ const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M);
91
  const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
92
  const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
93
  const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
94
  const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
95
  const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
96
 
97
+ // Wait for primary kernel completion
98
+ cudaGridDependencySynchronize();
99
+
100
  if (warp_idx >= kNumMathThreads / 32) {
101
  // TMA warp-group for loading data
102
  cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
 
107
  #pragma unroll
108
  for (uint32_t s = 0; s < num_total_stages; ++ s) {
109
  // Wait consumer release
110
+ const auto stage_idx = s % kNumStages;
111
  empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1);
112
 
113
  auto& full_barrier = *full_barriers[stage_idx];
114
+ const uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
115
+ const uint32_t k_idx = sk_idx % SHAPE_K;
116
+ const uint32_t s_idx = sk_idx / SHAPE_K;
117
 
118
  constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16);
119
+ tma::copy<BLOCK_K, BLOCK_M, kSwizzle>(
120
  &tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1);
121
+ tma::copy<BLOCK_K, BLOCK_N, kSwizzle>(
122
  &tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1);
123
  full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
124
  }
 
134
  // Launch MMAs
135
  for (uint32_t s = 0; s < num_total_stages; ++ s) {
136
  // Wait TMA arrivals
137
+ const auto stage_idx = s % kNumStages;
138
  full_barriers[stage_idx]->wait((s / kNumStages) & 1);
139
 
140
  // Commit WGMMA instructions
141
  #pragma unroll
142
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
143
+ ptx::warpgroup_fence_operand(accum[i]);
144
+ ptx::warpgroup_arrive();
145
  #pragma unroll
146
  for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
147
+ auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1);
148
+ auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
149
  WGMMA::wgmma(desc_a, desc_b, accum, 1);
150
  }
151
+ ptx::warpgroup_commit_batch();
152
  #pragma unroll
153
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
154
+ ptx::warpgroup_fence_operand(accum[i]);
155
+ ptx::warpgroup_wait<0>();
156
 
157
  // Notify barrier arrival at the last warpgroup wave
158
  empty_barriers[stage_idx]->arrive();
159
  }
160
 
161
+ const auto row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4;
162
+ const auto col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2;
163
  #pragma unroll
164
  for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
165
  if (col + i * 8 >= SHAPE_N)
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh CHANGED
@@ -6,18 +6,26 @@
6
  #include <cutlass/arch/barrier.h>
7
  #include <cutlass/arch/reg_reconfig.h>
8
 
 
9
  #include <cute/arch/cluster_sm90.hpp>
10
  #include <cute/arch/copy_sm90_desc.hpp>
11
  #include <cute/arch/copy_sm90_tma.hpp>
12
 
 
 
13
  #include <deep_gemm/common/utils.cuh>
14
- #include <deep_gemm/common/scheduler.cuh>
15
- #include <deep_gemm/common/sm90_utils.cuh>
 
 
 
 
 
 
 
16
 
17
  namespace deep_gemm {
18
 
19
- using namespace deep_gemm::sm90;
20
-
21
  template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
22
  uint32_t kNumGroups,
23
  uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
@@ -27,7 +35,7 @@ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
27
  uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
28
  uint32_t kNumSMs,
29
  GemmType kGemmType, typename cd_dtype_t>
30
- __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
31
  sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
32
  int* grouped_layout,
33
  cute::TmaDescriptor* tensor_map_buffer,
@@ -45,7 +53,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
45
  DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type");
46
 
47
  // Types
48
- using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
49
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
50
  DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
51
 
@@ -55,13 +63,13 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
55
  shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
56
 
57
  // Shared memory
58
- static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0);
59
  static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
60
  static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
61
  static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
62
  static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
63
  static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
64
- static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u);
65
  DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment");
66
 
67
  // Configs
@@ -83,47 +91,41 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
83
  DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
84
 
85
  // Tensor maps on shared and global memory
86
- auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) {
87
- return reinterpret_cast<cute::TmaDescriptor*>(smem_buffer + static_cast<uint32_t>(sizeof(cute::TmaDescriptor)) * i);
88
- });
89
- auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) {
90
- return reinterpret_cast<cute::TmaDescriptor*>(smem_buffer + static_cast<uint32_t>(sizeof(cute::TmaDescriptor)) * (2 + i));
91
- });
92
- auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; });
93
- auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; });
94
 
95
  // Data on shared memory
96
  auto smem_d = reinterpret_cast<float*>(smem_buffer + SMEM_TENSOR_MAP_SIZE);
97
- auto smem_a = PatternVisitor([&](const uint32_t& i) {
98
- return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE));
99
  });
100
- auto smem_b = PatternVisitor([&](const uint32_t& i) {
101
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
102
  });
103
  constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
104
- auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
105
  return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE));
106
  });
107
- auto smem_sfb = PatternVisitor([&](const uint32_t& i) {
108
  return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE));
109
  });
110
 
111
  // Barriers on shared memory
112
  constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE);
113
- auto full_barriers = PatternVisitor([&](const uint32_t& i) {
114
  return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast<uint32_t>(sizeof(Barrier))));
115
  });
116
- auto empty_barriers = PatternVisitor([&](const uint32_t& i) {
117
  return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast<uint32_t>(sizeof(Barrier))));
118
  });
119
 
120
  if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
121
  // Load tensormap A/B to shared memory
122
  if constexpr (kGemmType == GemmType::KGroupedContiguous) {
123
- *smem_tensor_map_a[0] = tensor_map_a_base;
124
- *smem_tensor_map_a[1] = tensor_map_a_base;
125
- *smem_tensor_map_b[0] = tensor_map_b_base;
126
- *smem_tensor_map_b[1] = tensor_map_b_base;
127
  }
128
 
129
  // Initialize barriers
@@ -149,12 +151,15 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
149
  constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24);
150
  constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240);
151
 
 
 
 
152
  // Block scheduler
153
  uint32_t m_block_idx, n_block_idx;
154
- auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs, 128u>(shape_m, shape_n, shape_k, grouped_layout);
155
 
156
  // TMA and MMA pipeline
157
- const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple<uint32_t, uint32_t> {
158
  return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase
159
  };
160
  uint32_t iter_idx = 0;
@@ -165,9 +170,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
165
 
166
  // NOTES: only one thread (or warp) will be used
167
  if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
168
- const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base;
169
- const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base;
170
- uint32_t last_group_idx = kNumGroups, sum_k = 0;
171
 
172
  // Persistently schedule over blocks
173
  while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
@@ -177,35 +180,27 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
177
  const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
178
  const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
179
  DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
180
-
181
- const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
182
- const uint32_t& m_idx = m_block_idx * BLOCK_M;
183
- const uint32_t& n_idx = n_block_idx * BLOCK_N;
184
-
185
- if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) {
186
- const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1;
187
- const uint32_t& next_stage_idx = stage_idx ^ 1;
188
  last_group_idx = scheduler.current_group_idx;
189
 
190
- // Prepare next tensor map
191
- sum_k += scheduler.current_shape_k;
192
- if (scheduler.next_group_idx < kNumGroups) {
193
- tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast<uint64_t>(sum_k) * shape_m);
194
- tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast<uint64_t>(sum_k) * shape_n);
195
- tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
196
- tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
197
- *(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]);
198
- *(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]);
199
- tensor_map_release_cta();
200
- }
201
-
202
- // Get current tensor map
203
- if (scheduler.current_num_valid_groups > 0) {
204
- tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]);
205
- tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]);
206
- current_tensor_map_a = gmem_tensor_map_a[stage_idx];
207
- current_tensor_map_b = gmem_tensor_map_b[stage_idx];
208
- }
209
  }
210
 
211
  #pragma unroll kNumPipelineUnrolls
@@ -216,12 +211,14 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
216
 
217
  // Issue TMA
218
  auto& full_barrier = *full_barriers[stage_idx];
219
- const uint32_t& k_idx = k_block_idx * BLOCK_K;
220
- const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx;
221
- tma_copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a);
222
- tma_copy<BLOCK_N, BLOCK_K, 0>(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b);
223
- tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(current_tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a);
224
- tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(current_tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b);
 
 
225
  full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE);
226
  }
227
  }
@@ -248,9 +245,9 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
248
  while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
249
  // Accumulation for WGMMA or CUDA promotion
250
  DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes");
251
- const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
252
- const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0);
253
- const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K);
254
  float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
255
  float2 scales_b[WGMMA::kNumAccum / 4];
256
 
@@ -272,30 +269,30 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
272
 
273
  // Read A scales
274
  // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
275
- auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0);
276
- auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1);
277
 
278
  // Read B scales
279
  #pragma unroll
280
  for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
281
- scales_b[i] = ld_shared(reinterpret_cast<float2*>(smem_sfb[stage_idx] + i * 8 + col_idx * 2));
282
 
283
  // Commit WGMMA instructions
284
  #pragma unroll
285
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
286
- warpgroup_fence_operand(accum[i]);
287
- warpgroup_arrive();
288
  #pragma unroll
289
  for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
290
- auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
291
- auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
292
  WGMMA::wgmma(desc_a, desc_b, accum, k);
293
  }
294
- warpgroup_commit_batch();
295
  #pragma unroll
296
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
297
- warpgroup_fence_operand(accum[i]);
298
- warpgroup_wait<0>();
299
 
300
  // Notify barrier arrival
301
  empty_barrier_arrive(stage_idx);
@@ -318,12 +315,12 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
318
  cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
319
 
320
  // Store to D shared memory
321
- const auto& smem_d_0 = reinterpret_cast<float2*>(smem_d + r_0 * BLOCK_N + col_idx * 2);
322
- const auto& smem_d_1 = reinterpret_cast<float2*>(smem_d + r_1 * BLOCK_N + col_idx * 2);
323
  #pragma unroll
324
  for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
325
- st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]});
326
- st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]});
327
  }
328
  cute::tma_store_fence();
329
  cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
 
6
  #include <cutlass/arch/barrier.h>
7
  #include <cutlass/arch/reg_reconfig.h>
8
 
9
+ #include <cute/int_tuple.hpp>
10
  #include <cute/arch/cluster_sm90.hpp>
11
  #include <cute/arch/copy_sm90_desc.hpp>
12
  #include <cute/arch/copy_sm90_tma.hpp>
13
 
14
+ #include <deep_gemm/common/cute_tie.cuh>
15
+ #include <deep_gemm/common/math.cuh>
16
  #include <deep_gemm/common/utils.cuh>
17
+ #include <deep_gemm/common/tma_copy.cuh>
18
+ #include <deep_gemm/common/types.cuh>
19
+ #include <deep_gemm/mma/sm90.cuh>
20
+ #include <deep_gemm/epilogue/transform.cuh>
21
+ #include <deep_gemm/ptx/ld_st.cuh>
22
+ #include <deep_gemm/ptx/tma.cuh>
23
+ #include <deep_gemm/ptx/utils.cuh>
24
+ #include <deep_gemm/ptx/wgmma.cuh>
25
+ #include <deep_gemm/scheduler/gemm.cuh>
26
 
27
  namespace deep_gemm {
28
 
 
 
29
  template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
30
  uint32_t kNumGroups,
31
  uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
 
35
  uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
36
  uint32_t kNumSMs,
37
  GemmType kGemmType, typename cd_dtype_t>
38
+ CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
39
  sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
40
  int* grouped_layout,
41
  cute::TmaDescriptor* tensor_map_buffer,
 
53
  DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type");
54
 
55
  // Types
56
+ using WGMMA = typename mma::sm90::FP8MMASelector<BLOCK_N>::type;
57
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
58
  DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
59
 
 
63
  shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
64
 
65
  // Shared memory
66
+ static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 2 : 0);
67
  static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
68
  static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
69
  static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
70
  static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
71
  static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
72
+ static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u);
73
  DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment");
74
 
75
  // Configs
 
91
  DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
92
 
93
  // Tensor maps on shared and global memory
94
+ auto smem_tensor_map_a = reinterpret_cast<cute::TmaDescriptor*>(smem_buffer);
95
+ auto smem_tensor_map_b = smem_tensor_map_a + 1;
96
+ auto gmem_tensor_map_a = tensor_map_buffer + blockIdx.x * 2;
97
+ auto gmem_tensor_map_b = gmem_tensor_map_a + 1;
 
 
 
 
98
 
99
  // Data on shared memory
100
  auto smem_d = reinterpret_cast<float*>(smem_buffer + SMEM_TENSOR_MAP_SIZE);
101
+ auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
102
+ return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE));
103
  });
104
+ auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
105
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
106
  });
107
  constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
108
+ auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) {
109
  return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE));
110
  });
111
+ auto smem_sfb = utils::PatternVisitor([&](const uint32_t& i) {
112
  return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE));
113
  });
114
 
115
  // Barriers on shared memory
116
  constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE);
117
+ auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) {
118
  return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast<uint32_t>(sizeof(Barrier))));
119
  });
120
+ auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) {
121
  return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast<uint32_t>(sizeof(Barrier))));
122
  });
123
 
124
  if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
125
  // Load tensormap A/B to shared memory
126
  if constexpr (kGemmType == GemmType::KGroupedContiguous) {
127
+ *smem_tensor_map_a = tensor_map_a_base;
128
+ *smem_tensor_map_b = tensor_map_b_base;
 
 
129
  }
130
 
131
  // Initialize barriers
 
151
  constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24);
152
  constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240);
153
 
154
+ // Wait for primary kernel completion
155
+ cudaGridDependencySynchronize();
156
+
157
  // Block scheduler
158
  uint32_t m_block_idx, n_block_idx;
159
+ auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs, 128u>(shape_m, shape_n, shape_k, grouped_layout);
160
 
161
  // TMA and MMA pipeline
162
+ const auto get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple<uint32_t, uint32_t> {
163
  return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase
164
  };
165
  uint32_t iter_idx = 0;
 
170
 
171
  // NOTES: only one thread (or warp) will be used
172
  if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
173
+ uint32_t last_group_idx = kNumGroups;
 
 
174
 
175
  // Persistently schedule over blocks
176
  while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
 
180
  const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
181
  const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
182
  DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
183
+
184
+ const uint32_t num_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
185
+ const uint32_t m_idx = m_block_idx * BLOCK_M;
186
+ const uint32_t n_idx = n_block_idx * BLOCK_N;
187
+
188
+ if (kGemmType == GemmType::KGroupedContiguous && last_group_idx != scheduler.current_group_idx) {
 
 
189
  last_group_idx = scheduler.current_group_idx;
190
 
191
+ // Directly update current tensor map
192
+ const uint64_t current_k_offset = scheduler.current_k_cumsum;
193
+ ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_a, gmem_a_ptr + current_k_offset * shape_m);
194
+ ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_b, gmem_b_ptr + current_k_offset * shape_n);
195
+ ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a, scheduler.current_shape_k, scheduler.current_shape_k);
196
+ ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b, scheduler.current_shape_k, scheduler.current_shape_k);
197
+ *(gmem_tensor_map_a) = *(smem_tensor_map_a);
198
+ *(gmem_tensor_map_b) = *(smem_tensor_map_b);
199
+ ptx::tensor_map_release_gpu();
200
+
201
+ // Immediately acquire current tensor map
202
+ ptx::tensor_map_acquire_gpu(gmem_tensor_map_a);
203
+ ptx::tensor_map_acquire_gpu(gmem_tensor_map_b);
 
 
 
 
 
 
204
  }
205
 
206
  #pragma unroll kNumPipelineUnrolls
 
211
 
212
  // Issue TMA
213
  auto& full_barrier = *full_barriers[stage_idx];
214
+ const uint32_t k_idx = k_block_idx * BLOCK_K;
215
+ const uint32_t sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx;
216
+ const auto tensor_map_a_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_a : &tensor_map_a_base);
217
+ const auto tensor_map_b_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_b : &tensor_map_b_base);
218
+ tma::copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a);
219
+ tma::copy<BLOCK_N, BLOCK_K, 0>(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b);
220
+ tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(tensor_map_a_ptr, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a);
221
+ tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(tensor_map_b_ptr, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b);
222
  full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE);
223
  }
224
  }
 
245
  while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
246
  // Accumulation for WGMMA or CUDA promotion
247
  DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes");
248
+ const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
249
+ const uint32_t current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0);
250
+ const uint32_t num_k_blocks = math::ceil_div(current_shape_k, BLOCK_K);
251
  float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
252
  float2 scales_b[WGMMA::kNumAccum / 4];
253
 
 
269
 
270
  // Read A scales
271
  // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
272
+ auto scale_a_0 = ptx::ld_shared(smem_sfa[stage_idx] + r_0);
273
+ auto scale_a_1 = ptx::ld_shared(smem_sfa[stage_idx] + r_1);
274
 
275
  // Read B scales
276
  #pragma unroll
277
  for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
278
+ scales_b[i] = ptx::ld_shared(reinterpret_cast<float2*>(smem_sfb[stage_idx] + i * 8 + col_idx * 2));
279
 
280
  // Commit WGMMA instructions
281
  #pragma unroll
282
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
283
+ ptx::warpgroup_fence_operand(accum[i]);
284
+ ptx::warpgroup_arrive();
285
  #pragma unroll
286
  for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
287
+ auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
288
+ auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
289
  WGMMA::wgmma(desc_a, desc_b, accum, k);
290
  }
291
+ ptx::warpgroup_commit_batch();
292
  #pragma unroll
293
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
294
+ ptx::warpgroup_fence_operand(accum[i]);
295
+ ptx::warpgroup_wait<0>();
296
 
297
  // Notify barrier arrival
298
  empty_barrier_arrive(stage_idx);
 
315
  cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
316
 
317
  // Store to D shared memory
318
+ const auto smem_d_0 = reinterpret_cast<float2*>(smem_d + r_0 * BLOCK_N + col_idx * 2);
319
+ const auto smem_d_1 = reinterpret_cast<float2*>(smem_d + r_1 * BLOCK_N + col_idx * 2);
320
  #pragma unroll
321
  for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
322
+ ptx::st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]});
323
+ ptx::st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]});
324
  }
325
  cute::tma_store_fence();
326
  cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh CHANGED
@@ -10,17 +10,21 @@
10
  #include <cute/arch/copy_sm90_desc.hpp>
11
  #include <cute/arch/copy_sm90_tma.hpp>
12
 
13
- #include <deep_gemm/common/epilogue_utils.cuh>
14
  #include <deep_gemm/common/utils.cuh>
15
- #include <deep_gemm/common/scheduler.cuh>
16
- #include <deep_gemm/common/sm90_utils.cuh>
 
 
 
 
 
 
17
 
18
  namespace deep_gemm {
19
 
20
- using namespace deep_gemm::sm90;
21
-
22
  template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd, typename func_t>
23
- __device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) {
24
  if (num_former_iters == kNumFormerIters) {
25
  func(cute::Int<kNumFormerIters>{});
26
  return;
@@ -35,12 +39,12 @@ template <cute::UMMA::Major kMajorSFB,
35
  uint32_t kNumGroups,
36
  uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
37
  uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleDMode,
38
- uint32_t kNumStages, uint32_t kNumLastStages,
39
  uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
40
  uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
41
  uint32_t kNumSMs, GemmType kGemmType,
42
  typename epilogue_type_t>
43
- __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
44
  sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
45
  uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
46
  const __grid_constant__ cute::TmaDescriptor tensor_map_a,
@@ -50,10 +54,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
50
  #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
51
  // Scaling checks
52
  DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
53
- DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
 
 
54
 
55
  // Types
56
- using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
57
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
58
  DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
59
 
@@ -64,23 +70,23 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
64
 
65
  // Shared memory
66
  static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
67
- static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(__nv_bfloat16)), 1024u);
68
  static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
69
  static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
70
  static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
71
- static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u);
72
- const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K);
73
- const uint32_t& shape_n_sfb = ceil_div(shape_n, BLOCK_K);
74
- const uint32_t& smem_sfb_size = align<uint32_t>(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier));
75
 
76
  // NOTES: Make sure we have enough shared memory for WGMMA padding
77
  static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3);
78
  DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA");
79
 
80
  // Configs
81
- const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K);
82
  const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
83
- const uint32_t lane_idx = get_lane_idx();
84
 
85
  // Prefetch TMA descriptors at the very beginning
86
  if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
@@ -97,22 +103,22 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
97
 
98
  // Data on shared memory
99
  auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
100
- auto smem_a = PatternVisitor([&](const uint32_t& i) {
101
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
102
  });
103
- auto smem_b = PatternVisitor([&](const uint32_t& i) {
104
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
105
  });
106
  constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
107
- auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
108
  return reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
109
  });
110
  auto smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
111
 
112
  // Fill barriers
113
  auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_sfb) + smem_sfb_size);
114
- auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; });
115
- auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; });
116
 
117
  // Initialize barriers
118
  DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
@@ -136,9 +142,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
136
  constexpr uint32_t kNumTMARegisters = 40;
137
  constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232;
138
 
 
 
 
139
  // Block scheduler
140
  uint32_t m_block_idx, n_block_idx;
141
- auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
142
 
143
  // Pipeline and TMA phases
144
  uint32_t stage_idx = 0, phase = 0;
@@ -177,15 +186,15 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
177
  constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
178
  auto& full_barrier = *full_barriers[stage_idx];
179
  const uint32_t k_idx = k_block_idx * BLOCK_K;
180
- tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_a, &full_barrier,
181
  smem_a[stage_idx], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
182
  num_tma_multicast_a, batch_idx);
183
- tma_copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier,
184
- smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx<kWithGroupOffsetA, IndexType::SF_K>(shape_k_scales, 1, k_block_idx),
185
  num_tma_multicast_a);
186
 
187
  // Issue TMA B
188
- tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_b, &full_barrier,
189
  smem_b[stage_idx], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
190
  num_tma_multicast_b, batch_idx);
191
  full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
@@ -206,8 +215,8 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
206
  const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
207
  const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
208
 
209
- auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1);
210
- auto b_desc = make_smem_desc(smem_b[0], 1);
211
  const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
212
  const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
213
 
@@ -225,14 +234,14 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
225
  // Load B scales with math warp-groups
226
  // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
227
  if (threadIdx.x >= 32) {
228
- auto previous_group_offset = scheduler.template get_global_idx<true, IndexType::SF_K>(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx);
229
  const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales;
230
  const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1;
231
  auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb;
232
 
233
  #pragma unroll
234
  for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32)
235
- st_shared(smem_sfb + i, __ldg(i < shape_k_scales ? local_sfb + i * stride_k_sfb : local_sfb + (i - shape_k_scales) * stride_k_sfb + stride_n_sfb));
236
  }
237
  cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
238
 
@@ -259,22 +268,22 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
259
  // Skip useless computations
260
  if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) {
261
  // The compiler must know the dynamic variable `num_former_iters`'s real value
262
- constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
263
- constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
264
  constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
265
 
266
  // Dispatch `num_former_iters` and launch MMAs
267
  dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) {
268
  #pragma unroll 8
269
  for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
270
- const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
271
- const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
272
 
273
  // Read B scales
274
- float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1;
275
  // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
276
  if constexpr (not kMustUseUniformedScaleB)
277
- scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales);
278
 
279
  // Wait TMA arrivals
280
  full_barriers[stage_idx]->wait(phase);
@@ -286,25 +295,25 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
286
 
287
  // Read A scales
288
  // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
289
- auto scale_a_0 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0;
290
- auto scale_a_1 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0;
291
 
292
  // Commit WGMMA instructions
293
  #pragma unroll
294
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
295
- warpgroup_fence_operand(accum[i]);
296
- warpgroup_arrive();
297
  #pragma unroll
298
  for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
299
  a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16;
300
  b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16;
301
  WGMMA::wgmma(a_desc, b_desc, accum, k);
302
  }
303
- warpgroup_commit_batch();
304
  #pragma unroll
305
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
306
- warpgroup_fence_operand(accum[i]);
307
- warpgroup_wait<0>();
308
 
309
  // Notify barrier arrival at the last warpgroup wave
310
  if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
@@ -325,7 +334,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
325
  #pragma unroll
326
  for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
327
  // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
328
- const bool& predicate = kMustUseUniformedScaleB or i < num_former_iters;
329
  shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
330
  shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
331
  shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
@@ -399,7 +408,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
399
  }
400
 
401
  // NOTES: only 16 lanes' addresses are used
402
- SM90_U32x2_STSM_N<nv_bfloat162>::copy(
403
  __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
404
  __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
405
  smem_ptr
 
10
  #include <cute/arch/copy_sm90_desc.hpp>
11
  #include <cute/arch/copy_sm90_tma.hpp>
12
 
13
+ #include <deep_gemm/common/math.cuh>
14
  #include <deep_gemm/common/utils.cuh>
15
+ #include <deep_gemm/common/tma_copy.cuh>
16
+ #include <deep_gemm/common/types.cuh>
17
+ #include <deep_gemm/mma/sm90.cuh>
18
+ #include <deep_gemm/epilogue/transform.cuh>
19
+ #include <deep_gemm/ptx/ld_st.cuh>
20
+ #include <deep_gemm/ptx/utils.cuh>
21
+ #include <deep_gemm/ptx/wgmma.cuh>
22
+ #include <deep_gemm/scheduler/gemm.cuh>
23
 
24
  namespace deep_gemm {
25
 
 
 
26
  template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd, typename func_t>
27
+ CUTLASS_DEVICE void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) {
28
  if (num_former_iters == kNumFormerIters) {
29
  func(cute::Int<kNumFormerIters>{});
30
  return;
 
39
  uint32_t kNumGroups,
40
  uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
41
  uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleDMode,
42
+ uint32_t kNumStages,
43
  uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
44
  uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
45
  uint32_t kNumSMs, GemmType kGemmType,
46
  typename epilogue_type_t>
47
+ CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
48
  sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
49
  uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
50
  const __grid_constant__ cute::TmaDescriptor tensor_map_a,
 
54
  #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
55
  // Scaling checks
56
  DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
57
+ DG_STATIC_ASSERT(
58
+ math::constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or
59
+ (math::constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
60
 
61
  // Types
62
+ using WGMMA = typename mma::sm90::FP8MMASelector<BLOCK_N>::type;
63
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
64
  DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
65
 
 
70
 
71
  // Shared memory
72
  static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
73
+ static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(__nv_bfloat16)), 1024u);
74
  static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
75
  static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
76
  static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
77
+ static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u);
78
+ const uint32_t shape_k_scales = math::ceil_div(shape_k, BLOCK_K);
79
+ const uint32_t shape_n_sfb = math::ceil_div(shape_n, BLOCK_K);
80
+ const uint32_t smem_sfb_size = math::align<uint32_t>(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier));
81
 
82
  // NOTES: Make sure we have enough shared memory for WGMMA padding
83
  static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3);
84
  DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA");
85
 
86
  // Configs
87
+ const uint32_t num_total_k_blocks = math::ceil_div(shape_k, BLOCK_K);
88
  const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
89
+ const uint32_t lane_idx = ptx::get_lane_idx();
90
 
91
  // Prefetch TMA descriptors at the very beginning
92
  if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
 
103
 
104
  // Data on shared memory
105
  auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
106
+ auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
107
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
108
  });
109
+ auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
110
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
111
  });
112
  constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
113
+ auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) {
114
  return reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
115
  });
116
  auto smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
117
 
118
  // Fill barriers
119
  auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_sfb) + smem_sfb_size);
120
+ auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; });
121
+ auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; });
122
 
123
  // Initialize barriers
124
  DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
 
142
  constexpr uint32_t kNumTMARegisters = 40;
143
  constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232;
144
 
145
+ // Wait for primary kernel completion
146
+ cudaGridDependencySynchronize();
147
+
148
  // Block scheduler
149
  uint32_t m_block_idx, n_block_idx;
150
+ auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
151
 
152
  // Pipeline and TMA phases
153
  uint32_t stage_idx = 0, phase = 0;
 
186
  constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
187
  auto& full_barrier = *full_barriers[stage_idx];
188
  const uint32_t k_idx = k_block_idx * BLOCK_K;
189
+ tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_a, &full_barrier,
190
  smem_a[stage_idx], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
191
  num_tma_multicast_a, batch_idx);
192
+ tma::copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier,
193
+ smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx<kWithGroupOffsetA, sched::IndexType::SF_K>(shape_k_scales, 1, k_block_idx),
194
  num_tma_multicast_a);
195
 
196
  // Issue TMA B
197
+ tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_b, &full_barrier,
198
  smem_b[stage_idx], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
199
  num_tma_multicast_b, batch_idx);
200
  full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
 
215
  const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
216
  const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
217
 
218
+ auto a_desc = mma::sm90::make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1);
219
+ auto b_desc = mma::sm90::make_smem_desc(smem_b[0], 1);
220
  const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
221
  const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
222
 
 
234
  // Load B scales with math warp-groups
235
  // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
236
  if (threadIdx.x >= 32) {
237
+ auto previous_group_offset = scheduler.template get_global_idx<true, sched::IndexType::SF_K>(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx);
238
  const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales;
239
  const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1;
240
  auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb;
241
 
242
  #pragma unroll
243
  for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32)
244
+ ptx::st_shared(smem_sfb + i, i < shape_k_scales ? local_sfb[i * stride_k_sfb] : local_sfb[(i - shape_k_scales) * stride_k_sfb + stride_n_sfb]);
245
  }
246
  cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
247
 
 
268
  // Skip useless computations
269
  if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) {
270
  // The compiler must know the dynamic variable `num_former_iters`'s real value
271
+ constexpr bool kShouldOptimize = BLOCK_K / math::constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
272
+ constexpr uint32_t kGap = math::constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
273
  constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
274
 
275
  // Dispatch `num_former_iters` and launch MMAs
276
  dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) {
277
  #pragma unroll 8
278
  for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
279
+ const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
280
+ const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
281
 
282
  // Read B scales
283
+ float scale_b_0 = ptx::ld_shared(smem_sfb + k_block_idx), scale_b_1;
284
  // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
285
  if constexpr (not kMustUseUniformedScaleB)
286
+ scale_b_1 = ptx::ld_shared(smem_sfb + k_block_idx + shape_k_scales);
287
 
288
  // Wait TMA arrivals
289
  full_barriers[stage_idx]->wait(phase);
 
295
 
296
  // Read A scales
297
  // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
298
+ auto scale_a_0 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0;
299
+ auto scale_a_1 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0;
300
 
301
  // Commit WGMMA instructions
302
  #pragma unroll
303
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
304
+ ptx::warpgroup_fence_operand(accum[i]);
305
+ ptx::warpgroup_arrive();
306
  #pragma unroll
307
  for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
308
  a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16;
309
  b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16;
310
  WGMMA::wgmma(a_desc, b_desc, accum, k);
311
  }
312
+ ptx::warpgroup_commit_batch();
313
  #pragma unroll
314
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
315
+ ptx::warpgroup_fence_operand(accum[i]);
316
+ ptx::warpgroup_wait<0>();
317
 
318
  // Notify barrier arrival at the last warpgroup wave
319
  if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
 
334
  #pragma unroll
335
  for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
336
  // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
337
+ const bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
338
  shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
339
  shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
340
  shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
 
408
  }
409
 
410
  // NOTES: only 16 lanes' addresses are used
411
+ ptx::SM90_U32x2_STSM_N<nv_bfloat162>::copy(
412
  __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
413
  __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
414
  smem_ptr
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh CHANGED
@@ -7,36 +7,31 @@
7
  #include <cute/arch/copy_sm90_desc.hpp>
8
  #include <cute/arch/mma_sm90_desc.hpp>
9
 
 
 
10
  #include <deep_gemm/common/utils.cuh>
11
- #include <deep_gemm/common/sm90_utils.cuh>
 
 
 
 
 
12
 
13
  namespace deep_gemm {
14
 
15
- using namespace deep_gemm::sm90;
16
-
17
- // ReSharper disable once CppNotAllPathsReturnValue
18
- template <uint32_t kHeadDim>
19
- static constexpr int to_swizzle_cute_type() {
20
- DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling");
21
- if constexpr (kHeadDim == 32)
22
- return static_cast<int>(cute::SM90::GMMA::LayoutType::B32);
23
- if constexpr (kHeadDim == 64)
24
- return static_cast<int>(cute::SM90::GMMA::LayoutType::B64);
25
- if constexpr (kHeadDim == 128)
26
- return static_cast<int>(cute::SM90::GMMA::LayoutType::B128);
27
- }
28
-
29
  template <uint32_t kNumHeads, uint32_t kHeadDim,
30
  bool kIsCompressedLogits,
31
  uint32_t BLOCK_Q, uint32_t BLOCK_KV,
32
  uint32_t kNumQStages, uint32_t kNumKVStages,
33
- uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
34
- __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
 
 
35
  void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
36
- const uint32_t max_seqlen_k, const uint64_t stride_logits,
37
  uint32_t* cu_seq_len_k_start,
38
  uint32_t* cu_seq_len_k_end,
39
- float* logits,
40
  const __grid_constant__ cute::TmaDescriptor tensor_map_q,
41
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
42
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
@@ -44,10 +39,10 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
44
  // TODO: consider TMA multicast
45
  // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
46
  // Q should be load only at once for a block
47
- const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q);
48
 
49
  // Types
50
- using WGMMA = typename FP8MMASelector<BLOCK_Q * kNumHeads>::type;
51
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
52
 
53
  // Prefetch TMA descriptors
@@ -74,19 +69,19 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
74
  DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
75
 
76
  // Data on shared memory
77
- auto smem_q = PatternVisitor([&](const uint32_t& i) {
78
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
79
  SMEM_Q_SIZE_PER_STAGE * i);
80
  });
81
- auto smem_kv = PatternVisitor([&](const uint32_t& i) {
82
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
83
  SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
84
  });
85
- auto smem_weights = PatternVisitor([&](const uint32_t& i) {
86
  return reinterpret_cast<float*>(smem_buffer +
87
  SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
88
  });
89
- auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
90
  return reinterpret_cast<float*>(smem_buffer +
91
  SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages +
92
  SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
@@ -94,13 +89,13 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
94
 
95
  // TMA barriers
96
  auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
97
- auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
98
- auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
99
- auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
100
- auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
101
 
102
  // Initialize barriers
103
- const bool& is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32;
104
  if (is_tma_load_warp and cute::elect_one_sync()) {
105
  #pragma unroll
106
  for (uint32_t i = 0; i < kNumQStages; ++ i) {
@@ -123,38 +118,43 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
123
  constexpr uint32_t kNumMathRegisters = 112;
124
 
125
  // Block scheduler
126
- uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0;
127
- const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
128
- return {block_q_idx + gridDim.x, q_iter_idx + 1};
 
129
  };
130
  uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
131
- const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
132
  uint32_t start = cute::numeric_limits<uint32_t>::max();
133
  uint32_t end = cute::numeric_limits<uint32_t>::min();
134
 
135
  #pragma unroll
136
  for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
137
- const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
138
- seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx);
139
- seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx);
140
  start = min(start, min(seq_k_start[i], seq_len_kv));
141
  end = max(end, min(seq_k_end[i], seq_len_kv));
142
  }
 
143
  start = start / 4 * 4;
144
  return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
145
  ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
146
- start, ceil_div(end - start, BLOCK_KV)}; // Task info
147
  };
148
 
149
  // KV pipeline
150
  uint32_t num_total_kv_blocks = 0;
151
- const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
152
  return {
153
  (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
154
  ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
155
  };
156
  };
157
 
 
 
 
158
  if (threadIdx.x >= kNumMathThreads) {
159
  // TMA warp-group for loading data
160
  cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
@@ -165,8 +165,8 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
165
 
166
  // Prefetch
167
  const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
168
- tma_copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
169
- tma_copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
170
  full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
171
  };
172
  if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
@@ -192,9 +192,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
192
  empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
193
 
194
  // Issue TMA KV
195
- tma_copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
196
  smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
197
- tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
198
  smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
199
  full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
200
  }
@@ -212,7 +212,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
212
  const auto& thread_idx = threadIdx.x % kNumMathThreads;
213
  const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0);
214
  const auto& warpgroup_idx = warp_idx / 4;
215
- const auto& lane_idx = get_lane_idx();
216
  float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4];
217
 
218
  const auto& warp_offset = warp_idx * 16;
@@ -230,7 +230,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
230
  for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
231
  #pragma unroll
232
  for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
233
- weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
234
  }
235
 
236
  // Compute over KV blocks
@@ -242,29 +242,31 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
242
  full_kv_barriers[kv_stage_idx]->wait(kv_phase);
243
 
244
  // Read per-KV scales
245
- float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset);
246
- float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset);
247
 
248
  // Issue WGMMA
249
  DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size");
250
  DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
251
  #pragma unroll
252
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
253
- warpgroup_fence_operand(accum[i]);
254
- warpgroup_arrive();
255
  #pragma unroll
256
  for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
257
- auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K,
258
- to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
259
- auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K,
260
- to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
 
 
261
  WGMMA::wgmma(desc_a, desc_b, accum, k);
262
  }
263
- warpgroup_commit_batch();
264
  #pragma unroll
265
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
266
- warpgroup_fence_operand(accum[i]);
267
- warpgroup_wait<0>();
268
 
269
  // Release KV empty
270
  empty_kv_barriers[kv_stage_idx]->arrive();
@@ -278,7 +280,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
278
  #pragma unroll
279
  for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
280
  auto shifted_accum = accum + i * kNumAccumPerReduce;
281
- const auto& transform = [&](const uint32_t& j) {
282
  return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
283
  };
284
 
@@ -302,16 +304,15 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
302
  }
303
 
304
  // Store into the global memory
305
- // NOTES: we have redundant writes here, consider more carefully
306
- const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
307
  if constexpr (kIsCompressedLogits) {
308
  if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i])
309
- logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0;
310
  if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i])
311
- logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1;
312
  } else {
313
- logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0;
314
- logits[q_idx * stride_logits + kv_offset + v_1_offset] = v_1;
315
  }
316
  }
317
  }
 
7
  #include <cute/arch/copy_sm90_desc.hpp>
8
  #include <cute/arch/mma_sm90_desc.hpp>
9
 
10
+ #include <deep_gemm/common/cute_tie.cuh>
11
+ #include <deep_gemm/common/math.cuh>
12
  #include <deep_gemm/common/utils.cuh>
13
+ #include <deep_gemm/common/tma_copy.cuh>
14
+ #include <deep_gemm/common/types.cuh>
15
+ #include <deep_gemm/mma/sm90.cuh>
16
+ #include <deep_gemm/ptx/ld_st.cuh>
17
+ #include <deep_gemm/ptx/utils.cuh>
18
+ #include <deep_gemm/ptx/wgmma.cuh>
19
 
20
  namespace deep_gemm {
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  template <uint32_t kNumHeads, uint32_t kHeadDim,
23
  bool kIsCompressedLogits,
24
  uint32_t BLOCK_Q, uint32_t BLOCK_KV,
25
  uint32_t kNumQStages, uint32_t kNumKVStages,
26
+ uint32_t kNumSMs,
27
+ uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
28
+ typename logits_dtype_t>
29
+ CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
30
  void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
31
+ const uint32_t max_seqlen_k, const uint32_t stride_logits,
32
  uint32_t* cu_seq_len_k_start,
33
  uint32_t* cu_seq_len_k_end,
34
+ logits_dtype_t* logits,
35
  const __grid_constant__ cute::TmaDescriptor tensor_map_q,
36
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
37
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
 
39
  // TODO: consider TMA multicast
40
  // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
41
  // Q should be load only at once for a block
42
+ const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q);
43
 
44
  // Types
45
+ using WGMMA = typename mma::sm90::FP8MMASelector<BLOCK_Q * kNumHeads>::type;
46
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
47
 
48
  // Prefetch TMA descriptors
 
69
  DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
70
 
71
  // Data on shared memory
72
+ auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
73
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
74
  SMEM_Q_SIZE_PER_STAGE * i);
75
  });
76
+ auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
77
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
78
  SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
79
  });
80
+ auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
81
  return reinterpret_cast<float*>(smem_buffer +
82
  SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
83
  });
84
+ auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
85
  return reinterpret_cast<float*>(smem_buffer +
86
  SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages +
87
  SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
 
89
 
90
  // TMA barriers
91
  auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
92
+ auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
93
+ auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
94
+ auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
95
+ auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
96
 
97
  // Initialize barriers
98
+ const bool is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32;
99
  if (is_tma_load_warp and cute::elect_one_sync()) {
100
  #pragma unroll
101
  for (uint32_t i = 0; i < kNumQStages; ++ i) {
 
118
  constexpr uint32_t kNumMathRegisters = 112;
119
 
120
  // Block scheduler
121
+ const auto sm_idx = blockIdx.x;
122
+ uint32_t block_q_idx = sm_idx, q_iter_idx = 0;
123
+ const auto get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
124
+ return {block_q_idx + kNumSMs, q_iter_idx + 1};
125
  };
126
  uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
127
+ const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
128
  uint32_t start = cute::numeric_limits<uint32_t>::max();
129
  uint32_t end = cute::numeric_limits<uint32_t>::min();
130
 
131
  #pragma unroll
132
  for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
133
+ const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
134
+ seq_k_start[i] = cu_seq_len_k_start[q_idx];
135
+ seq_k_end[i] = cu_seq_len_k_end[q_idx];
136
  start = min(start, min(seq_k_start[i], seq_len_kv));
137
  end = max(end, min(seq_k_end[i], seq_len_kv));
138
  }
139
+ // TMA alignment requirements for SF KV
140
  start = start / 4 * 4;
141
  return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
142
  ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
143
+ start, math::ceil_div(end - start, BLOCK_KV)}; // Task info
144
  };
145
 
146
  // KV pipeline
147
  uint32_t num_total_kv_blocks = 0;
148
+ const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
149
  return {
150
  (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
151
  ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
152
  };
153
  };
154
 
155
+ // Wait for primary kernel completion
156
+ cudaGridDependencySynchronize();
157
+
158
  if (threadIdx.x >= kNumMathThreads) {
159
  // TMA warp-group for loading data
160
  cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
 
165
 
166
  // Prefetch
167
  const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
168
+ tma::copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
169
+ tma::copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
170
  full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
171
  };
172
  if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
 
192
  empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
193
 
194
  // Issue TMA KV
195
+ tma::copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
196
  smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
197
+ tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
198
  smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
199
  full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
200
  }
 
212
  const auto& thread_idx = threadIdx.x % kNumMathThreads;
213
  const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0);
214
  const auto& warpgroup_idx = warp_idx / 4;
215
+ const auto& lane_idx = ptx::get_lane_idx();
216
  float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4];
217
 
218
  const auto& warp_offset = warp_idx * 16;
 
230
  for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
231
  #pragma unroll
232
  for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
233
+ weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
234
  }
235
 
236
  // Compute over KV blocks
 
242
  full_kv_barriers[kv_stage_idx]->wait(kv_phase);
243
 
244
  // Read per-KV scales
245
+ float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset);
246
+ float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset);
247
 
248
  // Issue WGMMA
249
  DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size");
250
  DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
251
  #pragma unroll
252
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
253
+ ptx::warpgroup_fence_operand(accum[i]);
254
+ ptx::warpgroup_arrive();
255
  #pragma unroll
256
  for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
257
+ auto desc_a = mma::sm90::make_smem_desc(
258
+ smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K,
259
+ mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
260
+ auto desc_b = mma::sm90::make_smem_desc(
261
+ smem_q[q_stage_idx] + k * WGMMA::K,
262
+ mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
263
  WGMMA::wgmma(desc_a, desc_b, accum, k);
264
  }
265
+ ptx::warpgroup_commit_batch();
266
  #pragma unroll
267
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
268
+ ptx::warpgroup_fence_operand(accum[i]);
269
+ ptx::warpgroup_wait<0>();
270
 
271
  // Release KV empty
272
  empty_kv_barriers[kv_stage_idx]->arrive();
 
280
  #pragma unroll
281
  for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
282
  auto shifted_accum = accum + i * kNumAccumPerReduce;
283
+ const auto transform = [&](const uint32_t& j) {
284
  return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
285
  };
286
 
 
304
  }
305
 
306
  // Store into the global memory
307
+ const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast<uint64_t>(stride_logits);
 
308
  if constexpr (kIsCompressedLogits) {
309
  if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i])
310
+ logits[q_offset + kv_offset + v_0_offset - seq_k_start[i]] = static_cast<logits_dtype_t>(v_0);
311
  if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i])
312
+ logits[q_offset + kv_offset + v_1_offset - seq_k_start[i]] = static_cast<logits_dtype_t>(v_1);
313
  } else {
314
+ logits[q_offset + kv_offset + v_0_offset] = static_cast<logits_dtype_t>(v_0);
315
+ logits[q_offset + kv_offset + v_1_offset] = static_cast<logits_dtype_t>(v_1);
316
  }
317
  }
318
  }
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh CHANGED
@@ -6,133 +6,46 @@
6
  #include <cute/arch/cluster_sm90.hpp>
7
  #include <cute/arch/copy_sm90_desc.hpp>
8
 
 
 
9
  #include <deep_gemm/common/utils.cuh>
10
- #include <deep_gemm/common/sm90_utils.cuh>
11
- #include <deep_gemm/impls/sm90_fp8_mqa_logits.cuh>
 
 
 
 
 
12
 
13
  namespace deep_gemm {
14
 
15
- template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs>
16
- __global__ __launch_bounds__(32, 1)
17
- void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d,
18
- const uint32_t* context_lens, uint32_t* schedule_metadata) {
19
- DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size");
20
- const uint32_t lane_idx = get_lane_idx();
21
-
22
- uint32_t num_segs[kAlignedBatchSize / 32];
23
- #pragma unroll
24
- for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
25
- const uint32_t q_idx = k * 32 + lane_idx;
26
- const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx);
27
- const uint32_t& context_len = (q_idx < batch_size ? __ldg(context_lens + lens_idx) : 0);
28
- num_segs[k] = ceil_div(context_len, SPLIT_KV);
29
- }
30
-
31
- __shared__ uint32_t prefix_sum[kAlignedBatchSize];
32
- uint32_t sum = 0;
33
- #pragma unroll
34
- for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
35
- uint32_t x = num_segs[k];
36
- #pragma unroll
37
- for (uint32_t offset = 1; offset < 32; offset <<= 1) {
38
- const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset);
39
- x += (lane_idx >= offset ? y : 0);
40
- }
41
- x += sum;
42
- prefix_sum[k * 32 + lane_idx] = x;
43
- sum = __shfl_sync(0xffffffff, x, 31);
44
- }
45
-
46
- const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs;
47
- for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
48
- uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
49
- uint32_t q_idx = 0;
50
- while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts)
51
- ++ q_idx;
52
- const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]);
53
- __syncwarp();
54
-
55
- schedule_metadata[sm_idx * 2] = q_idx;
56
- schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
57
- }
58
- }
59
-
60
- template <uint32_t kNextN, bool kIsContextLens2D,
61
- uint32_t BLOCK_KV, uint32_t kNumBlocksPerSplit>
62
- struct PagedMQALogitsScheduler {
63
- uint32_t batch_size;
64
- const uint32_t* context_lens;
65
-
66
- uint32_t current_q_idx, current_kv_idx;
67
- uint32_t end_q_idx, end_kv_idx;
68
- uint32_t current_num_kv;
69
-
70
- __device__ __forceinline__ uint32_t get_num_kv(const uint32_t& q_idx) {
71
- const auto& lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx);
72
- return q_idx < batch_size ? ceil_div(__ldg(context_lens + lens_idx), BLOCK_KV) : 0;
73
- }
74
-
75
- __device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx,
76
- const uint32_t* context_lens, const uint32_t* schedule_meta) {
77
- this->batch_size = batch_size;
78
- this->context_lens = context_lens;
79
-
80
- const auto& current_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx);
81
- const auto& end_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx + 1);
82
- current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit;
83
- end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit;
84
-
85
- current_num_kv = get_num_kv(current_q_idx);
86
- }
87
-
88
- __device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) {
89
- q_idx = current_q_idx;
90
- kv_idx = current_kv_idx;
91
- num_kv = current_num_kv;
92
-
93
- if (q_idx == end_q_idx and kv_idx == end_kv_idx)
94
- return false;
95
-
96
- current_kv_idx += kNumBlocksPerSplit;
97
- if (current_kv_idx >= current_num_kv) {
98
- ++ current_q_idx;
99
- current_kv_idx = 0;
100
- current_num_kv = get_num_kv(current_q_idx);
101
- }
102
-
103
- return true;
104
- }
105
-
106
- __device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const {
107
- return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx;
108
- }
109
- };
110
-
111
- using namespace deep_gemm::sm90;
112
-
113
  template <uint32_t kNextN, uint32_t kNumHeads,
114
  uint32_t kHeadDim, uint32_t BLOCK_KV,
115
- bool kIsContextLens2D,
116
  uint32_t kNumQStages, uint32_t kNumKVStages,
117
  uint32_t SPLIT_KV,
118
- uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
119
- __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
 
120
  void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
121
- const uint64_t logits_stride, const uint64_t block_table_stride,
122
- const uint32_t* context_lens, float* logits,
123
- const uint32_t* block_table, const uint32_t* schedule_meta,
 
124
  const __grid_constant__ cute::TmaDescriptor tensor_map_q,
125
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
126
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
127
  const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
 
 
128
  // Types
129
- using WGMMA = typename FP8MMASelector<kNextN * kNumHeads>::type;
130
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
131
 
132
  // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
133
- const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
134
- const auto& warpgroup_idx = warp_idx / 4;
135
- const auto& lane_idx = get_lane_idx();
136
 
137
  // Prefetch TMA descriptors
138
  static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128;
@@ -150,15 +63,15 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
150
  static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
151
  static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
152
  static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
153
- static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment);
154
  static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) +
155
- constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment);
156
 
157
  static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
158
  static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
159
- static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment);
160
  static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) +
161
- constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment);
162
 
163
  // Align to swizzling alignment bytes
164
  extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
@@ -166,31 +79,31 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
166
  DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
167
 
168
  // Q data and barriers on shared memory
169
- auto smem_q = PatternVisitor([&](const uint32_t& i) {
170
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
171
  });
172
- auto smem_weights = PatternVisitor([&](const uint32_t& i) {
173
  return reinterpret_cast<float*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i);
174
  });
175
  auto q_barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
176
- auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; });
177
- auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); });
178
 
179
  // Separate math warpgroups and tma load warps into KV groups
180
  // Each math warpgroup corresponds to a tma load warp
181
- const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0);
182
 
183
  // Per group KV data and barriers on shared memory
184
- const auto& smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx;
185
- auto smem_kv = PatternVisitor([&](const uint32_t& i) {
186
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i);
187
  });
188
- auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
189
  return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
190
  });
191
  auto kv_barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
192
- auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; });
193
- auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; });
194
 
195
  // Initialize barriers
196
  if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) {
@@ -218,15 +131,19 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
218
  constexpr uint32_t kNumTMARegisters = 64;
219
  constexpr uint32_t kNumMathRegisters = 104;
220
 
 
 
 
221
  // Scheduler
222
- auto scheduler = PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumMathWarpGroups>(batch_size, blockIdx.x, context_lens, schedule_meta);
 
223
  DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
224
 
225
  // Q and KV pipeline
226
- const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
227
  return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
228
  };
229
- const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
230
  return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
231
  };
232
  uint32_t q_iter_idx = 0, kv_iter_idx = 0;
@@ -237,10 +154,10 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
237
  if (kv_group_idx >= kNumMathWarpGroups)
238
  return;
239
 
240
- const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
241
  if (kv_group_idx == 0 and cute::elect_one_sync()) {
242
- tma_copy<kHeadDim, kNextN * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
243
- tma_copy<kNextN * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx);
244
  full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
245
  }
246
  };
@@ -259,7 +176,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
259
 
260
  while (fetched_next_task) {
261
  // Prefetch next Q when current Q changes
262
- bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1));
263
  q_idx = next_q_idx;
264
  kv_idx = next_kv_idx;
265
  num_kv = next_num_kv;
@@ -276,9 +193,9 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
276
  if (kv_idx == 0 or kv_block_idx_ptr == 32) {
277
  kv_block_idx_ptr = 0;
278
  kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ?
279
- __ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0);
280
  }
281
- const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++);
282
 
283
  // Wait KV consumer release
284
  CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
@@ -286,10 +203,10 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
286
 
287
  // Issue TMA KV
288
  if (cute::elect_one_sync()) {
289
- tma_copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
290
- smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx);
291
- tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
292
- smem_kv_scales[kv_stage_idx], 0, kv_block_idx);
293
  full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
294
  }
295
 
@@ -301,9 +218,9 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
301
  cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
302
 
303
  float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4];
304
- const auto& sub_warp_offset = (warp_idx % 4) * 16;
305
- const auto& v_0_offset = lane_idx / 4 + 0;
306
- const auto& v_1_offset = lane_idx / 4 + 8;
307
 
308
  // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
309
  uint32_t q_idx = batch_size, kv_idx;
@@ -326,7 +243,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
326
  for (uint32_t i = 0; i < kNextN; ++ i) {
327
  #pragma unroll
328
  for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
329
- weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
330
  }
331
  }
332
 
@@ -335,7 +252,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
335
  kv_idx = next_kv_idx;
336
 
337
  // Calculate KV offset in advance
338
- auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);
339
 
340
  // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]`
341
  // Wait TMA KV arrival
@@ -347,25 +264,29 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
347
  DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
348
  #pragma unroll
349
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
350
- warpgroup_fence_operand(accum[i]);
351
- warpgroup_arrive();
352
  #pragma unroll
353
  for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
354
- auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + k * WGMMA::K, to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
355
- auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
 
 
 
 
356
  WGMMA::wgmma(desc_a, desc_b, accum, k);
357
  }
358
- warpgroup_commit_batch();
359
  #pragma unroll
360
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
361
- warpgroup_fence_operand(accum[i]);
362
 
363
  // Read per-KV scales
364
- float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset);
365
- float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset);
366
 
367
  // Wait WGMMA
368
- warpgroup_wait<0>();
369
 
370
  // Release KV empty
371
  empty_kv_barriers[kv_stage_idx]->arrive();
@@ -378,7 +299,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
378
  #pragma unroll
379
  for (uint32_t i = 0; i < kNextN; ++ i) {
380
  auto shifted_accum = accum + i * kNumAccumPerReduce;
381
- const auto& transform = [&](const uint32_t& j) {
382
  return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
383
  };
384
 
@@ -396,15 +317,15 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
396
  // Inter-thread reduction
397
  #pragma unroll
398
  for (uint32_t j = 0; j < 2; ++ j) {
399
- const auto& offset = static_cast<int>(1u << j);
400
  v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset);
401
  v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset);
402
  }
403
 
404
  // Store into the global memory
405
  // NOTES: we have redundant writes here, consider more carefully
406
- logits[kv_offset + i * logits_stride + v_0_offset] = v_0;
407
- logits[kv_offset + i * logits_stride + v_1_offset] = v_1;
408
  }
409
  }
410
  }
 
6
  #include <cute/arch/cluster_sm90.hpp>
7
  #include <cute/arch/copy_sm90_desc.hpp>
8
 
9
+ #include <deep_gemm/common/cute_tie.cuh>
10
+ #include <deep_gemm/common/math.cuh>
11
  #include <deep_gemm/common/utils.cuh>
12
+ #include <deep_gemm/common/tma_copy.cuh>
13
+ #include <deep_gemm/common/types.cuh>
14
+ #include <deep_gemm/mma/sm90.cuh>
15
+ #include <deep_gemm/ptx/ld_st.cuh>
16
+ #include <deep_gemm/ptx/utils.cuh>
17
+ #include <deep_gemm/ptx/wgmma.cuh>
18
+ #include <deep_gemm/scheduler/paged_mqa_logits.cuh>
19
 
20
  namespace deep_gemm {
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  template <uint32_t kNextN, uint32_t kNumHeads,
23
  uint32_t kHeadDim, uint32_t BLOCK_KV,
24
+ bool kIsContextLens2D, bool kIsVarlen,
25
  uint32_t kNumQStages, uint32_t kNumKVStages,
26
  uint32_t SPLIT_KV,
27
+ uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
28
+ typename logits_dtype_t>
29
+ CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
30
  void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
31
+ const uint32_t logits_stride, const uint32_t block_table_stride,
32
+ const uint32_t* context_lens, logits_dtype_t* logits,
33
+ const uint32_t* block_table, const uint32_t* indices,
34
+ const uint32_t* schedule_meta,
35
  const __grid_constant__ cute::TmaDescriptor tensor_map_q,
36
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
37
  const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
38
  const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
39
+ DG_STATIC_ASSERT(not kIsVarlen, "Varlen is not supported for SM90 paged MQA logits");
40
+
41
  // Types
42
+ using WGMMA = typename mma::sm90::FP8MMASelector<kNextN * kNumHeads>::type;
43
  using Barrier = cutlass::arch::ClusterTransactionBarrier;
44
 
45
  // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
46
+ const auto warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
47
+ const auto warpgroup_idx = warp_idx / 4;
48
+ const auto lane_idx = ptx::get_lane_idx();
49
 
50
  // Prefetch TMA descriptors
51
  static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128;
 
63
  static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
64
  static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
65
  static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
66
+ static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = math::constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment);
67
  static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) +
68
+ math::constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment);
69
 
70
  static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
71
  static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
72
+ static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment);
73
  static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) +
74
+ math::constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment);
75
 
76
  // Align to swizzling alignment bytes
77
  extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
 
79
  DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
80
 
81
  // Q data and barriers on shared memory
82
+ auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
83
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
84
  });
85
+ auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
86
  return reinterpret_cast<float*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i);
87
  });
88
  auto q_barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
89
+ auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; });
90
+ auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); });
91
 
92
  // Separate math warpgroups and tma load warps into KV groups
93
  // Each math warpgroup corresponds to a tma load warp
94
+ const auto kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0);
95
 
96
  // Per group KV data and barriers on shared memory
97
+ const auto smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx;
98
+ auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
99
  return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i);
100
  });
101
+ auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
102
  return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
103
  });
104
  auto kv_barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
105
+ auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; });
106
+ auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; });
107
 
108
  // Initialize barriers
109
  if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) {
 
131
  constexpr uint32_t kNumTMARegisters = 64;
132
  constexpr uint32_t kNumMathRegisters = 104;
133
 
134
+ // Wait for primary kernel completion
135
+ cudaGridDependencySynchronize();
136
+
137
  // Scheduler
138
+ auto scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumMathWarpGroups, 1>(
139
+ blockIdx.x, batch_size, context_lens, schedule_meta, indices);
140
  DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
141
 
142
  // Q and KV pipeline
143
+ const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
144
  return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
145
  };
146
+ const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
147
  return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
148
  };
149
  uint32_t q_iter_idx = 0, kv_iter_idx = 0;
 
154
  if (kv_group_idx >= kNumMathWarpGroups)
155
  return;
156
 
157
+ const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
158
  if (kv_group_idx == 0 and cute::elect_one_sync()) {
159
+ tma::copy<kHeadDim, kNextN * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
160
+ tma::copy<kNextN * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx * kNextN);
161
  full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
162
  }
163
  };
 
176
 
177
  while (fetched_next_task) {
178
  // Prefetch next Q when current Q changes
179
+ bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_atom_idx(next_q_idx + 1));
180
  q_idx = next_q_idx;
181
  kv_idx = next_kv_idx;
182
  num_kv = next_num_kv;
 
193
  if (kv_idx == 0 or kv_block_idx_ptr == 32) {
194
  kv_block_idx_ptr = 0;
195
  kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ?
196
+ block_table[q_idx * static_cast<uint64_t>(block_table_stride) + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)] : 0);
197
  }
198
+ const auto kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++);
199
 
200
  // Wait KV consumer release
201
  CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
 
203
 
204
  // Issue TMA KV
205
  if (cute::elect_one_sync()) {
206
+ tma::copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
207
+ smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx);
208
+ tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
209
+ smem_kv_scales[kv_stage_idx], 0, kv_block_idx);
210
  full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
211
  }
212
 
 
218
  cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
219
 
220
  float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4];
221
+ const auto sub_warp_offset = (warp_idx % 4) * 16;
222
+ const auto v_0_offset = lane_idx / 4 + 0;
223
+ const auto v_1_offset = lane_idx / 4 + 8;
224
 
225
  // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
226
  uint32_t q_idx = batch_size, kv_idx;
 
243
  for (uint32_t i = 0; i < kNextN; ++ i) {
244
  #pragma unroll
245
  for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
246
+ weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
247
  }
248
  }
249
 
 
252
  kv_idx = next_kv_idx;
253
 
254
  // Calculate KV offset in advance
255
+ auto kv_offset = q_idx * kNextN * static_cast<uint64_t>(logits_stride) + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);
256
 
257
  // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]`
258
  // Wait TMA KV arrival
 
264
  DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
265
  #pragma unroll
266
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
267
+ ptx::warpgroup_fence_operand(accum[i]);
268
+ ptx::warpgroup_arrive();
269
  #pragma unroll
270
  for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
271
+ auto desc_a = mma::sm90::make_smem_desc(
272
+ smem_kv[kv_stage_idx] + k * WGMMA::K,
273
+ mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
274
+ auto desc_b = mma::sm90::make_smem_desc(
275
+ smem_q[q_stage_idx] + k * WGMMA::K,
276
+ mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
277
  WGMMA::wgmma(desc_a, desc_b, accum, k);
278
  }
279
+ ptx::warpgroup_commit_batch();
280
  #pragma unroll
281
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
282
+ ptx::warpgroup_fence_operand(accum[i]);
283
 
284
  // Read per-KV scales
285
+ float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset);
286
+ float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset);
287
 
288
  // Wait WGMMA
289
+ ptx::warpgroup_wait<0>();
290
 
291
  // Release KV empty
292
  empty_kv_barriers[kv_stage_idx]->arrive();
 
299
  #pragma unroll
300
  for (uint32_t i = 0; i < kNextN; ++ i) {
301
  auto shifted_accum = accum + i * kNumAccumPerReduce;
302
+ const auto transform = [&](const uint32_t& j) {
303
  return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
304
  };
305
 
 
317
  // Inter-thread reduction
318
  #pragma unroll
319
  for (uint32_t j = 0; j < 2; ++ j) {
320
+ const auto offset = static_cast<int>(1u << j);
321
  v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset);
322
  v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset);
323
  }
324
 
325
  // Store into the global memory
326
  // NOTES: we have redundant writes here, consider more carefully
327
+ logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + v_0_offset] = static_cast<logits_dtype_t>(v_0);
328
+ logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + v_1_offset] = static_cast<logits_dtype_t>(v_1);
329
  }
330
  }
331
  }
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh CHANGED
@@ -5,20 +5,23 @@
5
  #include <cutlass/arch/barrier.h>
6
  #include <cutlass/arch/reg_reconfig.h>
7
 
8
- #include <deep_gemm/common/reduction.cuh>
9
  #include <deep_gemm/common/utils.cuh>
10
- #include <deep_gemm/common/sm90_utils.cuh>
 
 
 
 
 
11
 
12
  namespace deep_gemm {
13
 
14
- using namespace deep_gemm::sm90;
15
-
16
  template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
17
- __device__ __forceinline__
18
  uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) {
19
  constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase;
20
 
21
- const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange;
22
 
23
  constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
24
  constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups;
@@ -35,7 +38,7 @@ template <uint32_t SHAPE_N, uint32_t SHAPE_K,
35
  uint32_t kSwizzleCDMode,
36
  uint32_t kNumStages,
37
  uint32_t kNumMathThreads, uint32_t kNumTMAThreads>
38
- __global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1)
39
  sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
40
  const __grid_constant__ cute::TmaDescriptor tensor_map_a,
41
  const __grid_constant__ cute::TmaDescriptor tensor_map_b,
@@ -56,7 +59,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
56
 
57
  // Utils
58
  const auto warp_idx = cutlass::canonical_warp_idx_sync();
59
- const auto lane_idx = get_lane_idx();
60
 
61
  // Align to 1024 bytes for swizzle-128B
62
  extern __shared__ __align__(1024) uint8_t smem_buffer[];
@@ -76,17 +79,17 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
76
  // Data on shared memory (layout as ordered below)
77
  // Fill D/A/B pointers
78
  auto smem_cd = reinterpret_cast<float*>(smem_buffer);
79
- auto smem_a = PatternVisitor([&](const uint32_t& i) {
80
  return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
81
  });
82
- auto smem_b = PatternVisitor([&](const uint32_t& i) {
83
  return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
84
  });
85
 
86
  // Fill barriers
87
  auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
88
- auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
89
- auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
90
 
91
  // Initialize barriers
92
  if (warp_idx == 1 and cute::elect_one_sync()) {
@@ -101,7 +104,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
101
  }
102
  __syncthreads();
103
 
104
- constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
105
  constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
106
  constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
107
  const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
@@ -113,12 +116,15 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
113
  constexpr uint32_t kNumTMARegisters = 40;
114
  constexpr uint32_t kNumMathRegisters = 256;
115
 
 
 
 
116
  // TMA load warp
117
  if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
118
  cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
119
  for (uint32_t s = 0; s < num_total_stages; ++ s) {
120
  // Wait consumer release
121
- const auto& stage_idx = s % kNumStages;
122
  empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
123
 
124
  // Compute offsets
@@ -126,8 +132,8 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
126
  uint32_t k_idx = k_offset + s * BLOCK_K;
127
 
128
  // Issue TMAs
129
- tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
130
- tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
131
 
132
  // Arrive at full barriers
133
  constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
@@ -135,7 +141,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
135
  }
136
 
137
  for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) {
138
- const auto& stage_idx = s % kNumStages;
139
  empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
140
  }
141
  } else if (warp_idx < kNumMathThreads / 32) {
@@ -148,7 +154,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
148
  constexpr uint32_t WGMMA_N = BLOCK_N;
149
  constexpr uint32_t WGMMA_K = 8;
150
 
151
- using WGMMA = typename TF32MMASelector<WGMMA_N, true>::type;
152
  float accum[WGMMA::kNumAccum] = {0};
153
 
154
  constexpr uint32_t kNumBankGroupBytes = 16;
@@ -196,14 +202,14 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
196
  sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y;
197
  }
198
 
199
- warpgroup_wait<0>();
200
  if (s > 0)
201
  empty_barriers[(s - 1) % kNumStages]->arrive();
202
 
203
  #pragma unroll
204
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
205
- warpgroup_fence_operand(accum[i]);
206
- warpgroup_arrive();
207
 
208
  constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float);
209
  constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K;
@@ -213,18 +219,19 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
213
  for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) {
214
  #pragma unroll
215
  for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) {
216
- auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1);
 
217
  WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1);
218
  }
219
  }
220
- warpgroup_commit_batch();
221
  #pragma unroll
222
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
223
- warpgroup_fence_operand(accum[i]);
224
  }
225
 
226
- const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0);
227
- const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1);
228
 
229
  const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4);
230
  if (lane_idx % 4 == 0) {
@@ -233,7 +240,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
233
  if (m_idx + 8 < shape_m)
234
  sqr_sum[m_offset + m_idx + 8] = reduced_sum_1;
235
  }
236
- warpgroup_wait<0>();
237
  empty_barriers[(num_total_stages-1) % kNumStages]->arrive();
238
 
239
  // Write accum to shared memory
@@ -260,8 +267,8 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
260
 
261
  // 0/1 write to the same row, 2/3 write to another row
262
  auto values = reinterpret_cast<uint32_t*>(accum + i * 2);
263
- st_shared(smem_ptr, values[0], values[1]);
264
- st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]);
265
  }
266
  cute::tma_store_fence();
267
  cutlass::arch::NamedBarrier::sync(128, 1);
 
5
  #include <cutlass/arch/barrier.h>
6
  #include <cutlass/arch/reg_reconfig.h>
7
 
8
+ #include <deep_gemm/common/math.cuh>
9
  #include <deep_gemm/common/utils.cuh>
10
+ #include <deep_gemm/common/tma_copy.cuh>
11
+ #include <deep_gemm/common/types.cuh>
12
+ #include <deep_gemm/mma/sm90.cuh>
13
+ #include <deep_gemm/ptx/ld_st.cuh>
14
+ #include <deep_gemm/ptx/utils.cuh>
15
+ #include <deep_gemm/ptx/wgmma.cuh>
16
 
17
  namespace deep_gemm {
18
 
 
 
19
  template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
20
+ CUTLASS_DEVICE
21
  uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) {
22
  constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase;
23
 
24
+ const auto bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange;
25
 
26
  constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
27
  constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups;
 
38
  uint32_t kSwizzleCDMode,
39
  uint32_t kNumStages,
40
  uint32_t kNumMathThreads, uint32_t kNumTMAThreads>
41
+ CUTLASS_GLOBAL void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1)
42
  sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
43
  const __grid_constant__ cute::TmaDescriptor tensor_map_a,
44
  const __grid_constant__ cute::TmaDescriptor tensor_map_b,
 
59
 
60
  // Utils
61
  const auto warp_idx = cutlass::canonical_warp_idx_sync();
62
+ const auto lane_idx = ptx::get_lane_idx();
63
 
64
  // Align to 1024 bytes for swizzle-128B
65
  extern __shared__ __align__(1024) uint8_t smem_buffer[];
 
79
  // Data on shared memory (layout as ordered below)
80
  // Fill D/A/B pointers
81
  auto smem_cd = reinterpret_cast<float*>(smem_buffer);
82
+ auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
83
  return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
84
  });
85
+ auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
86
  return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
87
  });
88
 
89
  // Fill barriers
90
  auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
91
+ auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
92
+ auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
93
 
94
  // Initialize barriers
95
  if (warp_idx == 1 and cute::elect_one_sync()) {
 
104
  }
105
  __syncthreads();
106
 
107
+ constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K);
108
  constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
109
  constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
110
  const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
 
116
  constexpr uint32_t kNumTMARegisters = 40;
117
  constexpr uint32_t kNumMathRegisters = 256;
118
 
119
+ // Wait for primary kernel completion
120
+ cudaGridDependencySynchronize();
121
+
122
  // TMA load warp
123
  if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
124
  cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
125
  for (uint32_t s = 0; s < num_total_stages; ++ s) {
126
  // Wait consumer release
127
+ const auto stage_idx = s % kNumStages;
128
  empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
129
 
130
  // Compute offsets
 
132
  uint32_t k_idx = k_offset + s * BLOCK_K;
133
 
134
  // Issue TMAs
135
+ tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
136
+ tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
137
 
138
  // Arrive at full barriers
139
  constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
 
141
  }
142
 
143
  for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) {
144
+ const auto stage_idx = s % kNumStages;
145
  empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
146
  }
147
  } else if (warp_idx < kNumMathThreads / 32) {
 
154
  constexpr uint32_t WGMMA_N = BLOCK_N;
155
  constexpr uint32_t WGMMA_K = 8;
156
 
157
+ using WGMMA = typename mma::sm90::TF32MMASelector<WGMMA_N, true>::type;
158
  float accum[WGMMA::kNumAccum] = {0};
159
 
160
  constexpr uint32_t kNumBankGroupBytes = 16;
 
202
  sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y;
203
  }
204
 
205
+ ptx::warpgroup_wait<0>();
206
  if (s > 0)
207
  empty_barriers[(s - 1) % kNumStages]->arrive();
208
 
209
  #pragma unroll
210
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
211
+ ptx::warpgroup_fence_operand(accum[i]);
212
+ ptx::warpgroup_arrive();
213
 
214
  constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float);
215
  constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K;
 
219
  for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) {
220
  #pragma unroll
221
  for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) {
222
+ auto b_desc = mma::sm90::make_smem_desc(
223
+ smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1);
224
  WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1);
225
  }
226
  }
227
+ ptx::warpgroup_commit_batch();
228
  #pragma unroll
229
  for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
230
+ ptx::warpgroup_fence_operand(accum[i]);
231
  }
232
 
233
+ const auto& reduced_sum_0 = math::warp_reduce_sum<4>(sqr_sum_acc_0);
234
+ const auto& reduced_sum_1 = math::warp_reduce_sum<4>(sqr_sum_acc_1);
235
 
236
  const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4);
237
  if (lane_idx % 4 == 0) {
 
240
  if (m_idx + 8 < shape_m)
241
  sqr_sum[m_offset + m_idx + 8] = reduced_sum_1;
242
  }
243
+ ptx::warpgroup_wait<0>();
244
  empty_barriers[(num_total_stages-1) % kNumStages]->arrive();
245
 
246
  // Write accum to shared memory
 
267
 
268
  // 0/1 write to the same row, 2/3 write to another row
269
  auto values = reinterpret_cast<uint32_t*>(accum + i * 2);
270
+ ptx::st_shared(smem_ptr, values[0], values[1]);
271
+ ptx::st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]);
272
  }
273
  cute::tma_store_fence();
274
  cutlass::arch::NamedBarrier::sync(128, 1);
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh CHANGED
@@ -3,21 +3,24 @@
3
  #include <cutlass/arch/barrier.h>
4
  #include <cute/arch/cluster_sm90.hpp>
5
 
6
- #include <deep_gemm/common/utils.cuh>
 
7
 
8
  namespace deep_gemm {
9
 
10
- template <uint32_t kNextN, uint32_t BLOCK_KV, uint32_t kNumWarps>
11
- __global__ __launch_bounds__(kNumWarps * 32, 1)
12
  void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits,
13
- const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) {
14
- const uint32_t& num_sms = gridDim.x;
15
- const uint32_t& sm_idx = blockIdx.x;
16
- const uint32_t& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
17
- constexpr float neg_inf = -cute::numeric_limits<float>::infinity();
 
 
18
 
19
  // Allocate filled `-inf` shared memory
20
- extern __shared__ __align__(1024) float smem_buffer[];
21
  #pragma unroll
22
  for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32)
23
  smem_buffer[i] = neg_inf;
@@ -25,38 +28,42 @@ void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const
25
  __syncthreads();
26
 
27
  // Assign sequence to each warp
28
- const auto& assign_task = [&](const uint32_t& num, const uint32_t& idx,
29
- const uint32_t& start, const uint32_t& total) -> cute::tuple<uint32_t, uint32_t> {
30
- const auto& per = total / num, rem = total % num;
31
- return {start + idx * per + min(idx, rem), per + (idx < rem)};
32
  };
33
  CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len);
34
  CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len);
35
 
 
 
 
36
  if (cute::elect_one_sync()) {
37
  for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
38
- const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN);
39
- const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1;
40
- const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4;
41
 
42
  for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) {
43
- const auto& right = min(left + BLOCK_KV, static_cast<uint32_t>(stride_logits));
44
  if (right <= ks or ke <= left) {
45
- cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(float));
46
  } else {
47
  if (left < aligned_ks)
48
- cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(float));
49
  if (aligned_ke < right)
50
- cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(float));
51
  }
52
  }
53
  }
54
  }
 
55
 
56
  for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
57
- const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN);
58
- const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1;
59
- const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4;
60
  for (uint32_t j = aligned_ks; j < ks; ++ j)
61
  logits[i * stride_logits + j] = neg_inf;
62
  for (uint32_t j = ke; j < aligned_ke; ++ j)
 
3
  #include <cutlass/arch/barrier.h>
4
  #include <cute/arch/cluster_sm90.hpp>
5
 
6
+ #include <deep_gemm/common/cute_tie.cuh>
7
+ #include <deep_gemm/common/math.cuh>
8
 
9
  namespace deep_gemm {
10
 
11
+ template <uint32_t kNextN, uint32_t BLOCK_KV, uint32_t kNumWarps, typename logits_dtype_t>
12
+ CUTLASS_GLOBAL __launch_bounds__(kNumWarps * 32, 1)
13
  void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits,
14
+ const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, logits_dtype_t* logits) {
15
+ const uint32_t num_sms = gridDim.x;
16
+ const uint32_t sm_idx = blockIdx.x;
17
+ const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
18
+
19
+ constexpr uint32_t kAlignment = 16 / sizeof(logits_dtype_t);
20
+ const logits_dtype_t neg_inf = -cute::numeric_limits<logits_dtype_t>::infinity();
21
 
22
  // Allocate filled `-inf` shared memory
23
+ extern __shared__ __align__(1024) logits_dtype_t smem_buffer[];
24
  #pragma unroll
25
  for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32)
26
  smem_buffer[i] = neg_inf;
 
28
  __syncthreads();
29
 
30
  // Assign sequence to each warp
31
+ const auto assign_task = [&](const uint32_t& num, const uint32_t& idx,
32
+ const uint32_t& start, const uint32_t& total) -> cute::tuple<uint32_t, uint32_t> {
33
+ const auto per = total / num, rem = total % num;
34
+ return {start + idx * per + cute::min(idx, rem), per + (idx < rem)};
35
  };
36
  CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len);
37
  CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len);
38
 
39
+ // Wait for primary kernel completion
40
+ cudaGridDependencySynchronize();
41
+
42
  if (cute::elect_one_sync()) {
43
  for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
44
+ const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN];
45
+ const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1;
46
+ const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment;
47
 
48
  for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) {
49
+ const auto right = cute::min(left + BLOCK_KV, static_cast<uint32_t>(stride_logits));
50
  if (right <= ks or ke <= left) {
51
+ cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(logits_dtype_t));
52
  } else {
53
  if (left < aligned_ks)
54
+ cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(logits_dtype_t));
55
  if (aligned_ke < right)
56
+ cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(logits_dtype_t));
57
  }
58
  }
59
  }
60
  }
61
+ __syncwarp();
62
 
63
  for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
64
+ const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN];
65
+ const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1;
66
+ const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment;
67
  for (uint32_t j = aligned_ks; j < ks; ++ j)
68
  logits[i * stride_logits + j] = neg_inf;
69
  for (uint32_t j = ke; j < aligned_ke; ++ j)
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/impls/smxx_layout.cuh CHANGED
@@ -1,13 +1,16 @@
1
  #pragma once
2
 
 
3
  #include <deep_gemm/common/utils.cuh>
 
 
4
 
5
  namespace deep_gemm {
6
 
7
  template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K,
8
  uint32_t PADDED_SF_K = SF_K + (1 - (SF_K % 2))>
9
- __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
10
- typedef typename Vectorized<sizeof(float) * SF_K>::vec_t in_vec_t;
11
  constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float);
12
  constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec;
13
 
@@ -15,16 +18,19 @@ __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
15
  extern __shared__ float smem_buffer[];
16
  constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(float));
17
  const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
18
- const auto tma_aligned_mn = align<uint32_t>(mn, kNumTMAAlignedElems);
19
 
20
  // Shift into the block
21
  sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
22
  out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * SF_K;
23
  const auto& local_sf = reinterpret_cast<const in_vec_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
24
 
 
 
 
25
  // Load
26
  for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) {
27
- auto in_vec = __ldg(local_sf + i);
28
  const auto& in_values = reinterpret_cast<float*>(&in_vec);
29
 
30
  const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec;
@@ -39,26 +45,29 @@ __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
39
  for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) {
40
  const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn;
41
  const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx;
42
- out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx);
43
  }
44
  }
45
 
46
  // NOTES: the two kernels below always pack the K dimension
47
 
48
  template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K>
49
- __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) {
50
  extern __shared__ uint32_t smem_buffer[];
51
 
52
  // Shapes and strides
53
- constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u);
54
  constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(int));
55
  const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
56
- const auto tma_aligned_mn = align<uint64_t>(mn, kNumTMAAlignedElems);
57
 
58
  // Shift into the group
59
  sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
60
  out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * kNumPackedSFK;
61
 
 
 
 
62
  // Load FP32 SFs
63
  DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size");
64
  const auto local_sf = reinterpret_cast<uint32_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
@@ -66,13 +75,13 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con
66
  const auto num_uint4 = num_values / 4;
67
  #pragma unroll
68
  for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) {
69
- const auto& [x, y, z, w] = __ldg(reinterpret_cast<uint4*>(local_sf) + i);
70
- st_shared(reinterpret_cast<uint4*>(smem_buffer) + i, x, y, z, w);
71
  }
72
 
73
  // Fill unaligned values as well
74
  if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values)
75
- st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx));
76
  __syncthreads();
77
 
78
  // Pack into UE8M0 and store
@@ -85,7 +94,7 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con
85
  #pragma unroll
86
  for (uint32_t j = 0; j < 4; ++ j) {
87
  const auto sf_k_idx = sf_k_pack_idx * 4 + j;
88
- values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0;
89
  }
90
 
91
  // Pack and store
@@ -101,8 +110,9 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con
101
 
102
  template <uint32_t kNumGroups, uint32_t kNumThreads,
103
  uint32_t BLOCK_MN, uint32_t BLOCK_PACKED_SF_K, bool kTransposed = true>
104
- __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
105
- const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) {
 
106
  // Always packing the K dimension
107
  // NOTES: should also assert `mn % 4 == 0` at launch
108
  DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)");
@@ -120,11 +130,14 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
120
 
121
  // Each warp is responsible for a packed row
122
  const auto warp_idx = threadIdx.x / 32;
123
- const auto lane_idx = get_lane_idx();
124
  const auto packed_sf_k_idx = static_cast<uint64_t>(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx;
125
  if (warp_idx >= in_block_packed_sf_k)
126
  return;
127
 
 
 
 
128
  // Make an offset on the input
129
  uint32_t input_offset = 0;
130
  if constexpr (kNumGroups > 1) {
@@ -134,18 +147,18 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
134
  #pragma unroll
135
  for (uint32_t i = 0; i < 4; ++ i) {
136
  const auto group_idx = lane_idx * 4 + i;
137
- group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0;
138
  }
139
  __syncwarp();
140
 
141
  // Make the offset
142
  sf_k = 0;
143
- auto sum_packed_sf_k = 0;
144
  #pragma unroll
145
  for (uint32_t i = 0; i < kNumGroups; ++ i) {
146
- const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4);
147
  sf_k += sf_k_in_group;
148
- sum_packed_sf_k += ceil_div(sf_k_in_group, 4u);
149
  if (packed_sf_k_idx < sum_packed_sf_k)
150
  break;
151
  if (const auto remainder = sf_k_in_group % 4; remainder > 0)
@@ -153,14 +166,14 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
153
  }
154
  }
155
 
156
- for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) {
157
  // Load
158
  uint4 values[4];
159
  #pragma unroll
160
  for (uint32_t j = 0; j < 4; ++ j) {
161
  values[j] = make_uint4(0, 0, 0, 0);
162
  if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k)
163
- values[j] = __ldg(reinterpret_cast<uint4*>(sf + sf_k_idx * mn) + mn_idx);
164
  }
165
 
166
  // Pack and store
 
1
  #pragma once
2
 
3
+ #include <deep_gemm/common/math.cuh>
4
  #include <deep_gemm/common/utils.cuh>
5
+ #include <deep_gemm/ptx/ld_st.cuh>
6
+ #include <deep_gemm/ptx/utils.cuh>
7
 
8
  namespace deep_gemm {
9
 
10
  template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K,
11
  uint32_t PADDED_SF_K = SF_K + (1 - (SF_K % 2))>
12
+ CUTLASS_GLOBAL void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
13
+ typedef typename utils::Vectorized<sizeof(float) * SF_K>::vec_t in_vec_t;
14
  constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float);
15
  constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec;
16
 
 
18
  extern __shared__ float smem_buffer[];
19
  constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(float));
20
  const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
21
+ const auto tma_aligned_mn = math::align<uint32_t>(mn, kNumTMAAlignedElems);
22
 
23
  // Shift into the block
24
  sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
25
  out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * SF_K;
26
  const auto& local_sf = reinterpret_cast<const in_vec_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
27
 
28
+ // Wait for primary kernel completion
29
+ cudaGridDependencySynchronize();
30
+
31
  // Load
32
  for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) {
33
+ auto in_vec = local_sf[i];
34
  const auto& in_values = reinterpret_cast<float*>(&in_vec);
35
 
36
  const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec;
 
45
  for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) {
46
  const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn;
47
  const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx;
48
+ out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ptx::ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx);
49
  }
50
  }
51
 
52
  // NOTES: the two kernels below always pack the K dimension
53
 
54
  template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K>
55
+ CUTLASS_GLOBAL void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) {
56
  extern __shared__ uint32_t smem_buffer[];
57
 
58
  // Shapes and strides
59
+ constexpr auto kNumPackedSFK = math::constexpr_ceil_div(SF_K, 4u);
60
  constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(int));
61
  const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
62
+ const auto tma_aligned_mn = math::align<uint64_t>(mn, kNumTMAAlignedElems);
63
 
64
  // Shift into the group
65
  sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
66
  out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * kNumPackedSFK;
67
 
68
+ // Wait for primary kernel completion
69
+ cudaGridDependencySynchronize();
70
+
71
  // Load FP32 SFs
72
  DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size");
73
  const auto local_sf = reinterpret_cast<uint32_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
 
75
  const auto num_uint4 = num_values / 4;
76
  #pragma unroll
77
  for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) {
78
+ const auto& [x, y, z, w] = reinterpret_cast<const uint4*>(local_sf)[i];
79
+ ptx::st_shared(reinterpret_cast<uint4*>(smem_buffer) + i, x, y, z, w);
80
  }
81
 
82
  // Fill unaligned values as well
83
  if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values)
84
+ ptx::st_shared(smem_buffer + unaligned_idx, local_sf[unaligned_idx]);
85
  __syncthreads();
86
 
87
  // Pack into UE8M0 and store
 
94
  #pragma unroll
95
  for (uint32_t j = 0; j < 4; ++ j) {
96
  const auto sf_k_idx = sf_k_pack_idx * 4 + j;
97
+ values[j] = sf_k_idx < SF_K ? ptx::ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0;
98
  }
99
 
100
  // Pack and store
 
110
 
111
  template <uint32_t kNumGroups, uint32_t kNumThreads,
112
  uint32_t BLOCK_MN, uint32_t BLOCK_PACKED_SF_K, bool kTransposed = true>
113
+ CUTLASS_GLOBAL void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
114
+ const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k,
115
+ const uint32_t gran_k) {
116
  // Always packing the K dimension
117
  // NOTES: should also assert `mn % 4 == 0` at launch
118
  DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)");
 
130
 
131
  // Each warp is responsible for a packed row
132
  const auto warp_idx = threadIdx.x / 32;
133
+ const auto lane_idx = ptx::get_lane_idx();
134
  const auto packed_sf_k_idx = static_cast<uint64_t>(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx;
135
  if (warp_idx >= in_block_packed_sf_k)
136
  return;
137
 
138
+ // Wait for primary kernel completion
139
+ cudaGridDependencySynchronize();
140
+
141
  // Make an offset on the input
142
  uint32_t input_offset = 0;
143
  if constexpr (kNumGroups > 1) {
 
147
  #pragma unroll
148
  for (uint32_t i = 0; i < 4; ++ i) {
149
  const auto group_idx = lane_idx * 4 + i;
150
+ group_ks[i] = group_idx < kNumGroups ? ks[group_idx] : 0;
151
  }
152
  __syncwarp();
153
 
154
  // Make the offset
155
  sf_k = 0;
156
+ uint32_t sum_packed_sf_k = 0;
157
  #pragma unroll
158
  for (uint32_t i = 0; i < kNumGroups; ++ i) {
159
+ const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / gran_k, i / 4);
160
  sf_k += sf_k_in_group;
161
+ sum_packed_sf_k += math::ceil_div(sf_k_in_group, 4u);
162
  if (packed_sf_k_idx < sum_packed_sf_k)
163
  break;
164
  if (const auto remainder = sf_k_in_group % 4; remainder > 0)
 
166
  }
167
  }
168
 
169
+ for (uint32_t mn_idx = ptx::get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) {
170
  // Load
171
  uint4 values[4];
172
  #pragma unroll
173
  for (uint32_t j = 0; j < 4; ++ j) {
174
  values[j] = make_uint4(0, 0, 0, 0);
175
  if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k)
176
+ values[j] = reinterpret_cast<const uint4*>(sf + sf_k_idx * mn)[mn_idx];
177
  }
178
 
179
  // Pack and store
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/layout/mega_moe.cuh ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cute/numeric/math.hpp>
4
+
5
+ #include <deep_gemm/common/math.cuh>
6
+ #include <deep_gemm/common/exception.cuh>
7
+
8
+ namespace deep_gemm::layout {
9
+
10
+ static constexpr int kNumCandidateBlockMs = 7;
11
+ static constexpr int kCandidateBlockM[kNumCandidateBlockMs] = {8, 16, 32, 64, 96, 128, 192};
12
+ static constexpr int kMaxCandidateBlockM = 192;
13
+ static constexpr int kMinCandidateBlockM = 8;
14
+ static constexpr int kLCMCandidateBlockM = 384;
15
+
16
+ // Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding, among all possible BLOCK_M
17
+ template <typename T>
18
+ CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk,
19
+ T num_experts_per_rank) {
20
+ const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank;
21
+ const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank);
22
+ return math::constexpr_align(
23
+ num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (static_cast<T>(kMaxCandidateBlockM) - 1),
24
+ static_cast<T>(kLCMCandidateBlockM));
25
+ }
26
+
27
+ // SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M
28
+ template <typename T>
29
+ CUTLASS_HOST_DEVICE constexpr T get_num_padded_sf_pool_tokens(T num_max_pool_tokens, T block_m) {
30
+ return (num_max_pool_tokens / block_m) * math::constexpr_align(block_m, static_cast<T>(128));
31
+ }
32
+
33
+ // Per-token source metadata for combine write-back
34
+ struct TokenSrcMetadata {
35
+ uint32_t rank_idx;
36
+ uint32_t token_idx;
37
+ uint32_t topk_idx;
38
+ };
39
+
40
+ struct Workspace {
41
+ void* base;
42
+ uint32_t num_ranks, num_experts;
43
+ uint32_t num_experts_per_rank;
44
+ uint32_t num_max_tokens_per_rank;
45
+ uint32_t num_max_recv_tokens_per_expert;
46
+
47
+ // Pool capacity: all local experts share a contiguous token pool
48
+ uint32_t num_max_pool_tokens;
49
+ uint32_t num_max_pool_blocks;
50
+
51
+ // For both grid barrier and NVLink barrier
52
+ static constexpr uint64_t kNumBarrierSignalBytes = 32;
53
+
54
+ CUTLASS_HOST_DEVICE
55
+ Workspace(void* base,
56
+ const uint32_t& num_ranks,
57
+ const uint32_t& num_experts,
58
+ const uint32_t& num_max_tokens_per_rank,
59
+ const uint32_t& num_topk):
60
+ base(base),
61
+ num_ranks(num_ranks), num_experts(num_experts),
62
+ num_max_tokens_per_rank(num_max_tokens_per_rank) {
63
+ num_experts_per_rank = num_experts / num_ranks;
64
+ num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank;
65
+ num_max_pool_tokens = get_num_max_pool_tokens(num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank);
66
+ num_max_pool_blocks = num_max_pool_tokens / kMinCandidateBlockM;
67
+ }
68
+
69
+ CUTLASS_HOST_DEVICE
70
+ uint64_t get_num_bytes() const {
71
+ uint64_t num_bytes = 0;
72
+
73
+ // Barrier
74
+ num_bytes += kNumBarrierSignalBytes;
75
+
76
+ // Expert send/recv count
77
+ num_bytes += num_experts * sizeof(uint64_t) * 2;
78
+
79
+ // Expert recv count sum
80
+ num_bytes += num_experts_per_rank * sizeof(uint64_t);
81
+
82
+ // L1 arrival count (padded to even entry count for `uint64_t` alignment of L2 mask)
83
+ num_bytes += math::align(num_max_pool_blocks, 2u) * sizeof(uint32_t);
84
+
85
+ // L2 block arrival mask
86
+ num_bytes += num_max_pool_blocks * sizeof(uint64_t);
87
+
88
+ // Dispatch pulling source token-topk
89
+ num_bytes += num_experts_per_rank * num_ranks * num_max_recv_tokens_per_expert * sizeof(int);
90
+
91
+ // Combine push source indices
92
+ num_bytes += num_max_pool_tokens * sizeof(TokenSrcMetadata);
93
+
94
+ // Align to TMA descriptor requirements
95
+ num_bytes = math::align<uint64_t>(num_bytes, 16);
96
+ return num_bytes;
97
+ }
98
+
99
+ CUTLASS_HOST_DEVICE
100
+ void* get_end_ptr() const {
101
+ return math::advance_ptr(base, get_num_bytes());
102
+ }
103
+
104
+ // Grid sync counters: `kNumBarrierSignalBytes` layout
105
+ // [ 0..15]: 4 x `uint32_t` grid sync counters
106
+ // [16..20]: `uint32_t` NVLink barrier counter
107
+ // [20..27]: 2 x `int` NVLink barrier signals (phase 0 and 1)
108
+ static constexpr uint32_t kNumMaxGridSyncCounters = 4;
109
+
110
+ template <uint32_t kIndex = 0>
111
+ CUTLASS_DEVICE
112
+ uint32_t* get_grid_sync_count_ptr() const {
113
+ DG_STATIC_ASSERT(kIndex < kNumMaxGridSyncCounters, "Grid sync index out of bounds");
114
+ return static_cast<uint32_t*>(base) + kIndex;
115
+ }
116
+
117
+ CUTLASS_DEVICE
118
+ uint32_t* get_nvl_barrier_counter_ptr() const {
119
+ return static_cast<uint32_t*>(base) + kNumMaxGridSyncCounters;
120
+ }
121
+
122
+ CUTLASS_DEVICE
123
+ int* get_nvl_barrier_signal_ptr(const uint32_t& phase) const {
124
+ // NOTES: the signal is signed, as we may minus
125
+ return math::advance_ptr<int>(base, (kNumMaxGridSyncCounters + 1) * sizeof(uint32_t) + phase * sizeof(int));
126
+ }
127
+
128
+ CUTLASS_DEVICE
129
+ uint64_t* get_expert_send_count_ptr(const uint32_t& expert_idx = 0) const {
130
+ return math::advance_ptr<uint64_t>(base, kNumBarrierSignalBytes) + expert_idx;
131
+ }
132
+
133
+ CUTLASS_DEVICE
134
+ uint64_t* get_expert_recv_count_ptr(
135
+ const uint32_t& rank_idx = 0, const uint32_t& expert_idx = 0) const {
136
+ return get_expert_send_count_ptr(num_experts) + rank_idx * num_experts_per_rank + expert_idx;
137
+ }
138
+
139
+ CUTLASS_DEVICE
140
+ uint64_t* get_expert_recv_count_sum_ptr(const uint32_t& expert_idx = 0) const {
141
+ return get_expert_send_count_ptr(num_experts * 2) + expert_idx;
142
+ }
143
+
144
+ CUTLASS_DEVICE
145
+ uint32_t* get_l1_arrival_count_ptr(const uint32_t& pool_block_idx = 0) const {
146
+ const auto base = get_expert_recv_count_sum_ptr(num_experts_per_rank);
147
+ return reinterpret_cast<uint32_t*>(base) + pool_block_idx;
148
+ }
149
+
150
+ CUTLASS_DEVICE
151
+ uint64_t* get_l2_arrival_mask_ptr(const uint32_t& pool_block_idx = 0) const {
152
+ // Pad L1 entry count to even so that the `l2_arrival_mask` is 8-byte aligned
153
+ const auto base = get_l1_arrival_count_ptr(math::align(num_max_pool_blocks, 2u));
154
+ return reinterpret_cast<uint64_t*>(base) + pool_block_idx;
155
+ }
156
+
157
+ // For dispatch pulling
158
+ CUTLASS_DEVICE
159
+ uint32_t* get_src_token_topk_idx_ptr(
160
+ const uint32_t& expert_idx = 0, const uint32_t& rank_idx = 0, const uint32_t& token_idx = 0) const {
161
+ const auto base = get_l2_arrival_mask_ptr(num_max_pool_blocks);
162
+ return reinterpret_cast<uint32_t*>(base) +
163
+ expert_idx * (num_ranks * num_max_recv_tokens_per_expert) +
164
+ rank_idx * num_max_recv_tokens_per_expert + token_idx;
165
+ }
166
+
167
+ // For combine usages
168
+ CUTLASS_DEVICE
169
+ TokenSrcMetadata* get_token_src_metadata_ptr(const uint32_t& pool_token_idx = 0) const {
170
+ const auto base = reinterpret_cast<TokenSrcMetadata*>(get_src_token_topk_idx_ptr(num_experts_per_rank));
171
+ return base + pool_token_idx;
172
+ }
173
+ };
174
+
175
+ struct Data {
176
+ uint32_t num_bytes;
177
+ bool require_tma_alignment;
178
+ void* base;
179
+
180
+ CUTLASS_HOST_DEVICE
181
+ constexpr explicit Data(
182
+ const uint32_t& num_bytes,
183
+ const bool& require_tma_alignment = true,
184
+ void* base = nullptr) :
185
+ num_bytes(num_bytes), require_tma_alignment(require_tma_alignment), base(base) {
186
+ DG_UNIFIED_ASSERT(num_bytes % 16 == 0 or not require_tma_alignment);
187
+ }
188
+
189
+ template <typename dtype_t = uint32_t>
190
+ CUTLASS_HOST_DEVICE constexpr dtype_t get_num_bytes() const {
191
+ return static_cast<dtype_t>(num_bytes);
192
+ }
193
+
194
+ template <typename dtype_t = void>
195
+ CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const {
196
+ return static_cast<dtype_t*>(base);
197
+ }
198
+
199
+ CUTLASS_HOST_DEVICE void set_base_ptr(void* ptr) {
200
+ base = ptr;
201
+ }
202
+ };
203
+
204
+ struct Buffer {
205
+ Data data_layout;
206
+ uint32_t num_ranks;
207
+ uint32_t num_max_tokens_per_rank;
208
+
209
+ void* base;
210
+
211
+ CUTLASS_HOST_DEVICE
212
+ Buffer(const Data& data_layout,
213
+ const uint32_t& num_ranks,
214
+ const uint32_t& max_num_tokens_per_rank,
215
+ void* base = nullptr) :
216
+ data_layout(data_layout),
217
+ num_ranks(num_ranks), num_max_tokens_per_rank(max_num_tokens_per_rank),
218
+ base(base) {}
219
+
220
+ CUTLASS_HOST_DEVICE
221
+ uint64_t get_num_bytes_per_rank() const {
222
+ return num_max_tokens_per_rank * data_layout.get_num_bytes<uint64_t>();
223
+ }
224
+
225
+ CUTLASS_HOST_DEVICE
226
+ uint64_t get_num_bytes() const {
227
+ return get_num_bytes_per_rank() * num_ranks;
228
+ }
229
+
230
+ template <typename dtype_t = void>
231
+ CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const {
232
+ return static_cast<dtype_t*>(base);
233
+ }
234
+
235
+ CUTLASS_HOST_DEVICE
236
+ void* get_end_ptr() const {
237
+ return math::advance_ptr(base, get_num_bytes());
238
+ }
239
+
240
+ CUTLASS_HOST_DEVICE
241
+ Buffer get_rank_buffer(const uint32_t& rank_idx) const {
242
+ return {
243
+ data_layout,
244
+ 1, num_max_tokens_per_rank,
245
+ math::advance_ptr(base, get_num_bytes_per_rank() * rank_idx)
246
+ };
247
+ }
248
+
249
+ CUTLASS_HOST_DEVICE
250
+ Data get_data_buffer(const uint32_t& token_idx, const bool& global = false) const {
251
+ DG_DEVICE_ASSERT(num_ranks == 1 or global);
252
+ return Data(
253
+ data_layout.num_bytes,
254
+ data_layout.require_tma_alignment,
255
+ math::advance_ptr(base, data_layout.get_num_bytes<uint64_t>() * token_idx)
256
+ );
257
+ }
258
+ };
259
+
260
+ } // namespace deep_gemm::layout
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/layout/sym_buffer.cuh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <deep_gemm/common/exception.cuh>
4
+
5
+ namespace deep_gemm::layout {
6
+
7
+ constexpr static uint32_t kNumMaxRanks = 72;
8
+
9
+ template <uint32_t kNumRanks = kNumMaxRanks>
10
+ struct SymBuffer {
11
+ int64_t base;
12
+ int64_t offsets[kNumMaxRanks];
13
+ uint32_t rank_idx;
14
+
15
+ DG_STATIC_ASSERT(kNumRanks <= kNumMaxRanks, "Too many ranks");
16
+
17
+ SymBuffer() = default;
18
+
19
+ template <typename Container>
20
+ explicit SymBuffer(const Container& c, const uint32_t& rank_idx): rank_idx(rank_idx) {
21
+ const auto size = static_cast<uint32_t>(c.size());
22
+ base = c[rank_idx];
23
+ for (uint32_t i = 0; i < kNumMaxRanks; ++ i)
24
+ offsets[i] = i < size ? (c[i] - base) : 0;
25
+ }
26
+
27
+ #if defined(__CUDA_ARCH__) or defined(__CLION_IDE__)
28
+ template <typename ptr_t = void*>
29
+ CUTLASS_DEVICE ptr_t get_base_ptr() const {
30
+ return reinterpret_cast<ptr_t>(base);
31
+ }
32
+
33
+ template <typename ptr_t>
34
+ CUTLASS_DEVICE ptr_t map(const ptr_t& ptr, const uint32_t& dst_rank_idx) const {
35
+ int64_t mapped_ptr = offsets[dst_rank_idx] + reinterpret_cast<int64_t>(ptr);
36
+ return *reinterpret_cast<ptr_t*>(&mapped_ptr);
37
+ }
38
+ #endif
39
+ };
40
+
41
+ } // namespace deep_gemm::layout
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/mma/sm100.cuh ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cute/atom/mma_traits_sm100.hpp>
4
+ #include <cute/arch/mma_sm100_umma.hpp>
5
+
6
+ #include <deep_gemm/common/exception.cuh>
7
+ #include <deep_gemm/common/math.cuh>
8
+ #include <deep_gemm/common/tma_copy.cuh>
9
+
10
+ namespace deep_gemm::mma::sm100 {
11
+
12
+ /// Shared memory descriptor
13
+ CUTLASS_DEVICE
14
+ cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr,
15
+ const uint32_t& stride_byte_offset, const uint32_t& leading_byte_offset) {
16
+ cute::UMMA::SmemDescriptor desc;
17
+
18
+ // Set the version for SM100
19
+ desc.version_ = 1;
20
+
21
+ // Legacy mode
22
+ desc.lbo_mode_ = 0;
23
+
24
+ // Layout
25
+ desc.layout_type_ = static_cast<uint8_t>(layout);
26
+
27
+ // Start address
28
+ const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr);
29
+ desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
30
+
31
+ // Base offset
32
+ desc.base_offset_ = 0;
33
+
34
+ // SBO and LBO
35
+ desc.stride_byte_offset_ = stride_byte_offset >> 4;
36
+ desc.leading_byte_offset_ = leading_byte_offset >> 4;
37
+
38
+ return desc;
39
+ }
40
+
41
+ CUTLASS_DEVICE
42
+ cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) {
43
+ // NOTES: the UTCCP layout is K-major by default
44
+ // Atom size: 8 x 128 bits
45
+ // {SBO, LBO} means the byte stride between atoms on {MN, K}
46
+ // Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero
47
+ return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0);
48
+ }
49
+
50
+ CUTLASS_DEVICE
51
+ void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) {
52
+ const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr);
53
+ desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
54
+ }
55
+
56
+ CUTLASS_DEVICE
57
+ static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) {
58
+ return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16;
59
+ }
60
+
61
+ /// UMMA descriptors
62
+ // ReSharper disable once CppNotAllPathsReturnValue
63
+ template <cute::UMMA::Major kMajorMode, uint32_t kSwizzleMode, bool kUseBase32, typename dtype_t>
64
+ constexpr static cute::UMMA::LayoutType to_umma_layout_type() {
65
+ DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or
66
+ kSwizzleMode == 32 or kSwizzleMode == 64 or
67
+ kSwizzleMode == 128, "Invalid swizzling mode");
68
+ // A special case
69
+ if constexpr ((cute::is_same_v<dtype_t, float> and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) {
70
+ DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base");
71
+ return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B;
72
+ }
73
+
74
+ // Normal cases
75
+ if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE;
76
+ if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE;
77
+ if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B;
78
+ if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B;
79
+ if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B;
80
+ }
81
+
82
+ template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
83
+ CUTLASS_DEVICE
84
+ constexpr uint32_t get_umma_desc_stride_k() {
85
+ return kMajorMode == cute::UMMA::Major::K ? 1 : tma::get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
86
+ }
87
+
88
+ template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
89
+ CUTLASS_DEVICE
90
+ uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) {
91
+ return base + (((offset + k_idx * get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) * static_cast<uint32_t>(sizeof(dtype_t))) >> 4u);
92
+ }
93
+
94
+ template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, bool kUseBase32 = false, typename dtype_t>
95
+ CUTLASS_DEVICE
96
+ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) {
97
+ const uint32_t stride_k = get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>();
98
+ const auto layout_type = to_umma_layout_type<kMajorMode, kSwizzleMode, kUseBase32, dtype_t>();
99
+ const auto num_non_contiguous = 128 / get_atom_base(layout_type);
100
+ if constexpr (kMajorMode == cute::UMMA::Major::K) {
101
+ // NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)`
102
+ // also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis
103
+ DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
104
+
105
+ // Atom size: 8 x `kSwizzleMode` (in bytes, on K)
106
+ // {SBO, LBO} means the byte stride between atoms on {MN, K}
107
+ // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0
108
+ const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t);
109
+ const uint32_t leading_byte_offset = 0;
110
+ return make_smem_desc(layout_type,
111
+ base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
112
+ stride_byte_offset, leading_byte_offset);
113
+ } else {
114
+ constexpr uint32_t BLOCK_MN_ATOM = tma::get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
115
+
116
+ // Must have no in-atom MN-idx
117
+ // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time
118
+ DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0);
119
+ DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling");
120
+
121
+ // Atom size: `kSwizzleMode` (in bytes, on MN) x 8
122
+ // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving
123
+ // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling
124
+ // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling
125
+ uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t);
126
+ uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t);
127
+ if constexpr (kSwizzleMode == 16)
128
+ math::swap(stride_byte_offset, leading_byte_offset);
129
+ return make_smem_desc(layout_type,
130
+ base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
131
+ stride_byte_offset, leading_byte_offset);
132
+ }
133
+ }
134
+
135
+ CUTLASS_DEVICE uint64_t make_runtime_instr_desc_with_sf_id(
136
+ cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) {
137
+ desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id;
138
+ return static_cast<uint64_t>(static_cast<uint32_t>(desc)) << 32;
139
+ }
140
+
141
+ CUTLASS_DEVICE void update_instr_desc_with_umma_n(
142
+ cute::UMMA::InstrDescriptorBlockScaled& desc, const uint32_t& umma_n) {
143
+ desc.n_dim_ = umma_n >> 3;
144
+ }
145
+
146
+ CUTLASS_DEVICE void update_instr_desc_with_umma_n(
147
+ cute::UMMA::InstrDescriptor& desc, const uint32_t& umma_n) {
148
+ desc.n_dim_ = umma_n >> 3;
149
+ }
150
+
151
+ } // namespace deep_gemm::mma::sm100
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/mma/sm90.cuh ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cute/arch/cluster_sm90.hpp>
4
+ #include <cute/arch/mma_sm90_desc.hpp>
5
+ #include <cute/arch/mma_sm90_gmma.hpp>
6
+ #include <cute/arch/mma_sm90_gmma_ext.hpp>
7
+ #include <cute/arch/mma_sm100_desc.hpp>
8
+
9
+ #include <deep_gemm/common/exception.cuh>
10
+
11
+ namespace deep_gemm::mma::sm90 {
12
+
13
+ /// MMA
14
+ template <int N_, typename MMA>
15
+ struct FP8MMA {
16
+ template <size_t ...Idx>
17
+ CUTLASS_DEVICE static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
18
+ using namespace cute::SM90::GMMA;
19
+ MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
20
+ }
21
+
22
+ CUTLASS_DEVICE static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
23
+ call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence<N_ / 2>{});
24
+ }
25
+
26
+ static constexpr int M = 64;
27
+ static constexpr int N = N_;
28
+ static constexpr int K = 32;
29
+ static constexpr int kNumAccum = M * N / 128;
30
+ };
31
+
32
+ template <int N>
33
+ struct FP8MMASelector {
34
+ static constexpr auto select_mma() {
35
+ using namespace cute::SM90::GMMA;
36
+ if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN();
37
+ if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN();
38
+ if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN();
39
+ if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN();
40
+ if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN();
41
+ if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN();
42
+ if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN();
43
+ if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN();
44
+ if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN();
45
+ if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN();
46
+ if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN();
47
+ if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN();
48
+ if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN();
49
+ if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN();
50
+ if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN();
51
+ if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN();
52
+ if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN();
53
+ if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN();
54
+ if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN();
55
+ if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN();
56
+ if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN();
57
+ if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN();
58
+ if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN();
59
+ if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN();
60
+ if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN();
61
+ if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN();
62
+ if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN();
63
+ if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN();
64
+ if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN();
65
+ if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN();
66
+ if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN();
67
+ if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN();
68
+ }
69
+
70
+ static constexpr auto select_type() {
71
+ return FP8MMA<N, decltype(select_mma())>();
72
+ }
73
+
74
+ using type = decltype(select_type());
75
+ };
76
+
77
+ template <int N_, typename MMA>
78
+ struct BF16MMA {
79
+ template <size_t ...Idx>
80
+ CUTLASS_DEVICE static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
81
+ using namespace cute::SM90::GMMA;
82
+ MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
83
+ }
84
+
85
+ CUTLASS_DEVICE static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
86
+ call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
87
+ }
88
+
89
+ static constexpr int M = 64;
90
+ static constexpr int N = N_;
91
+ static constexpr int K = 16;
92
+ static constexpr int kNumAccum = M * N / 128;
93
+ };
94
+
95
+ template <cute::UMMA::Major kMajor>
96
+ constexpr cute::SM90::GMMA::Major to_sm90_major() {
97
+ DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness");
98
+ return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN;
99
+ }
100
+
101
+ template <int N,
102
+ cute::UMMA::Major kMajorA = cute::UMMA::Major::K,
103
+ cute::UMMA::Major kMajorB = cute::UMMA::Major::K>
104
+ struct BF16MMASelector {
105
+ static constexpr auto select_mma() {
106
+ using namespace cute::SM90::GMMA;
107
+ constexpr auto kGMMAMajorA = to_sm90_major<kMajorA>();
108
+ constexpr auto kGMMAMajorB = to_sm90_major<kMajorB>();
109
+ if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
110
+ if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
111
+ if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
112
+ if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
113
+ if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
114
+ if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
115
+ if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
116
+ if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
117
+ if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
118
+ if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
119
+ if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
120
+ if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
121
+ if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
122
+ if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
123
+ if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
124
+ if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
125
+ if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
126
+ if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
127
+ if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
128
+ if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
129
+ if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
130
+ if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
131
+ if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
132
+ if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
133
+ if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
134
+ if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
135
+ if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
136
+ if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
137
+ if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
138
+ if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
139
+ if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
140
+ if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
141
+ }
142
+
143
+ static constexpr auto select_type() {
144
+ return BF16MMA<N, decltype(select_mma())>();
145
+ }
146
+
147
+ using type = decltype(select_type());
148
+ };
149
+
150
+ template <int N_, typename MMA>
151
+ struct TF32MMARS {
152
+ template <size_t ...Idx>
153
+ CUTLASS_DEVICE static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
154
+ using namespace cute::SM90::GMMA;
155
+ MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
156
+ }
157
+
158
+ CUTLASS_DEVICE static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) {
159
+ call_fma_impl(reinterpret_cast<uint32_t*>(a), desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
160
+ }
161
+
162
+ static constexpr int M = 64;
163
+ static constexpr int N = N_;
164
+ static constexpr int K = 8;
165
+ static constexpr int kNumAccum = M * N / 128;
166
+ };
167
+
168
+ template <int N, bool kUseRS = true>
169
+ struct TF32MMASelector {
170
+ static constexpr auto select_mma() {
171
+ using namespace cute::SM90::GMMA;
172
+ if constexpr (kUseRS) {
173
+ if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN();
174
+ if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN();
175
+ if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN();
176
+ if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN();
177
+ if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN();
178
+ if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN();
179
+ DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N");
180
+ }
181
+ }
182
+
183
+ static constexpr auto select_type() {
184
+ if constexpr (kUseRS) {
185
+ return TF32MMARS<N, decltype(select_mma())>();
186
+ } else {
187
+ DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now");
188
+ }
189
+ }
190
+
191
+ using type = decltype(select_type());
192
+ };
193
+
194
+ /// Shared memory descriptor
195
+ template <class PointerType>
196
+ CUTLASS_DEVICE cute::GmmaDescriptor
197
+ make_smem_desc(PointerType smem_ptr, const int& layout_type,
198
+ const uint32_t& leading_byte_offset = 0,
199
+ const uint32_t& stride_byte_offset = 1024) {
200
+ // NOTES: the default LBO and SBO are for K-major types
201
+ cute::GmmaDescriptor desc;
202
+ const auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
203
+ desc.bitfield.start_address_ = uint_ptr >> 4;
204
+ desc.bitfield.layout_type_ = layout_type;
205
+ desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
206
+ desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
207
+ desc.bitfield.base_offset_ = 0;
208
+ return desc;
209
+ }
210
+
211
+ template <uint32_t BLOCK_INNER, uint32_t kSwizzleMode, typename dtype_t>
212
+ constexpr uint32_t get_inner_block_atom_size() {
213
+ return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t);
214
+ }
215
+
216
+ template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
217
+ CUTLASS_DEVICE
218
+ constexpr uint32_t get_gmma_desc_stride_k() {
219
+ return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
220
+ }
221
+
222
+ // ReSharper disable once CppNotAllPathsReturnValue
223
+ template <cute::UMMA::Major kMajorMode, uint32_t kSwizzleMode, typename dtype_t>
224
+ constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() {
225
+ DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or
226
+ kSwizzleMode == 32 or kSwizzleMode == 64 or
227
+ kSwizzleMode == 128, "Invalid swizzling mode");
228
+
229
+ // Normal cases
230
+ if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE;
231
+ if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE;
232
+ if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32;
233
+ if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64;
234
+ if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128;
235
+ }
236
+
237
+ template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
238
+ CUTLASS_DEVICE
239
+ uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) {
240
+ return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) * static_cast<uint32_t>(sizeof(dtype_t))) >> 4u);
241
+ }
242
+
243
+ template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
244
+ CUTLASS_DEVICE
245
+ cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) {
246
+ const uint32_t stride_k = get_gmma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>();
247
+ const auto layout_type = to_gmma_layout_type<kMajorMode, kSwizzleMode, dtype_t>();
248
+ constexpr uint32_t num_non_contiguous = 128 / 16;
249
+ if constexpr (kMajorMode == cute::UMMA::Major::K) {
250
+ // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128
251
+ DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
252
+
253
+ // Atom size: 8 x `kSwizzleMode` (in bytes, on K)
254
+ // {SBO, LBO} means the byte stride between atoms on {MN, K}
255
+ // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0
256
+ const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t);
257
+ const uint32_t leading_byte_offset = 0;
258
+ return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast<uint32_t>(layout_type),
259
+ leading_byte_offset, stride_byte_offset);
260
+ } else {
261
+ constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
262
+
263
+ // Must have no in-atom MN-idx
264
+ // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time
265
+ DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0);
266
+ DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling");
267
+
268
+ // Atom size: `kSwizzleMode` (in bytes, on MN) x 8
269
+ // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving
270
+ // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling
271
+ // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling
272
+ uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t);
273
+ uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t);
274
+ if constexpr (kSwizzleMode == 16)
275
+ math::swap(stride_byte_offset, leading_byte_offset);
276
+ return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast<uint32_t>(layout_type),
277
+ leading_byte_offset, stride_byte_offset);
278
+ }
279
+ }
280
+
281
+ // ReSharper disable once CppNotAllPathsReturnValue
282
+ template <uint32_t kHeadDim>
283
+ static constexpr int to_swizzle_cute_type() {
284
+ DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling");
285
+ if constexpr (kHeadDim == 32)
286
+ return static_cast<int>(cute::SM90::GMMA::LayoutType::B32);
287
+ if constexpr (kHeadDim == 64)
288
+ return static_cast<int>(cute::SM90::GMMA::LayoutType::B64);
289
+ if constexpr (kHeadDim == 128)
290
+ return static_cast<int>(cute::SM90::GMMA::LayoutType::B128);
291
+ }
292
+
293
+ } // namespace deep_gemm::mma::sm90
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/ld_st.cuh ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cuda/std/cstdint>
4
+ #include <cuda_bf16.h>
5
+
6
+ namespace deep_gemm::ptx {
7
+
8
+ // Compatibility: 256 bits LD/ST instructions
9
+ #if defined(CUDART_VERSION) and CUDART_VERSION >= 13000
10
+ using longlong4_t = longlong4_32a;
11
+ #define make_longlong4_t make_longlong4_32a
12
+ #else
13
+ struct alignas(32) longlong4_t { long long x, y, z, w; };
14
+ CUTLASS_HOST_DEVICE longlong4_t make_longlong4_t(
15
+ const long long& x, const long long& y, const long long& z, const long long& w) {
16
+ return {x, y, z, w};
17
+ }
18
+ #endif
19
+
20
+ /// LD/ST matrix
21
+ // TODO: remove `struct`
22
+ struct SM90_U32x2_LDSM_N {
23
+ CUTLASS_DEVICE static void
24
+ copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) {
25
+ asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
26
+ : "=r"(dst_0), "=r"(dst_1)
27
+ : "l"(__cvta_generic_to_shared(smem_src)));
28
+ }
29
+ };
30
+
31
+ struct SM90_U32x4_LDSM_N {
32
+ CUTLASS_DEVICE static void
33
+ copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) {
34
+ asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
35
+ : "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3)
36
+ : "l"(__cvta_generic_to_shared(smem_src)));
37
+ }
38
+ };
39
+
40
+ template <typename dtype_t>
41
+ struct SM90_U32x2_STSM_N {
42
+ CUTLASS_DEVICE static void
43
+ copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
44
+ DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype");
45
+ const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
46
+ asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
47
+ :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1]));
48
+ }
49
+ };
50
+
51
+ template <typename dtype_t>
52
+ struct SM90_U32x4_STSM_T {
53
+ CUTLASS_DEVICE static void
54
+ copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) {
55
+ DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype");
56
+ const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
57
+ *reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
58
+ asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16.trans [%0], {%1, %2, %3, %4};\n"
59
+ :: "l"(__cvta_generic_to_shared(smem_dst)),
60
+ "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
61
+ }
62
+ };
63
+
64
+ template <typename dtype_t>
65
+ struct SM100_U8x4_STSM_T {
66
+ __device__ __forceinline__ static void
67
+ copy(dtype_t src_0, void* smem_dst) {
68
+ DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype");
69
+ const uint32_t src = *reinterpret_cast<uint32_t*>(&src_0);
70
+ asm volatile("stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 [%0], {%1};\n"
71
+ :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src));
72
+ }
73
+ };
74
+
75
+ template <typename dtype_t>
76
+ struct SM100_U8x8_STSM_T {
77
+ __device__ __forceinline__ static void
78
+ copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
79
+ DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype");
80
+ const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
81
+ asm volatile("stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 [%0], {%1, %2};\n"
82
+ :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1]));
83
+ }
84
+ };
85
+
86
+ /// Shared memory
87
+ CUTLASS_DEVICE uint32_t ld_shared(const uint32_t* ptr) {
88
+ uint32_t ret;
89
+ asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr)));
90
+ return ret;
91
+ }
92
+
93
+ CUTLASS_DEVICE float2 ld_shared(const float2* ptr) {
94
+ float2 ret;
95
+ asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr)));
96
+ return ret;
97
+ }
98
+
99
+ CUTLASS_DEVICE float4 ld_shared(const float4* ptr) {
100
+ float4 ret;
101
+ asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr)));
102
+ return ret;
103
+ }
104
+
105
+ CUTLASS_DEVICE uint4 ld_shared(const uint4* ptr) {
106
+ uint4 ret;
107
+ asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr)));
108
+ return ret;
109
+ }
110
+
111
+ CUTLASS_DEVICE float ld_shared(const float* ptr) {
112
+ float ret;
113
+ asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr)));
114
+ return ret;
115
+ }
116
+
117
+ CUTLASS_DEVICE void st_shared(const float* ptr, float val) {
118
+ asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val));
119
+ }
120
+
121
+ CUTLASS_DEVICE void st_shared(const float2* ptr, float2 val) {
122
+ asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y));
123
+ }
124
+
125
+ CUTLASS_DEVICE void st_shared(const uint32_t* ptr, uint32_t val) {
126
+ asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val));
127
+ }
128
+
129
+ CUTLASS_DEVICE void st_shared(const void* ptr, uint32_t x, uint32_t y) {
130
+ asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y));
131
+ }
132
+
133
+ CUTLASS_DEVICE void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) {
134
+ asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w));
135
+ }
136
+
137
+ CUTLASS_DEVICE void st_shared(const __int128_t* ptr, __int128_t val) {
138
+ asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val));
139
+ }
140
+
141
+ CUTLASS_DEVICE void st_shared_bulk(void* smem_ptr, const uint32_t& num_bytes) {
142
+ // `size` must be 64-bit before PTX ISA 9.0
143
+ asm volatile("st.bulk.weak.shared::cta [%0], %1, 0;" ::
144
+ "l"(__cvta_generic_to_shared(smem_ptr)), "l"(static_cast<uint64_t>(num_bytes)));
145
+ }
146
+
147
+ /// Global memory
148
+ CUTLASS_DEVICE uint64_t ld_volatile(const uint64_t* ptr) {
149
+ uint64_t ret;
150
+ asm volatile("ld.volatile.global.b64 %0, [%1];" : "=l"(ret) : "l"(ptr));
151
+ return ret;
152
+ }
153
+
154
+ CUTLASS_DEVICE uint32_t ld_acq(const uint32_t* ptr) {
155
+ uint32_t ret;
156
+ asm volatile("ld.acquire.gpu.global.b32 %0, [%1];" : "=r"(ret) : "l"(ptr));
157
+ return ret;
158
+ }
159
+
160
+ CUTLASS_DEVICE uint64_t ld_acq_sys(const uint64_t* ptr) {
161
+ uint64_t ret;
162
+ asm volatile("ld.acquire.sys.global.b64 %0, [%1];" : "=l"(ret) : "l"(ptr));
163
+ return ret;
164
+ }
165
+
166
+ CUTLASS_DEVICE void st_relaxed_sys(const uint64_t* ptr, const uint64_t& value) {
167
+ asm volatile("st.L1::no_allocate.relaxed.sys.global.u64 [%0], %1;" :: "l"(ptr), "l"(value));
168
+ }
169
+
170
+ /// Atomics
171
+ CUTLASS_DEVICE uint64_t atomic_add(const uint64_t* ptr, const uint64_t& value) {
172
+ uint64_t ret;
173
+ asm volatile("atom.global.add.u64 %0, [%1], %2;" : "=l"(ret) : "l"(ptr), "l"(value));
174
+ return ret;
175
+ }
176
+
177
+ CUTLASS_DEVICE uint64_t atomic_add_sys(const uint64_t* ptr, const uint64_t& value) {
178
+ uint64_t ret;
179
+ asm volatile("atom.sys.global.add.u64 %0, [%1], %2;" : "=l"(ret) : "l"(ptr), "l"(value));
180
+ return ret;
181
+ }
182
+
183
+ CUTLASS_DEVICE uint32_t atomic_add_rel(const uint32_t* ptr, const uint32_t& value) {
184
+ uint32_t ret;
185
+ asm volatile("atom.release.gpu.global.add.u32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value));
186
+ return ret;
187
+ }
188
+
189
+ CUTLASS_DEVICE void red_add(const int* ptr, const int& value) {
190
+ asm volatile("red.gpu.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value));
191
+ }
192
+
193
+ CUTLASS_DEVICE void red_add(const uint32_t* ptr, const uint32_t& value) {
194
+ asm volatile("red.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value));
195
+ }
196
+
197
+ CUTLASS_DEVICE void red_or_rel_sys(const uint64_t* ptr, const uint64_t& value) {
198
+ asm volatile("red.release.sys.global.or.b64 [%0], %1;" :: "l"(ptr), "l"(value));
199
+ }
200
+
201
+ CUTLASS_DEVICE void red_or_rel_gpu(uint64_t* ptr, const uint64_t& value) {
202
+ asm volatile("red.release.gpu.global.or.b64 [%0], %1;" :: "l"(ptr), "l"(value));
203
+ }
204
+
205
+ CUTLASS_DEVICE void red_add_rel(const uint32_t* ptr, const uint32_t& value) {
206
+ asm volatile("red.release.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value));
207
+ }
208
+
209
+ CUTLASS_DEVICE void red_add_rel_sys(const int* ptr, const int& value) {
210
+ asm volatile("red.release.sys.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value));
211
+ }
212
+
213
+ CUTLASS_DEVICE int ld_acq_sys(const int* ptr) {
214
+ int ret;
215
+ asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
216
+ return ret;
217
+ }
218
+
219
+ CUTLASS_DEVICE uint32_t ld_acq_sys(const uint32_t* ptr) {
220
+ uint32_t ret;
221
+ asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
222
+ return ret;
223
+ }
224
+
225
+ CUTLASS_DEVICE uint64_t ld_acq_gpu(const uint64_t* ptr) {
226
+ uint64_t ret;
227
+ asm volatile("ld.acquire.gpu.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
228
+ return ret;
229
+ }
230
+
231
+ /// Predicated loads
232
+ CUTLASS_DEVICE longlong4_t ld_gez_pred(const longlong4_t* ptr, const int& pred) {
233
+ longlong4_t ret = make_longlong4_t(0, 0, 0, 0);
234
+ asm volatile(
235
+ "{\n\t"
236
+ " .reg .pred p;\n\t"
237
+ " setp.ge.s32 p, %5, 0;\n\t"
238
+ " @p ld.global.L2::256B.v4.s64 {%0, %1, %2, %3}, [%4];\n\t"
239
+ "}"
240
+ : "+l"(ret.x), "+l"(ret.y), "+l"(ret.z), "+l"(ret.w)
241
+ : "l"(ptr), "r"(pred)
242
+ : "memory");
243
+ return ret;
244
+ }
245
+
246
+ /// Prefetch
247
+ CUTLASS_DEVICE void prefetch_l1(void *ptr) {
248
+ asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr));
249
+ }
250
+
251
+ } // namespace deep_gemm::ptx
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/tcgen05.cuh ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace deep_gemm::ptx {
4
+
5
+ /// UMMA versions with relaxed assertions
6
+ struct SM100_MMA_F16BF16_SS {
7
+ CUTLASS_DEVICE static void
8
+ fma(uint64_t const& desc_a,
9
+ uint64_t const& desc_b,
10
+ uint32_t const& tmem_c,
11
+ uint32_t const& scale_c,
12
+ uint64_t const& desc) {
13
+ asm volatile(
14
+ "{\n\t"
15
+ ".reg .pred p;\n\t"
16
+ "setp.ne.b32 p, %4, 0;\n\t"
17
+ "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
18
+ "}\n"
19
+ :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
20
+ }
21
+ };
22
+
23
+ struct SM100_MMA_F16BF16_2x1SM_SS {
24
+ CUTLASS_DEVICE static void
25
+ fma(uint64_t const& desc_a,
26
+ uint64_t const& desc_b,
27
+ uint32_t const& tmem_c,
28
+ uint32_t const& scale_c,
29
+ uint64_t const& desc) {
30
+ asm volatile(
31
+ "{\n\t"
32
+ ".reg .pred p;\n\t"
33
+ "setp.ne.b32 p, %4, 0;\n\t"
34
+ "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t"
35
+ "}\n"
36
+ :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
37
+ }
38
+ };
39
+
40
+ struct SM100_MMA_MXF8F6F4_SS {
41
+ CUTLASS_DEVICE static void
42
+ fma(uint64_t const& desc_a,
43
+ uint64_t const& desc_b,
44
+ uint32_t const& tmem_c,
45
+ uint32_t const& scale_c,
46
+ uint64_t const& desc,
47
+ uint32_t const& tmem_sfa,
48
+ uint32_t const& tmem_sfb) {
49
+ asm volatile(
50
+ "{\n\t"
51
+ ".reg .pred p;\n\t"
52
+ "setp.ne.b32 p, %4, 0;\n\t"
53
+ "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
54
+ "}\n"
55
+ :
56
+ : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
57
+ "r"(tmem_sfa), "r"(tmem_sfb));
58
+ }
59
+ };
60
+
61
+ struct SM100_MMA_MXF8F6F4_2x1SM_SS {
62
+ CUTLASS_DEVICE static void
63
+ fma(uint64_t const& desc_a,
64
+ uint64_t const& desc_b,
65
+ uint32_t const& tmem_c,
66
+ uint32_t const& scale_c,
67
+ uint64_t const& desc,
68
+ uint32_t const& tmem_sfa,
69
+ uint32_t const& tmem_sfb) {
70
+ asm volatile(
71
+ "{\n\t"
72
+ ".reg .pred p;\n\t"
73
+ "setp.ne.b32 p, %4, 0;\n\t"
74
+ "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
75
+ "}\n"
76
+ :
77
+ : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
78
+ "r"(tmem_sfa), "r"(tmem_sfb));
79
+ }
80
+ };
81
+
82
+ struct SM100_MMA_F8F6F4_SS {
83
+ CUTLASS_DEVICE static void
84
+ fma(uint64_t const& desc_a,
85
+ uint64_t const& desc_b,
86
+ uint32_t const& tmem_c,
87
+ uint32_t const& scale_c,
88
+ uint64_t const& desc) {
89
+ asm volatile(
90
+ "{\n\t"
91
+ ".reg .pred p;\n\t"
92
+ "setp.ne.b32 p, %4, 0;\n\t"
93
+ "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p; \n\t"
94
+ "}\n"
95
+ :
96
+ : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
97
+ }
98
+ };
99
+
100
+ struct SM100_MMA_F8F6F4_2x1SM_SS {
101
+ CUTLASS_DEVICE static void
102
+ fma(uint64_t const& desc_a,
103
+ uint64_t const& desc_b,
104
+ uint32_t const& tmem_c,
105
+ uint32_t const& scale_c,
106
+ uint64_t const& desc) {
107
+ asm volatile(
108
+ "{\n\t"
109
+ ".reg .pred p;\n\t"
110
+ "setp.ne.b32 p, %4, 0;\n\t"
111
+ "tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], %1, %2, %3, p; \n\t"
112
+ "}\n"
113
+ :
114
+ : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
115
+ }
116
+ };
117
+
118
+ struct SM100_MMA_MXF4_SS {
119
+ CUTLASS_DEVICE static void
120
+ fma(uint64_t const& desc_a,
121
+ uint64_t const& desc_b,
122
+ uint32_t const& tmem_c,
123
+ uint32_t const& scale_c,
124
+ uint64_t const& desc,
125
+ uint32_t const& tmem_sfa,
126
+ uint32_t const& tmem_sfb) {
127
+ asm volatile(
128
+ "{\n\t"
129
+ ".reg .pred p;\n\t"
130
+ "setp.ne.b32 p, %4, 0;\n\t"
131
+ #if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)
132
+ "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t"
133
+ #else
134
+ "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t"
135
+ #endif
136
+ "}\n"
137
+ :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
138
+ "r"(tmem_sfa), "r"(tmem_sfb));
139
+ }
140
+ };
141
+
142
+ struct SM100_MMA_F16BF16_WS_SS {
143
+ CUTLASS_DEVICE static void
144
+ fma(uint64_t const& desc_a,
145
+ uint64_t const& desc_b,
146
+ uint32_t const& tmem_c,
147
+ uint32_t const& scale_c,
148
+ uint64_t const& desc) {
149
+ asm volatile(
150
+ "{\n\t"
151
+ ".reg .pred p;\n\t"
152
+ "setp.ne.b32 p, %4, 0;\n\t"
153
+ "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
154
+ "}\n"
155
+ :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
156
+ }
157
+ };
158
+
159
+ /// Tensor memory operations
160
+ CUTLASS_DEVICE void tcgen05_before_thread_sync() {
161
+ asm volatile("tcgen05.fence::before_thread_sync;");
162
+ }
163
+
164
+ CUTLASS_DEVICE void tcgen05_after_thread_sync() {
165
+ asm volatile("tcgen05.fence::after_thread_sync;");
166
+ }
167
+
168
+ } // namespace deep_gemm::ptx
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/tma.cuh ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cutlass/arch/barrier.h>
4
+ #include <cute/arch/copy_sm90_desc.hpp>
5
+
6
+ namespace deep_gemm::ptx {
7
+
8
+ // Tensor-map instructions
9
+ CUTLASS_DEVICE void tensor_map_release_gpu() {
10
+ asm volatile ("fence.proxy.tensormap::generic.release.gpu;" ::: "memory");
11
+ }
12
+
13
+ CUTLASS_DEVICE void tensor_map_acquire_gpu(const cute::TmaDescriptor* gmem_desc_ptr) {
14
+ auto gmem_int_desc = reinterpret_cast<uint64_t>(gmem_desc_ptr);
15
+ asm volatile ("fence.proxy.tensormap::generic.acquire.gpu [%0], 128;" :: "l"(gmem_int_desc) : "memory");
16
+ }
17
+
18
+ CUTLASS_DEVICE void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) {
19
+ auto smem_int_desc = static_cast<uint32_t>(__cvta_generic_to_shared(smem_desc));
20
+ const auto new_int64_addr = reinterpret_cast<uint64_t>(new_addr);
21
+ asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr));
22
+ }
23
+
24
+ CUTLASS_DEVICE void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) {
25
+ auto smem_int_desc = __cvta_generic_to_shared(smem_desc);
26
+ asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim));
27
+ #if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3)))
28
+ asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride));
29
+ #else
30
+ DG_STATIC_ASSERT(false, "Invalid CUDA version");
31
+ #endif
32
+ }
33
+
34
+ /// TMA instructions
35
+ CUTLASS_DEVICE void mbarrier_arrive(
36
+ cutlass::arch::ClusterTransactionBarrier* ptr) {
37
+ asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0]; \n\t" ::
38
+ "r"(static_cast<uint32_t>(__cvta_generic_to_shared(ptr))));
39
+ }
40
+
41
+ CUTLASS_DEVICE void mbarrier_arrive_and_set_tx(
42
+ cutlass::arch::ClusterTransactionBarrier* ptr, const uint32_t& num_bytes) {
43
+ asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" ::
44
+ "r"(num_bytes), "r"(static_cast<uint32_t>(__cvta_generic_to_shared(ptr))));
45
+ }
46
+
47
+ CUTLASS_DEVICE void mbarrier_wait_and_flip_phase(
48
+ cutlass::arch::ClusterTransactionBarrier* ptr, uint32_t& phase) {
49
+ asm volatile(
50
+ "{\n\t"
51
+ ".reg .pred P1; \n\t"
52
+ "LAB_WAIT: \n\t"
53
+ "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t"
54
+ "@P1 bra DONE; \n\t"
55
+ "bra LAB_WAIT; \n\t"
56
+ "DONE: \n\t"
57
+ "}" ::
58
+ "r"(static_cast<uint32_t>(__cvta_generic_to_shared(ptr))),
59
+ "r"(phase), "r"(0x989680));
60
+ phase ^= 1;
61
+ }
62
+
63
+ CUTLASS_DEVICE void tma_load_1d(
64
+ const void* dst_ptr, const void* src_ptr,
65
+ cutlass::arch::ClusterTransactionBarrier* mbarrier_ptr,
66
+ const uint32_t& num_bytes,
67
+ const cute::TMA::CacheHintSm90& hint = cute::TMA::CacheHintSm90::EVICT_FIRST) {
68
+ // NOTES: normally, the loaded part will be evicted soon
69
+ asm volatile(
70
+ "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\n" ::
71
+ "r"(static_cast<uint32_t>(__cvta_generic_to_shared(dst_ptr))),
72
+ "l"(src_ptr),
73
+ "r"(num_bytes),
74
+ "r"(static_cast<uint32_t>(__cvta_generic_to_shared(mbarrier_ptr))),
75
+ "l"(hint)
76
+ : "memory");
77
+ }
78
+
79
+ CUTLASS_DEVICE void tma_store_1d(
80
+ const void* dst_ptr, const void* src_ptr, const uint32_t& num_bytes,
81
+ const cute::TMA::CacheHintSm90& hint = cute::TMA::CacheHintSm90::EVICT_NORMAL) {
82
+ // NOTES: normally, the stored part will be used soon
83
+ asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\n" ::
84
+ "l"(dst_ptr),
85
+ "r"(static_cast<uint32_t>(__cvta_generic_to_shared(src_ptr))),
86
+ "r"(num_bytes),
87
+ "l"(hint)
88
+ : "memory");
89
+ }
90
+
91
+ template <int kNumRemainingWaits = 0>
92
+ __forceinline__ __device__ void tma_store_wait() {
93
+ // NOTES: this function does not have `.read`
94
+ asm volatile("cp.async.bulk.wait_group %0;" ::"n"(kNumRemainingWaits) : "memory");
95
+ }
96
+
97
+ CUTLASS_DEVICE
98
+ void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier& mbarrier,
99
+ void* smem_ptr, const uint32_t& col_idx, const int4& row_idxs, const uint64_t& cache_hint) {
100
+ const auto smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
101
+ const auto mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier);
102
+ asm volatile(
103
+ "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
104
+ :
105
+ : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
106
+ "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
107
+ "r"(mbarrier_addr), "l"(cache_hint)
108
+ : "memory"
109
+ );
110
+ }
111
+
112
+ } // namespace deep_gemm::ptx
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/utils.cuh ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cuda/std/cstdint>
4
+ #include <cuda_bf16.h>
5
+
6
+ #include <deep_gemm/common/exception.cuh>
7
+
8
+ namespace deep_gemm::ptx {
9
+
10
+ CUTLASS_DEVICE uint32_t get_sm_idx() {
11
+ uint32_t sm_idx;
12
+ asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx));
13
+ return sm_idx;
14
+ }
15
+
16
+ CUTLASS_DEVICE uint32_t get_lane_idx() {
17
+ uint32_t lane_id;
18
+ asm ("mov.u32 %0, %%laneid;" : "=r"(lane_id));
19
+ return lane_id;
20
+ }
21
+
22
+ CUTLASS_DEVICE void sync_aligned(const uint32_t& num_threads, const uint32_t& barrier_idx) {
23
+ asm volatile("bar.sync %0, %1;" : : "r"(barrier_idx), "r"(num_threads));
24
+ }
25
+
26
+ CUTLASS_DEVICE void sync_unaligned(const uint32_t& num_threads, const uint32_t& barrier_idx) {
27
+ asm volatile("barrier.sync %0, %1;" : : "r"(barrier_idx), "r"(num_threads));
28
+ }
29
+
30
+ template <typename dtype_t>
31
+ CUTLASS_DEVICE dtype_t exchange(dtype_t ptr, const uint32_t& src_lane_idx) {
32
+ DG_STATIC_ASSERT(sizeof(dtype_t) % sizeof(uint32_t) == 0, "");
33
+ const auto send_int_values = reinterpret_cast<uint32_t*>(&ptr);
34
+ dtype_t recv_dtype;
35
+ auto recv_int_values = reinterpret_cast<uint32_t*>(&recv_dtype);
36
+ #pragma unroll
37
+ for (uint32_t i = 0; i < sizeof(dtype_t) / sizeof(uint32_t); ++ i)
38
+ recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], static_cast<int>(src_lane_idx));
39
+ return recv_dtype;
40
+ }
41
+
42
+ CUTLASS_DEVICE void accumulate(float2& a, nv_bfloat162 b) {
43
+ #if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)
44
+ // Use `add.rn.f32.bf16` instruction to perform fused (cast + add) operation on SM100
45
+ asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.x) : "h"(*reinterpret_cast<uint16_t*>(&b.x)));
46
+ asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.y) : "h"(*reinterpret_cast<uint16_t*>(&b.y)));
47
+ #else
48
+ const auto [x, y] = __bfloat1622float2(b);
49
+ a.x += x, a.y += y;
50
+ #endif
51
+ }
52
+
53
+ } // namespace deep_gemm::ptx
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/ptx/wgmma.cuh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <deep_gemm/common/exception.cuh>
4
+
5
+ namespace deep_gemm::ptx {
6
+
7
+ CUTLASS_DEVICE void warpgroup_arrive() {
8
+ asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
9
+ }
10
+
11
+ CUTLASS_DEVICE void warpgroup_commit_batch() {
12
+ asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
13
+ }
14
+
15
+ CUTLASS_DEVICE void warpgroup_fence_operand(float& reg) {
16
+ asm volatile("" : "+f"(reg) :: "memory");
17
+ }
18
+
19
+ template <int N>
20
+ CUTLASS_DEVICE void warpgroup_wait() {
21
+ DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
22
+ asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
23
+ }
24
+
25
+ } // namespace deep_gemm::ptx
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/scheduler/gemm.cuh ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <deep_gemm/common/math.cuh>
4
+ #include <deep_gemm/common/types.cuh>
5
+
6
+ namespace deep_gemm::sched {
7
+
8
+ enum class IndexType {
9
+ MN,
10
+ K,
11
+ SF_K,
12
+ };
13
+
14
+ template <GemmType kGemmType, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumSMs, bool kIsMulticastOnA>
15
+ static constexpr uint32_t get_num_1d_blocks_per_group() {
16
+ // Select the best from candidates
17
+ uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits<uint32_t>::max();
18
+ for (const auto candidate: {8u, 16u}) {
19
+ const auto usage = kIsMulticastOnA ?
20
+ candidate * BLOCK_N + math::constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N
21
+ candidate * BLOCK_M + math::constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M
22
+ if (usage < min_usage)
23
+ min_usage = usage, num_best_blocks = candidate;
24
+ }
25
+ return num_best_blocks;
26
+ }
27
+
28
+ #pragma clang diagnostic push
29
+ #pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
30
+ template <GemmType kGemmType,
31
+ uint32_t BLOCK_M, uint32_t BLOCK_N,
32
+ uint32_t kNumGroups,
33
+ uint32_t kNumMulticast, bool kIsMulticastOnA,
34
+ uint32_t kNumSMs,
35
+ uint32_t SF_K_ALIGNMENT = 512u, // for k-grouped GEMM only: 128 on SM90 (float SF), gran_k * 4 on SM100 (packed UE8M0 SF)
36
+ uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group<kGemmType, BLOCK_M, BLOCK_N, kNumSMs, kIsMulticastOnA>()>
37
+ struct Scheduler {
38
+ int current_iter = -1;
39
+
40
+ // Block configs
41
+ uint32_t num_blocks;
42
+ uint32_t num_m_blocks;
43
+ uint32_t num_n_blocks;
44
+
45
+ // For SM90 multicast checks
46
+ uint32_t num_blocks_in_group;
47
+ bool is_peer_cta_alive = true;
48
+
49
+ // For grouped GEMM
50
+ int* grouped_layout;
51
+ uint32_t current_group_idx = 0;
52
+ // Only used for masked layout
53
+ uint32_t current_m_cumsum = 0;
54
+ // Only used for contiguous psum layout
55
+ uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0;
56
+ // Only used for k-grouped layout
57
+ uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0;
58
+ uint32_t next_group_idx, next_shape_k;
59
+
60
+ // Only used for k-grouped gemm
61
+ CUTLASS_DEVICE void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const {
62
+ for (; group_idx < kNumGroups; ++ group_idx) {
63
+ shape_k = grouped_layout[group_idx];
64
+ if (shape_k > 0)
65
+ break;
66
+ }
67
+ }
68
+
69
+ // ReSharper disable once CppPossiblyUninitializedMember
70
+ CUTLASS_DEVICE explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n,
71
+ const uint32_t& shape_k, int* grouped_layout = nullptr) {
72
+ num_m_blocks = math::ceil_div(shape_m, BLOCK_M);
73
+ num_n_blocks = math::ceil_div(shape_n, BLOCK_N);
74
+ current_shape_k = shape_k;
75
+ if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) {
76
+ num_blocks = num_m_blocks * num_n_blocks;
77
+ } else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
78
+ num_blocks = num_m_blocks * num_n_blocks;
79
+ this->grouped_layout = grouped_layout;
80
+ } else if constexpr (kGemmType == GemmType::MGroupedMasked) {
81
+ this->grouped_layout = grouped_layout;
82
+ } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
83
+ this->grouped_layout = grouped_layout;
84
+ current_psum_m = grouped_layout[0];
85
+ num_m_blocks = math::ceil_div(current_psum_m, BLOCK_M);
86
+ } else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
87
+ num_blocks = num_m_blocks * num_n_blocks;
88
+ this->grouped_layout = grouped_layout;
89
+ get_next_k_group(current_group_idx, current_shape_k);
90
+ next_group_idx = current_group_idx + 1;
91
+ get_next_k_group(next_group_idx, next_shape_k);
92
+ }
93
+ }
94
+
95
+ CUTLASS_DEVICE void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
96
+ DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size");
97
+
98
+ // Swizzle for better L2 usages
99
+ const auto primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks;
100
+ const auto secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks;
101
+ const auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup;
102
+ const auto group_idx = block_idx / num_blocks_per_group;
103
+ auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
104
+ auto in_group_idx = block_idx % num_blocks_per_group;
105
+ num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
106
+
107
+ // Fix unaligned TMA multicast
108
+ // NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast
109
+ // while SM100 uses 2-CTA, which can not be dynamically disabled
110
+ #if __CUDA_ARCH__ < 1000
111
+ if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) {
112
+ if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
113
+ num_blocks_in_group = num_blocks_in_group ^ 1;
114
+ } else {
115
+ in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks;
116
+ first_block_idx += num_blocks_in_group ^ 1;
117
+ num_blocks_in_group = 1;
118
+ }
119
+ }
120
+ #endif
121
+
122
+ // Convert to final M/N block indices
123
+ // `kIsMulticastOnA == true` leads to groups on N
124
+ if constexpr (kIsMulticastOnA) {
125
+ m_block_idx = in_group_idx / num_blocks_in_group;
126
+ n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
127
+ } else {
128
+ m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
129
+ n_block_idx = in_group_idx / num_blocks_in_group;
130
+ }
131
+ }
132
+
133
+ template <bool kWithGroupOffset, IndexType kIndexType = IndexType::MN>
134
+ CUTLASS_DEVICE uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
135
+ const uint32_t& block_idx, const uint32_t& m_block_idx = 0) {
136
+ if constexpr (kGemmType == GemmType::Normal) {
137
+ return block_idx * block_size;
138
+ } else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
139
+ const auto offset = kWithGroupOffset ? cute::max(0, grouped_layout[m_block_idx * BLOCK_M]) : 0;
140
+ return offset * shape_dim + block_idx * block_size;
141
+ } else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
142
+ const auto offset = kWithGroupOffset ? current_group_idx : 0;
143
+ return offset * shape_dim + block_idx * block_size;
144
+ } else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
145
+ auto offset = 0;
146
+ if constexpr (kWithGroupOffset) {
147
+ if constexpr (kIndexType == IndexType::MN)
148
+ offset = current_group_idx * shape_dim;
149
+ else if constexpr (kIndexType == IndexType::K)
150
+ offset = current_k_cumsum;
151
+ else if constexpr (kIndexType == IndexType::SF_K)
152
+ offset = current_sf_k_cumsum;
153
+ }
154
+ return offset + block_idx * block_size;
155
+ } else if constexpr (kGemmType == GemmType::Batched) {
156
+ // Ignore kWithGroupOffset, and apply offset for IndexType::SF_K
157
+ const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0;
158
+ return offset * shape_dim + block_idx * block_size;
159
+ }
160
+ }
161
+
162
+ // For swap A/B and psum layout only
163
+ CUTLASS_DEVICE uint32_t get_aligned_effective_m_in_block(const uint32_t& m_block_idx) const {
164
+ constexpr uint32_t UMMA_STEP_N = 16;
165
+ DG_STATIC_ASSERT(BLOCK_M % UMMA_STEP_N == 0, "Invalid alignment");
166
+ if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout)
167
+ return math::align(m_block_idx == last_psum_m / BLOCK_M + num_m_blocks - 1 ? current_psum_m - m_block_idx * BLOCK_M : BLOCK_M, UMMA_STEP_N);
168
+ return BLOCK_M;
169
+ }
170
+
171
+ CUTLASS_DEVICE bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
172
+ const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x;
173
+
174
+ if constexpr (kGemmType == GemmType::MGroupedMasked) {
175
+ while (true) {
176
+ // End of the task
177
+ if (current_group_idx == kNumGroups)
178
+ return false;
179
+
180
+ // Within current group
181
+ num_m_blocks = math::ceil_div(static_cast<uint32_t>(grouped_layout[current_group_idx]), BLOCK_M);
182
+ const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks;
183
+ if (next_block_idx < current_m_block_cumsum * num_n_blocks)
184
+ break;
185
+
186
+ // Move to check the next group
187
+ current_group_idx ++, current_m_cumsum = current_m_block_cumsum;
188
+ }
189
+
190
+ get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx);
191
+ } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
192
+ while (true) {
193
+ // Within current group
194
+ if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks)
195
+ break;
196
+
197
+ // Move to check the next group
198
+ if (++ current_group_idx == kNumGroups)
199
+ return false;
200
+
201
+ // NOTES: `num_m_blocks` varies with the increase of the group index
202
+ last_psum_m = math::align(current_psum_m, BLOCK_M);
203
+ current_psum_m = grouped_layout[current_group_idx];
204
+ current_m_block_cumsum += num_m_blocks;
205
+ num_m_blocks = math::ceil_div(current_psum_m - last_psum_m, BLOCK_M);
206
+ }
207
+
208
+ get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx);
209
+
210
+ // NOTES: `last_psum_m` is aligned with block M
211
+ m_block_idx += last_psum_m / BLOCK_M;
212
+ } else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
213
+ while (true) {
214
+ // End of the task
215
+ if (current_group_idx == kNumGroups)
216
+ return false;
217
+
218
+ // Within current group
219
+ if (next_block_idx < (current_num_valid_groups + 1) * num_blocks)
220
+ break;
221
+
222
+ // Move to check the next group
223
+ current_k_cumsum += current_shape_k;
224
+ current_sf_k_cumsum += math::ceil_div(current_shape_k, SF_K_ALIGNMENT);
225
+ current_num_valid_groups ++;
226
+
227
+ current_group_idx = next_group_idx ++;
228
+ current_shape_k = next_shape_k;
229
+ get_next_k_group(next_group_idx, next_shape_k);
230
+ }
231
+
232
+ get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_blocks, m_block_idx, n_block_idx);
233
+ } else if constexpr (kGemmType == GemmType::Batched) {
234
+ if (next_block_idx >= num_blocks * kNumGroups)
235
+ return false;
236
+
237
+ current_group_idx = next_block_idx / num_blocks;
238
+ const auto block_idx = next_block_idx - current_group_idx * num_blocks;
239
+ if constexpr (kIsMulticastOnA) {
240
+ m_block_idx = block_idx / num_n_blocks;
241
+ n_block_idx = block_idx % num_n_blocks;
242
+ } else {
243
+ m_block_idx = block_idx % num_m_blocks;
244
+ n_block_idx = block_idx / num_m_blocks;
245
+ }
246
+ } else {
247
+ if (next_block_idx >= num_blocks)
248
+ return false;
249
+
250
+ // For SM90 only
251
+ // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
252
+ is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass)
253
+ num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass)
254
+ (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound
255
+ get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx);
256
+ }
257
+ return true;
258
+ }
259
+
260
+ // For SM90 only
261
+ CUTLASS_DEVICE bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
262
+ if (num_blocks_in_group == 1)
263
+ return false;
264
+ if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or
265
+ kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
266
+ kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
267
+ return true;
268
+ } else {
269
+ DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type");
270
+ if constexpr (kIsMulticastOnA) {
271
+ return true;
272
+ } else {
273
+ const auto group_idx = grouped_layout[m_block_idx * BLOCK_M];
274
+ const auto peer_group_idx = grouped_layout[(m_block_idx ^ 1) * BLOCK_M];
275
+ return group_idx == peer_group_idx;
276
+ }
277
+ }
278
+ }
279
+
280
+ // For SM90 only
281
+ // ReSharper disable once CppNotAllPathsReturnValue
282
+ CUTLASS_DEVICE bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const {
283
+ if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) {
284
+ return true;
285
+ } else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
286
+ return grouped_layout[m_offset + m_block_idx * BLOCK_M] >= 0;
287
+ } else if constexpr (kGemmType == GemmType::MGroupedMasked) {
288
+ return m_offset + m_block_idx * BLOCK_M < grouped_layout[current_group_idx];
289
+ } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
290
+ return m_offset + m_block_idx * BLOCK_M < current_psum_m;
291
+ } else {
292
+ // Unreachable
293
+ DG_TRAP_ONLY_DEVICE_ASSERT(false);
294
+ }
295
+ }
296
+ };
297
+
298
+ #pragma clang diagnostic pop
299
+
300
+ } // namespace deep_gemm::sched
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/scheduler/mega_moe.cuh ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <deep_gemm/common/cute_tie.cuh>
4
+ #include <deep_gemm/common/math.cuh>
5
+ #include <deep_gemm/common/types.cuh>
6
+ #include <deep_gemm/layout/mega_moe.cuh>
7
+ #include <deep_gemm/ptx/ld_st.cuh>
8
+ #include <deep_gemm/ptx/utils.cuh>
9
+
10
+ namespace deep_gemm::sched {
11
+
12
+ // Computation phase for the current block
13
+ enum class BlockPhase {
14
+ None = 0,
15
+ Linear1 = 1,
16
+ Linear2 = 2
17
+ };
18
+
19
+ template <uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
20
+ uint32_t L1_SHAPE_N, uint32_t L1_SHAPE_K,
21
+ uint32_t L2_SHAPE_N, uint32_t L2_SHAPE_K,
22
+ uint32_t kNumExpertsPerRank,
23
+ uint32_t kNumExpertsPerWave,
24
+ uint32_t kNumSMs, uint32_t kNumRanks,
25
+ uint32_t kNumExpertsPerLane = math::constexpr_ceil_div(kNumExpertsPerRank, 32u),
26
+ uint32_t kNumL1BlockNs = L1_SHAPE_N / BLOCK_N,
27
+ uint32_t kNumL2BlockNs = L2_SHAPE_N / BLOCK_N,
28
+ uint32_t kNumL1BlockKs = L1_SHAPE_K / BLOCK_K,
29
+ uint32_t kNumL2BlockKs = L2_SHAPE_K / BLOCK_K>
30
+ struct MegaMoEScheduler {
31
+ DG_STATIC_ASSERT(L1_SHAPE_N % BLOCK_N == 0, "Invalid shape");
32
+ DG_STATIC_ASSERT(L2_SHAPE_N % BLOCK_N == 0, "Invalid shape");
33
+ DG_STATIC_ASSERT(L1_SHAPE_K % BLOCK_K == 0, "Invalid shape");
34
+ DG_STATIC_ASSERT(L2_SHAPE_K % BLOCK_K == 0, "Invalid shape");
35
+ DG_STATIC_ASSERT(kNumExpertsPerRank % kNumExpertsPerWave == 0, "Invalid wave config");
36
+
37
+ // NOTES: N block counts must be even so that 2 adjacent CTAs in a cluster
38
+ // always land on the same m_block_idx with n_block_idx differing by 1
39
+ DG_STATIC_ASSERT(kNumSMs % 2 == 0, "Number of SMs must be even for 2-CTA cluster");
40
+ DG_STATIC_ASSERT(kNumL1BlockNs % 2 == 0, "L1 N block count must be even for 2-CTA cluster");
41
+ DG_STATIC_ASSERT(kNumL2BlockNs % 2 == 0, "L2 N block count must be even for 2-CTA cluster");
42
+
43
+ // Arrival counts
44
+ const layout::Workspace& workspace;
45
+
46
+ // Scheduler state
47
+ BlockPhase next_phase = BlockPhase::Linear1;
48
+
49
+ // Current expert and block indices
50
+ uint32_t current_local_expert_idx = 0;
51
+ uint32_t current_num_tokens = 0;
52
+ uint32_t current_pool_block_offset = 0;
53
+ uint32_t block_idx = 0;
54
+ uint32_t m_block_idx = 0;
55
+ uint32_t n_block_idx = 0;
56
+
57
+ // Pre-cached per-expert token counts (filled during `for_each_block` init)
58
+ // Layout: `stored_num_tokens_per_expert[i]` holds expert (i * 32 + lane_idx)'s count
59
+ uint32_t stored_num_tokens_per_expert[kNumExpertsPerLane] = {};
60
+
61
+ CUTLASS_DEVICE explicit MegaMoEScheduler(const layout::Workspace& workspace): workspace(workspace) {
62
+ block_idx = blockIdx.x;
63
+ }
64
+
65
+ CUTLASS_DEVICE uint32_t get_wave_expert_end_idx() const {
66
+ return math::align(current_local_expert_idx + 1, kNumExpertsPerWave);
67
+ }
68
+
69
+ CUTLASS_DEVICE uint32_t get_num_tokens(const uint32_t& expert_idx) const {
70
+ uint32_t valid_value;
71
+ #pragma unroll
72
+ for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) {
73
+ valid_value = (expert_idx == i * 32 + ptx::get_lane_idx()) ?
74
+ stored_num_tokens_per_expert[i] : valid_value;
75
+ }
76
+ return ptx::exchange(valid_value, expert_idx % 32);
77
+ }
78
+
79
+ // Get pool block offset for a given expert index from a per-lane token count array
80
+ CUTLASS_DEVICE uint32_t get_pool_block_offset(const uint32_t& expert_idx) {
81
+ uint32_t num_blocks = 0;
82
+ #pragma unroll
83
+ for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) {
84
+ if (i * 32 + ptx::get_lane_idx() < expert_idx)
85
+ num_blocks += math::ceil_div(stored_num_tokens_per_expert[i], BLOCK_M);
86
+ }
87
+ return __reduce_add_sync(0xffffffff, num_blocks);
88
+ }
89
+
90
+ CUTLASS_DEVICE void advance_expert_idx() {
91
+ current_pool_block_offset += get_current_num_m_blocks();
92
+ current_local_expert_idx += 1;
93
+ current_num_tokens = get_num_tokens(current_local_expert_idx);
94
+ }
95
+
96
+ CUTLASS_DEVICE void set_expert_idx(const uint32_t& expert_idx) {
97
+ current_local_expert_idx = expert_idx;
98
+ current_num_tokens = get_num_tokens(expert_idx);
99
+ current_pool_block_offset = get_pool_block_offset(expert_idx);
100
+ }
101
+
102
+ CUTLASS_DEVICE uint32_t get_current_pool_block_offset() const {
103
+ return current_pool_block_offset;
104
+ }
105
+
106
+ CUTLASS_DEVICE uint32_t get_current_num_m_blocks() const {
107
+ return math::ceil_div(current_num_tokens, BLOCK_M);
108
+ }
109
+
110
+ template <bool kDoUMMAAligned = false>
111
+ CUTLASS_DEVICE uint32_t get_valid_m() const {
112
+ const auto m = cute::min(current_num_tokens - m_block_idx * BLOCK_M, BLOCK_M);
113
+ return kDoUMMAAligned ? math::align(m, 16u) : m;
114
+ }
115
+
116
+ CUTLASS_DEVICE bool fetch_next_l1_block() {
117
+ const auto wave_end_expert_idx = get_wave_expert_end_idx();
118
+ while (current_local_expert_idx < wave_end_expert_idx) {
119
+ const auto num_m_blocks = get_current_num_m_blocks();
120
+ m_block_idx = block_idx / kNumL1BlockNs;
121
+ if (m_block_idx < num_m_blocks)
122
+ return true;
123
+
124
+ // Current expert is fully assigned, move to the next
125
+ block_idx -= num_m_blocks * kNumL1BlockNs;
126
+ advance_expert_idx();
127
+ }
128
+ return false;
129
+ }
130
+
131
+ CUTLASS_DEVICE bool fetch_next_l2_block() {
132
+ const auto wave_end_expert_idx = get_wave_expert_end_idx();
133
+ while (current_local_expert_idx < wave_end_expert_idx) {
134
+ const auto num_m_blocks = get_current_num_m_blocks();
135
+ if (block_idx < num_m_blocks * kNumL2BlockNs) {
136
+ m_block_idx = block_idx / kNumL2BlockNs;
137
+ return true;
138
+ }
139
+
140
+ // Current expert is fully assigned, move to the next
141
+ block_idx -= num_m_blocks * kNumL2BlockNs;
142
+ advance_expert_idx();
143
+ }
144
+ return false;
145
+ }
146
+
147
+ // Core state machine: assigns the next block
148
+ CUTLASS_DEVICE cute::tuple<BlockPhase, uint32_t, uint32_t, uint32_t> get_next_block() {
149
+ while (true) {
150
+ if (current_local_expert_idx >= kNumExpertsPerRank)
151
+ break;
152
+
153
+ if (next_phase == BlockPhase::Linear1) {
154
+ if (fetch_next_l1_block()) {
155
+ // Found a new L1 block
156
+ n_block_idx = block_idx - m_block_idx * kNumL1BlockNs;
157
+ // Jump to next block
158
+ block_idx += kNumSMs;
159
+ return {BlockPhase::Linear1, current_local_expert_idx, m_block_idx, n_block_idx};
160
+ } else {
161
+ // L1 for the current wave is complete, transition to L2
162
+ next_phase = BlockPhase::Linear2;
163
+ set_expert_idx(math::align<uint32_t, false>(current_local_expert_idx - 1, kNumExpertsPerWave));
164
+ }
165
+ } else {
166
+ if (fetch_next_l2_block()) {
167
+ // Found a new L2 block
168
+ n_block_idx = block_idx - m_block_idx * kNumL2BlockNs;
169
+ // Jump to next block
170
+ block_idx += kNumSMs;
171
+ return {BlockPhase::Linear2, current_local_expert_idx, m_block_idx, n_block_idx};
172
+ } else {
173
+ // Move to L1 of the next wave
174
+ next_phase = BlockPhase::Linear1;
175
+ }
176
+ }
177
+ }
178
+
179
+ // All waves and experts are fully processed
180
+ return {BlockPhase::None, 0, 0, 0};
181
+ }
182
+
183
+ CUTLASS_DEVICE void fetch_expert_recv_count() {
184
+ // NOTES: each lane caches experts at indices (i * 32 + lane_idx)
185
+ #pragma unroll
186
+ for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) {
187
+ const auto expert_idx = i * 32 + ptx::get_lane_idx();
188
+ uint64_t value = 0;
189
+ if (expert_idx < kNumExpertsPerRank) {
190
+ do {
191
+ value = ptx::ld_volatile(workspace.get_expert_recv_count_sum_ptr(expert_idx));
192
+ } while (static_cast<uint32_t>(value >> 32) != kNumSMs * kNumRanks);
193
+ }
194
+ stored_num_tokens_per_expert[i] = static_cast<uint32_t>(value);
195
+ }
196
+ __syncwarp();
197
+ }
198
+
199
+ template <typename Func>
200
+ CUTLASS_DEVICE void for_each_block(Func&& func) {
201
+ // Wait for all expert counters to be finalized
202
+ fetch_expert_recv_count();
203
+
204
+ // Initialize current expert with 0
205
+ set_expert_idx(0);
206
+
207
+ // Iterate over all blocks
208
+ // TODO: add swizzle within expert waves for better L2 cache utilization
209
+ while (true) {
210
+ CUTE_TIE_DECL(get_next_block(), block_phase, current_local_expert_idx, m_block_idx, n_block_idx);
211
+ if (block_phase == BlockPhase::None)
212
+ break;
213
+
214
+ func(block_phase, current_local_expert_idx,
215
+ block_phase == BlockPhase::Linear2 ? kNumL2BlockKs : kNumL1BlockKs,
216
+ m_block_idx, n_block_idx);
217
+ }
218
+ }
219
+ };
220
+
221
+ } // namespace deep_gemm::sched
build/torch210-cxx11-cu128-x86_64-linux/include/deep_gemm/scheduler/paged_mqa_logits.cuh ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <deep_gemm/common/math.cuh>
4
+ #include <deep_gemm/common/types.cuh>
5
+ #include <deep_gemm/ptx/utils.cuh>
6
+
7
+ namespace deep_gemm::sched {
8
+
9
+ template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs, bool kIsVarlen = false>
10
+ CUTLASS_GLOBAL __launch_bounds__(32, 1)
11
+ void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d,
12
+ const uint32_t* context_lens, const uint32_t* indices, uint32_t* schedule_metadata) {
13
+ DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size");
14
+ const uint32_t lane_idx = ptx::get_lane_idx();
15
+
16
+ // Wait for primary kernel completion
17
+ cudaGridDependencySynchronize();
18
+
19
+ __shared__ uint32_t varlen_atom_token_start[kAlignedBatchSize];
20
+ __shared__ uint32_t varlen_atom_context_len[kAlignedBatchSize];
21
+ __shared__ uint32_t varlen_num_atoms_shared;
22
+ uint32_t num_items;
23
+
24
+ if constexpr (kIsVarlen) {
25
+ if (lane_idx == 0) {
26
+ uint32_t t = 0, atom_count = 0;
27
+ while (t < batch_size) {
28
+ varlen_atom_token_start[atom_count] = t;
29
+ const bool is_paired = (t + 1 < batch_size and indices[t] == indices[t + 1]);
30
+ varlen_atom_context_len[atom_count] = is_paired ? context_lens[t + 1] : context_lens[t];
31
+ t += is_paired ? 2 : 1;
32
+ ++ atom_count;
33
+ }
34
+ varlen_num_atoms_shared = atom_count;
35
+ }
36
+ __syncwarp();
37
+ num_items = varlen_num_atoms_shared;
38
+ } else {
39
+ num_items = batch_size;
40
+ }
41
+
42
+ // Compute num_segs and prefix sum
43
+ uint32_t num_segs[kAlignedBatchSize / 32];
44
+ #pragma unroll
45
+ for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
46
+ const uint32_t q_idx = k * 32 + lane_idx;
47
+ uint32_t context_len;
48
+ if constexpr (kIsVarlen) {
49
+ context_len = (q_idx < num_items ? varlen_atom_context_len[q_idx] : 0);
50
+ } else {
51
+ const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx);
52
+ context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0);
53
+ }
54
+ num_segs[k] = math::ceil_div(context_len, SPLIT_KV);
55
+ }
56
+
57
+ __shared__ uint32_t prefix_sum[kAlignedBatchSize];
58
+ uint32_t sum = 0;
59
+ #pragma unroll
60
+ for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
61
+ uint32_t x = num_segs[k];
62
+ #pragma unroll
63
+ for (uint32_t offset = 1; offset < 32; offset <<= 1) {
64
+ const uint32_t y = __shfl_up_sync(0xffffffff, x, offset);
65
+ x += (lane_idx >= offset ? y : 0);
66
+ }
67
+ x += sum;
68
+ prefix_sum[k * 32 + lane_idx] = x;
69
+ sum = __shfl_sync(0xffffffff, x, 31);
70
+ }
71
+
72
+ // SM work distribution
73
+ if constexpr (kIsVarlen) {
74
+ const uint32_t total = sum;
75
+ const uint32_t q = total / kNumSMs, r = total % kNumSMs;
76
+ for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
77
+ uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
78
+ uint32_t lo = 0, hi = num_items;
79
+ while (lo < hi) {
80
+ const uint32_t mid = (lo + hi) / 2;
81
+ const bool pred = prefix_sum[mid] <= seg_starts;
82
+ lo = pred ? mid + 1 : lo;
83
+ hi = pred ? hi : mid;
84
+ }
85
+ const uint32_t atom_idx = lo;
86
+ const uint32_t kv_split_idx = (atom_idx == 0 ? seg_starts : seg_starts - prefix_sum[atom_idx - 1]);
87
+ const uint32_t q_atom_idx = (atom_idx < num_items ? varlen_atom_token_start[atom_idx] : batch_size);
88
+ __syncwarp();
89
+
90
+ schedule_metadata[sm_idx * 2] = q_atom_idx;
91
+ schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
92
+ }
93
+ } else {
94
+ const uint32_t next_n_atom = (next_n >= 2) ? 2 : 1;
95
+ const uint32_t num_next_n_atoms = math::ceil_div(next_n, next_n_atom);
96
+ const uint32_t total = sum * num_next_n_atoms;
97
+ const uint32_t q = total / kNumSMs, r = total % kNumSMs;
98
+ for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
99
+ uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
100
+ uint32_t lo = 0, hi = batch_size;
101
+ while (lo < hi) {
102
+ const uint32_t mid = (lo + hi) / 2;
103
+ const bool pred = prefix_sum[mid] * num_next_n_atoms <= seg_starts;
104
+ lo = pred ? mid + 1 : lo;
105
+ hi = pred ? hi : mid;
106
+ }
107
+ const uint32_t q_idx = lo;
108
+ const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms);
109
+ const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]);
110
+ const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0;
111
+ const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0;
112
+ const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx;
113
+ __syncwarp();
114
+
115
+ schedule_metadata[sm_idx * 2] = q_atom_idx;
116
+ schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
117
+ }
118
+ }
119
+ }
120
+
121
+ // Conditional storage for varlen indices pointer (EBO: zero cost when unused)
122
+ template <bool kHasIndices>
123
+ struct IndicesStorage {
124
+ const uint32_t* indices;
125
+ };
126
+
127
+ template <>
128
+ struct IndicesStorage<false> {};
129
+
130
+ template <uint32_t kNextN, bool kIsContextLens2D, bool kIsVarlen,
131
+ uint32_t BLOCK_KV, uint32_t kNumBlocksPerSplit,
132
+ uint32_t kNumNextNAtoms>
133
+ struct PagedMQALogitsScheduler : IndicesStorage<kIsVarlen> {
134
+ const uint32_t* context_lens;
135
+ uint32_t batch_size;
136
+
137
+ uint32_t current_q_atom_idx, current_kv_idx;
138
+ uint32_t end_q_atom_idx, end_kv_idx;
139
+ uint32_t current_num_kv;
140
+
141
+ CUTLASS_DEVICE static uint32_t atom_to_token_idx(const uint32_t& q_atom_idx) {
142
+ if constexpr (kIsVarlen) {
143
+ return q_atom_idx;
144
+ } else {
145
+ static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
146
+ static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
147
+ if constexpr (kPadOddN) {
148
+ return q_atom_idx / kNumNextNAtoms * kNextN + q_atom_idx % kNumNextNAtoms * kNextNAtom;
149
+ } else {
150
+ return q_atom_idx * kNextNAtom;
151
+ }
152
+ }
153
+ }
154
+
155
+ CUTLASS_DEVICE static uint32_t atom_to_block_table_row(const uint32_t& q_atom_idx) {
156
+ if constexpr (kIsVarlen) {
157
+ return q_atom_idx;
158
+ } else {
159
+ return q_atom_idx / kNumNextNAtoms;
160
+ }
161
+ }
162
+
163
+ CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const {
164
+ if constexpr (kIsVarlen) {
165
+ const bool is_paired = (q_atom_idx + 1 < batch_size and
166
+ this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]);
167
+ const uint32_t ctx_len = is_paired ? context_lens[q_atom_idx + 1] : context_lens[q_atom_idx];
168
+ return math::ceil_div(ctx_len, BLOCK_KV);
169
+ } else {
170
+ const uint32_t q_idx = q_atom_idx / kNumNextNAtoms;
171
+ const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx);
172
+ return math::ceil_div(context_lens[lens_idx], BLOCK_KV);
173
+ }
174
+ }
175
+
176
+ CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t& batch_size,
177
+ const uint32_t* context_lens,
178
+ const uint32_t* schedule_meta, const uint32_t* indices) {
179
+ this->context_lens = context_lens;
180
+ this->batch_size = batch_size;
181
+ if constexpr (kIsVarlen) {
182
+ this->indices = indices;
183
+ }
184
+
185
+ const auto current_pack = reinterpret_cast<const uint2*>(schedule_meta)[sm_idx];
186
+ const auto end_pack = reinterpret_cast<const uint2*>(schedule_meta)[sm_idx + 1];
187
+ current_q_atom_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit;
188
+ end_q_atom_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit;
189
+
190
+ current_num_kv = get_num_kv(current_q_atom_idx);
191
+ }
192
+
193
+ // Advance step in q_atom_idx space when moving to the next atom.
194
+ // Varlen: 1 or 2 depending on whether consecutive tokens share the same sequence.
195
+ // Non-varlen: always 1 (one atom unit).
196
+ CUTLASS_DEVICE uint32_t get_atom_advance(const uint32_t& q_atom_idx, const uint32_t& bound) const {
197
+ if constexpr (kIsVarlen) {
198
+ return (q_atom_idx + 1 < bound and this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]) ? 2 : 1;
199
+ } else {
200
+ return 1;
201
+ }
202
+ }
203
+
204
+ // Whether num_kv should be refreshed after advancing to q_atom_idx.
205
+ // Varlen: always refresh (each atom may have a different context_len).
206
+ // Non-varlen: only at atom-group boundaries (atoms within a group share context_len).
207
+ CUTLASS_DEVICE bool should_refresh_num_kv(const uint32_t& q_atom_idx) const {
208
+ if constexpr (kIsVarlen) {
209
+ return true;
210
+ } else {
211
+ return q_atom_idx % kNumNextNAtoms == 0;
212
+ }
213
+ }
214
+
215
+ CUTLASS_DEVICE bool fetch_next_task(uint32_t &q_atom_idx, uint32_t &kv_idx, uint32_t &num_kv) {
216
+ q_atom_idx = current_q_atom_idx;
217
+ kv_idx = current_kv_idx;
218
+ num_kv = current_num_kv;
219
+
220
+ if (current_q_atom_idx == end_q_atom_idx and current_kv_idx == end_kv_idx)
221
+ return false;
222
+
223
+ current_kv_idx += kNumBlocksPerSplit;
224
+ if (current_kv_idx >= current_num_kv) {
225
+ current_kv_idx = 0;
226
+ current_q_atom_idx += get_atom_advance(current_q_atom_idx, end_q_atom_idx);
227
+ if (should_refresh_num_kv(current_q_atom_idx) and exist_q_atom_idx(current_q_atom_idx)) {
228
+ current_num_kv = get_num_kv(current_q_atom_idx);
229
+ }
230
+ }
231
+ return true;
232
+ }
233
+
234
+ CUTLASS_DEVICE bool exist_q_atom_idx(const uint32_t& q_atom_idx) const {
235
+ return q_atom_idx < end_q_atom_idx or (q_atom_idx == end_q_atom_idx and 0 < end_kv_idx);
236
+ }
237
+ };
238
+
239
+ } // namespace deep_gemm::sched
build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp DELETED
@@ -1,904 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- #pragma once
32
-
33
- #include "cutlass/cutlass.h"
34
- #include "cutlass/workspace.h"
35
- #include "cutlass/fast_math.h"
36
- #include "cutlass/kernel_hardware_info.hpp"
37
- #include "cute/arch/cluster_sm90.hpp"
38
- #include "cutlass/arch/arch.h"
39
- #include "cutlass/arch/reg_reconfig.h"
40
- #include "cutlass/arch/mma_sm90.h"
41
- #include "cutlass/epilogue/collective/detail.hpp"
42
- #include "cutlass/gemm/gemm.h"
43
- #include "cutlass/gemm/dispatch_policy.hpp"
44
- #include "cutlass/gemm/kernel/tile_scheduler.hpp"
45
- #include "cutlass/pipeline/pipeline.hpp"
46
- #include "cute/tensor.hpp"
47
- #include "cutlass/trace.h"
48
- #include "cutlass/gemm/kernel/gemm_universal_decl.h"
49
- #include "cutlass/arch/grid_dependency_control.h"
50
-
51
- ///////////////////////////////////////////////////////////////////////////////
52
-
53
- namespace cutlass::gemm::kernel {
54
-
55
- ///////////////////////////////////////////////////////////////////////////////
56
-
57
- template <
58
- class ProblemShape_,
59
- class CollectiveMainloop_,
60
- class CollectiveEpilogue_,
61
- class TileSchedulerTag_
62
- >
63
- class GemmUniversal<
64
- ProblemShape_,
65
- CollectiveMainloop_,
66
- CollectiveEpilogue_,
67
- TileSchedulerTag_,
68
- cute::enable_if_t<
69
- cutlass::detail::is_asymmetric_dma_kernel_tag_of_v<typename CollectiveMainloop_::DispatchPolicy::Schedule,
70
- KernelTmaWarpSpecializedCooperativeSparseSm120> ||
71
- cutlass::detail::is_asymmetric_dma_kernel_tag_of_v<typename CollectiveMainloop_::DispatchPolicy::Schedule,
72
- KernelTmaWarpSpecializedCooperativeSparseBlockScaledSm120>>>
73
- {
74
- public:
75
- //
76
- // Type Aliases
77
- //
78
- using ProblemShape = ProblemShape_;
79
- static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
80
- "ProblemShape{} should be <M,N,K> or <M,N,K,L>");
81
-
82
- // Mainloop derived types
83
- using CollectiveMainloop = CollectiveMainloop_;
84
- using TileShape = typename CollectiveMainloop::TileShape;
85
- using TiledMma = typename CollectiveMainloop::TiledMma;
86
- using ArchTag = typename CollectiveMainloop::ArchTag;
87
- using ElementA = typename CollectiveMainloop::ElementA;
88
- using StrideA = typename CollectiveMainloop::StrideA;
89
- using ElementB = typename CollectiveMainloop::ElementB;
90
- using StrideB = typename CollectiveMainloop::StrideB;
91
- using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
92
- using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
93
- using ClusterShape = typename DispatchPolicy::ClusterShape;
94
- using MainloopArguments = typename CollectiveMainloop::Arguments;
95
- using MainloopParams = typename CollectiveMainloop::Params;
96
-
97
- // Epilogue derived types
98
- using CollectiveEpilogue = CollectiveEpilogue_;
99
- using ElementC = typename CollectiveEpilogue::ElementC;
100
- using StrideC = typename CollectiveEpilogue::StrideC;
101
- using ElementD = typename CollectiveEpilogue::ElementD;
102
- using StrideD = typename CollectiveEpilogue::StrideD;
103
- using EpilogueArguments = typename CollectiveEpilogue::Arguments;
104
- using EpilogueParams = typename CollectiveEpilogue::Params;
105
-
106
- static_assert(ArchTag::kMinComputeCapability >= 90);
107
- static constexpr uint32_t TileSchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount;
108
-
109
- using TileSchedulerTag = TileSchedulerTag_;
110
- using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
111
- using TileScheduler = typename detail::TileSchedulerSelector<
112
- TileSchedulerTag, ArchTag, TileShape, ClusterShape
113
- ,TileSchedulerPipelineStageCount
114
- >::Scheduler;
115
-
116
- using TileSchedulerArguments = typename TileScheduler::Arguments;
117
- using TileSchedulerParams = typename TileScheduler::Params;
118
-
119
- // Asymmetric buffering
120
- // Tensor A/B could have different buffering, with number of KBLOCK, aka TILEK,
121
- // and STAGEs. It let AsymmetricKRatio, equals KBLOCK_A / KBLOCK_B, to control
122
- // the balance of A/B loading, make sure A/B's pipeline keep same cadence
123
- // when produce / consume data.
124
- // Currently, AsymmetricKRatio = {1, 2} is the only support.
125
- static constexpr bool isAsymmetric = DispatchPolicy::Schedule::isAsymmetric;
126
- static constexpr uint32_t AsymmetricKRatio = isAsymmetric ? 2 : 1;
127
-
128
- // Warp specialization thread count per threadblock
129
- static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp
130
- static constexpr uint32_t NumMMAThreads = size(TiledMma{}); // 8 warps
131
- static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp * 2; // 2 warp
132
- static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp for C
133
-
134
- static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent;
135
- static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled;
136
-
137
- static constexpr uint32_t NumLoadWarpGroups = 1;
138
- static constexpr uint32_t NumMmaWarpGroups = NumMMAThreads / NumThreadsPerWarpGroup;
139
- static constexpr uint32_t MaxThreadsPerBlock = NumMMAThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
140
- static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
141
- static constexpr uint32_t NumFixupBarriers = NumMmaWarpGroups;
142
-
143
- /// Register requirement for Load and Math WGs
144
- static constexpr uint32_t LoadRegisterRequirement = 40;
145
- static constexpr uint32_t MmaRegisterRequirement = 232;
146
-
147
- // 1 stage ordered sequence between mainloop and epilogue producer load threads
148
- using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>;
149
-
150
- using TileSchedulerPipeline = typename TileScheduler::Pipeline;
151
- using TileSchedulerPipelineState = typename TileSchedulerPipeline::PipelineState;
152
- using TileSchedulerThrottlePipeline = typename TileScheduler::ThrottlePipeline;
153
- using TileSchedulerThrottlePipelineState = typename TileSchedulerThrottlePipeline::PipelineState;
154
- using TileSchedulerStorage = typename TileScheduler::SharedStorage;
155
-
156
- // Kernel level shared memory storage
157
- struct SharedStorage {
158
- struct PipelineStorage : cute::aligned_struct<16, _1> {
159
- using MainloopPipelineStorageMK = typename CollectiveMainloop::PipelineStorageMK;
160
- using MainloopPipelineStorageNK = typename CollectiveMainloop::PipelineStorageNK;
161
- using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
162
-
163
- alignas(16) MainloopPipelineStorageMK mainloop_mk;
164
- alignas(16) MainloopPipelineStorageNK mainloop_nk;
165
- alignas(16) EpiLoadPipelineStorage epi_load;
166
- alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
167
- } pipelines;
168
-
169
- alignas(16) TileSchedulerStorage scheduler;
170
-
171
- struct TensorStorage : cute::aligned_struct<128, _1> {
172
- using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
173
- using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
174
-
175
- EpilogueTensorStorage epilogue;
176
- MainloopTensorStorage mainloop;
177
- } tensors;
178
- };
179
-
180
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
181
- static_assert(SharedStorageSize <= cutlass::arch::sm120_smem_capacity_bytes, "SMEM usage exceeded capacity.");
182
-
183
- // Device side arguments
184
- struct Arguments {
185
- GemmUniversalMode mode{};
186
- ProblemShape problem_shape{};
187
- MainloopArguments mainloop{};
188
- EpilogueArguments epilogue{};
189
- KernelHardwareInfo hw_info{};
190
- TileSchedulerArguments scheduler{};
191
- };
192
-
193
- // Kernel entry point API
194
- struct Params {
195
- GemmUniversalMode mode{};
196
- ProblemShape problem_shape{};
197
- MainloopParams mainloop{};
198
- EpilogueParams epilogue{};
199
- KernelHardwareInfo hw_info{};
200
- TileSchedulerParams scheduler{};
201
- void* workspace{nullptr};
202
- };
203
-
204
- //
205
- // Methods
206
- //
207
-
208
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
209
- static
210
- Params
211
- to_underlying_arguments(Arguments const& args, void* workspace) {
212
- CUTLASS_TRACE_HOST("to_underlying_arguments():");
213
-
214
- auto problem_shape = args.problem_shape;
215
- if constexpr (detail::Has_SwapAB_v<CollectiveMainloop>) {
216
- // swap M/N
217
- get<0>(problem_shape) = get<1>(args.problem_shape);
218
- get<1>(problem_shape) = get<0>(args.problem_shape);
219
- }
220
- auto problem_shape_MNKL = append<4>(problem_shape, 1);
221
-
222
- // Get SM count if needed, otherwise use user supplied SM count
223
- int sm_count = args.hw_info.sm_count;
224
- if (sm_count <= 0) {
225
- CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
226
- " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
227
- sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
228
- }
229
-
230
- CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
231
-
232
- KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
233
-
234
- // Calculate workspace pointers
235
- uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
236
- size_t workspace_offset = 0;
237
-
238
- void* epilogue_workspace = workspace_ptr + workspace_offset;
239
- workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
240
- workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
241
-
242
- void* mainloop_workspace = nullptr;
243
-
244
- void* scheduler_workspace = workspace_ptr + workspace_offset;
245
- workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
246
- args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
247
- workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
248
-
249
- // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used
250
- // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means
251
- // subtile will not be used, therefore separate reduction will not be enabled.
252
- constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
253
- TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(
254
- problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles
255
- );
256
-
257
- return {
258
- args.mode,
259
- problem_shape,
260
- CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace),
261
- CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace),
262
- hw_info,
263
- scheduler,
264
- workspace
265
- };
266
- }
267
-
268
- static bool
269
- can_implement(Arguments const& args) {
270
- bool implementable = (args.mode == GemmUniversalMode::kGemm) or
271
- (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
272
- if (!implementable) {
273
- CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
274
- return implementable;
275
- }
276
- implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
277
- implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
278
- implementable &= TileScheduler::can_implement(args.scheduler);
279
- return implementable;
280
- }
281
-
282
- static size_t
283
- get_workspace_size(Arguments const& args) {
284
- size_t workspace_size = 0;
285
- constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
286
-
287
- workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
288
- workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
289
-
290
- workspace_size += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
291
- args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
292
- workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
293
-
294
- return workspace_size;
295
- }
296
-
297
- static cutlass::Status
298
- initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
299
- CudaHostAdapter* cuda_adapter = nullptr) {
300
- Status status = Status::kSuccess;
301
- uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
302
- size_t workspace_offset = 0;
303
- constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
304
- static constexpr uint32_t NumAccumulatorMtxs = 1;
305
-
306
- status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter);
307
- workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
308
- workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
309
- if (status != Status::kSuccess) {
310
- return status;
311
- }
312
-
313
- status = TileScheduler::template initialize_workspace<ProblemShape, ElementAccumulator>(
314
- args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter);
315
- workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
316
- args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
317
- workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
318
- if (status != Status::kSuccess) {
319
- return status;
320
- }
321
-
322
- return status;
323
- }
324
-
325
- // Computes the kernel launch grid shape based on runtime parameters
326
- static dim3
327
- get_grid_shape(Params const& params) {
328
- // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently
329
- TileSchedulerArguments args{};
330
- if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>) {
331
- args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
332
- }
333
- args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM;
334
- return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args);
335
- }
336
-
337
- static dim3
338
- get_block_shape() {
339
- return dim3(MaxThreadsPerBlock, 1, 1);
340
- }
341
-
342
- CUTLASS_DEVICE
343
- void
344
- operator()(Params const& params, char* smem_buf) {
345
- using namespace cute;
346
- using X = Underscore;
347
-
348
- // Preconditions
349
- static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads.");
350
- static_assert(size<0>(TileShape{}) >= 128,
351
- "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension.");
352
-
353
- static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
354
- static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
355
- static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
356
- static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
357
-
358
- /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */
359
- enum class WarpGroupRole {
360
- Producer = 0,
361
- Consumer0 = 1,
362
- Consumer1 = 2
363
- };
364
- enum class ProducerWarpRole {
365
- LoadMK = 0,
366
- Warp1 = 1,
367
- LoadNK = 2,
368
- LoadMN = 3
369
- };
370
-
371
- // Kernel level shared memory storage
372
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
373
-
374
- int thread_idx = int(threadIdx.x);
375
- int lane_idx = canonical_lane_idx();
376
- int warp_idx = canonical_warp_idx_sync();
377
- int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
378
- int mma_thread_idx = thread_idx % NumMMAThreads;
379
- auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
380
- auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
381
- int lane_predicate = cute::elect_one_sync();
382
- uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
383
-
384
- // Issue Tma Descriptor Prefetch from a single thread
385
- if ((warp_idx == 0) && lane_predicate) {
386
- CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
387
- CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
388
- }
389
-
390
- CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
391
- bool is_epi_load_needed = collective_epilogue.is_producer_load_needed();
392
- // TileScheduler pipeline
393
- typename TileSchedulerPipeline::Params scheduler_pipeline_params;
394
- typename TileSchedulerThrottlePipeline::Params scheduler_throttle_pipeline_params;
395
- if constexpr (IsSchedDynamicPersistent) {
396
- if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp1) {
397
- scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::ProducerConsumer;
398
- }
399
- else {
400
- scheduler_pipeline_params.role = TileSchedulerPipeline::ThreadCategory::Consumer;
401
- }
402
- scheduler_pipeline_params.producer_blockid = 0;
403
- scheduler_pipeline_params.producer_arv_count = 1;
404
- scheduler_pipeline_params.consumer_arv_count = NumSchedThreads + (NumMainloopLoadThreads + NumMMAThreads);
405
-
406
- if (is_epi_load_needed) {
407
- scheduler_pipeline_params.consumer_arv_count += NumEpilogueLoadThreads;
408
- }
409
- scheduler_pipeline_params.transaction_bytes = sizeof(typename TileScheduler::CLCResponse);
410
-
411
- scheduler_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads;
412
- scheduler_throttle_pipeline_params.consumer_arv_count = NumSchedThreads;
413
- scheduler_throttle_pipeline_params.dst_blockid = 0;
414
- scheduler_throttle_pipeline_params.initializing_warp = 1;
415
- if (warp_group_role == WarpGroupRole::Producer &&
416
- producer_warp_role == ProducerWarpRole::Warp1) {
417
- scheduler_throttle_pipeline_params.role =
418
- TileSchedulerThrottlePipeline::ThreadCategory::Consumer;
419
- }
420
- // set role when it is for DMA warp in Mainloop
421
- else if (warp_group_role == WarpGroupRole::Producer &&
422
- (producer_warp_role == ProducerWarpRole::LoadMK ||
423
- producer_warp_role == ProducerWarpRole::LoadNK)) {
424
- scheduler_throttle_pipeline_params.role =
425
- TileSchedulerThrottlePipeline::ThreadCategory::Producer;
426
- }
427
- }
428
- TileSchedulerPipeline scheduler_pipeline(shared_storage.scheduler.pipeline(), scheduler_pipeline_params, ClusterShape{});
429
- TileSchedulerPipelineState scheduler_pipe_consumer_state;
430
-
431
- TileSchedulerThrottlePipeline scheduler_throttle_pipeline(shared_storage.scheduler.throttle_pipeline(), scheduler_throttle_pipeline_params);
432
- TileSchedulerThrottlePipelineState scheduler_pipe_throttle_consumer_state;
433
- TileSchedulerThrottlePipelineState scheduler_pipe_throttle_producer_state = cutlass::make_producer_start_state<TileSchedulerThrottlePipeline>();
434
-
435
- // Mainloop Load pipeline
436
- using MainloopPipelineMK = typename CollectiveMainloop::MainloopPipelineMK;
437
- using MainloopPipelineNK = typename CollectiveMainloop::MainloopPipelineNK;
438
- typename MainloopPipelineMK::Params mainloop_pipeline_params_mk;
439
- typename MainloopPipelineNK::Params mainloop_pipeline_params_nk;
440
- if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::LoadMK) {
441
- mainloop_pipeline_params_mk.role = MainloopPipelineMK::ThreadCategory::Producer;
442
- mainloop_pipeline_params_mk.is_leader = cute::elect_one_sync();
443
- mainloop_pipeline_params_mk.transaction_bytes = params.mainloop.tma_transaction_bytes_mk;
444
- }
445
- if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::LoadNK) {
446
- mainloop_pipeline_params_nk.role = MainloopPipelineNK::ThreadCategory::Producer;
447
- mainloop_pipeline_params_nk.is_leader = cute::elect_one_sync();
448
- mainloop_pipeline_params_nk.transaction_bytes = params.mainloop.tma_transaction_bytes_nk;
449
- }
450
- if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) {
451
- mainloop_pipeline_params_mk.role = MainloopPipelineMK::ThreadCategory::Consumer;
452
- mainloop_pipeline_params_nk.role = MainloopPipelineNK::ThreadCategory::Consumer;
453
- }
454
- mainloop_pipeline_params_mk.num_consumers = NumMMAThreads;
455
- mainloop_pipeline_params_nk.num_consumers = NumMMAThreads;
456
-
457
- MainloopPipelineMK mainloop_pipeline_mk(shared_storage.pipelines.mainloop_mk, mainloop_pipeline_params_mk, ClusterShape{});
458
- MainloopPipelineNK mainloop_pipeline_nk(shared_storage.pipelines.mainloop_nk, mainloop_pipeline_params_nk, ClusterShape{});
459
-
460
- // Epilogue Load pipeline
461
- using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
462
- typename EpiLoadPipeline::Params epi_load_pipeline_params;
463
- if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::LoadMN) {
464
- epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
465
- }
466
- if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) {
467
- epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
468
- }
469
- epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
470
- epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads;
471
- epi_load_pipeline_params.consumer_arv_count = NumMMAThreads;
472
- if constexpr (CollectiveEpilogue::RequiresTransactionBytes) {
473
- epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes;
474
- }
475
- EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params);
476
-
477
- // Epilogue Store pipeline
478
- using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
479
- typename EpiStorePipeline::Params epi_store_pipeline_params;
480
- epi_store_pipeline_params.always_wait = true;
481
- EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
482
-
483
- typename LoadWarpOrderBarrier::Params params_load_order_barrier;
484
- // 2 warps (LoadMK / LoadNK) are ordered before 1 warp (LoadMN) and will signal arrival.
485
- params_load_order_barrier.group_id = (
486
- producer_warp_role == ProducerWarpRole::LoadMK ||
487
- producer_warp_role == ProducerWarpRole::LoadNK) ? 0 : 1;
488
- params_load_order_barrier.group_size = NumThreadsPerWarp * 2;
489
- LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier);
490
-
491
- // Initialize starting pipeline states for the collectives
492
- // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
493
- typename CollectiveMainloop::PipelineStateMK mainloop_pipe_consumer_state_mk;
494
- typename CollectiveMainloop::PipelineStateNK mainloop_pipe_consumer_state_nk;
495
- typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
496
-
497
- // For the DMA Load (producer) we start with an opposite phase
498
- // i.e., we skip all waits since we know that the buffer is indeed empty
499
- typename CollectiveMainloop::PipelineStateMK mainloop_pipe_producer_state_mk = cutlass::make_producer_start_state<MainloopPipelineMK>();
500
- typename CollectiveMainloop::PipelineStateNK mainloop_pipe_producer_state_nk = cutlass::make_producer_start_state<MainloopPipelineNK>();
501
- PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state<EpiLoadPipeline>();
502
- PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
503
-
504
- auto cluster_wait_fn = [] () {
505
- // We need this to guarantee that the Pipeline init is visible
506
- // To all producers and consumer thread blocks in the Cluster
507
- if constexpr (size(ClusterShape{}) > 1) {
508
- cute::cluster_arrive_relaxed();
509
- return [] () { cute::cluster_wait(); };
510
- }
511
- else {
512
- __syncthreads();
513
- return [] () {}; // do nothing
514
- }
515
- } ();
516
-
517
- // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
518
- auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
519
-
520
- // Get the appropriate blocks for this thread block -- potential for thread block locality
521
- TiledMma tiled_mma;
522
- auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
523
-
524
- TileScheduler scheduler{params.scheduler};
525
- if constexpr (IsSchedDynamicPersistent) {
526
- scheduler.set_data_ptr(shared_storage.scheduler.data());
527
- }
528
- // Declare work_tile_info, then define it in each of warps that use it.
529
- typename TileScheduler::WorkTileInfo work_tile_info;
530
-
531
- // In a warp specialized kernel, collectives expose data movement and compute operations separately
532
- CollectiveMainloop collective_mainloop;
533
-
534
- // Prepare and partition the input tensors. Expects a tuple of tensors where:
535
- // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
536
- // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
537
- auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
538
- static_assert(cute::tuple_size_v<decltype(load_inputs)> >= 2, "Output of load_init must have at least two elements (A, B)");
539
-
540
- // Extract out partitioned A and B.
541
- Tensor gA_mkl = get<0>(load_inputs);
542
- Tensor gB_nkl = get<1>(load_inputs);
543
-
544
- // Wait for all thread blocks in the Cluster
545
- cluster_wait_fn();
546
-
547
- if (warp_group_role == WarpGroupRole::Producer) {
548
- cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
549
-
550
- // Scheduler Producer Warp
551
- if (producer_warp_role == ProducerWarpRole::Warp1) {
552
- work_tile_info = scheduler.initial_work_tile_info(ClusterShape{});
553
-
554
- if constexpr (IsSchedDynamicPersistent) {
555
- bool requires_clc_query = true;
556
- TileSchedulerPipelineState scheduler_pipe_producer_state = cutlass::make_producer_start_state<TileSchedulerPipeline>();
557
-
558
- cutlass::arch::wait_on_dependent_grids();
559
-
560
- while (work_tile_info.is_valid()) {
561
- if (requires_clc_query) {
562
- // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers.
563
- scheduler_throttle_pipeline.consumer_wait(scheduler_pipe_throttle_consumer_state);
564
- scheduler_throttle_pipeline.consumer_release(scheduler_pipe_throttle_consumer_state);
565
- ++scheduler_pipe_throttle_consumer_state;
566
-
567
- // Query next clcID and update producer state
568
- scheduler_pipe_producer_state = scheduler.advance_to_next_work(scheduler_pipeline, scheduler_pipe_producer_state);
569
- }
570
- // Fetch next work tile
571
- auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(
572
- work_tile_info,
573
- scheduler_pipeline,
574
- scheduler_pipe_consumer_state
575
- );
576
- requires_clc_query = increment_pipe;
577
- if (increment_pipe) {
578
- ++scheduler_pipe_consumer_state;
579
- }
580
- work_tile_info = next_work_tile_info;
581
- }
582
- scheduler_pipeline.producer_tail(scheduler_pipe_producer_state);
583
- }
584
- } // Scheduler Producer Warp End
585
- else
586
- // Producer Warp to LoadMK
587
- if (producer_warp_role == ProducerWarpRole::LoadMK) {
588
- work_tile_info = scheduler.initial_work_tile_info(ClusterShape{});
589
-
590
- // Ensure that the prefetched kernel does not touch
591
- // unflushed global memory prior to this instruction
592
- cutlass::arch::wait_on_dependent_grids();
593
- bool do_load_order_arrive = true;
594
- bool requires_clc_query = true;
595
- while (work_tile_info.is_valid()) {
596
- if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
597
- auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info);
598
- work_tile_info = next_work_tile_info;
599
- continue;
600
- }
601
-
602
- // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
603
- auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
604
- auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
605
- auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
606
- auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
607
-
608
- // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work.
609
- auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
610
- auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info);
611
- auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
612
-
613
- if (requires_clc_query) {
614
- scheduler_throttle_pipeline.producer_acquire(scheduler_pipe_throttle_producer_state);
615
- scheduler_throttle_pipeline.producer_commit(scheduler_pipe_throttle_producer_state);
616
- ++scheduler_pipe_throttle_producer_state;
617
- }
618
-
619
- collective_mainloop.load_MK(
620
- params.mainloop,
621
- mainloop_pipeline_mk,
622
- mainloop_pipe_producer_state_mk,
623
- load_inputs,
624
- blk_coord,
625
- k_tile_iter, work_k_tile_count,
626
- lane_idx,
627
- block_rank_in_cluster,
628
- shared_storage.tensors.mainloop
629
- );
630
- // Update starting pipeline state for the next tile
631
- mainloop_pipe_producer_state_mk.advance(work_k_tile_count);
632
-
633
- // Signal for the epilogue load warp to begin
634
- if (do_load_order_arrive) {
635
- load_order_barrier.arrive();
636
- do_load_order_arrive = false;
637
- }
638
- // Get next work tile
639
- auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info
640
- ,scheduler_pipeline
641
- ,scheduler_pipe_consumer_state
642
- );
643
- work_tile_info = next_work_tile_info;
644
- if constexpr (IsSchedDynamicPersistent) {
645
- requires_clc_query = increment_pipe;
646
- if (increment_pipe) {
647
- ++scheduler_pipe_consumer_state;
648
- }
649
- }
650
- } // Scheduler work fetch loop
651
-
652
- // Make sure all Consumer Warp Groups have been waited upon
653
- collective_mainloop.load_tail(mainloop_pipeline_mk, mainloop_pipe_producer_state_mk);
654
-
655
- } // Producer Warp LoadMK End
656
-
657
- // LoadNK Producer Warp
658
- if (producer_warp_role == ProducerWarpRole::LoadNK) {
659
- work_tile_info = scheduler.initial_work_tile_info(ClusterShape{});
660
-
661
- // Ensure that the prefetched kernel does not touch
662
- // unflushed global memory prior to this instruction
663
- cutlass::arch::wait_on_dependent_grids();
664
-
665
- bool do_load_order_arrive = true;
666
- bool requires_clc_query = true;
667
- while (work_tile_info.is_valid()) {
668
- if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
669
- auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info);
670
- work_tile_info = next_work_tile_info;
671
- continue;
672
- }
673
-
674
- // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
675
- auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
676
- auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
677
- auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
678
- auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
679
-
680
- // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work.
681
- auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape) * AsymmetricKRatio;
682
- auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info) * AsymmetricKRatio;
683
- auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
684
-
685
- if (requires_clc_query) {
686
- scheduler_throttle_pipeline.producer_acquire(scheduler_pipe_throttle_producer_state);
687
- scheduler_throttle_pipeline.producer_commit(scheduler_pipe_throttle_producer_state);
688
- ++scheduler_pipe_throttle_producer_state;
689
- }
690
-
691
- collective_mainloop.load_NK(
692
- params.mainloop,
693
- mainloop_pipeline_nk,
694
- mainloop_pipe_producer_state_nk,
695
- load_inputs,
696
- blk_coord,
697
- k_tile_iter, work_k_tile_count,
698
- lane_idx,
699
- block_rank_in_cluster,
700
- shared_storage.tensors.mainloop
701
- );
702
- // Update starting pipeline state for the next tile
703
- mainloop_pipe_producer_state_nk.advance(work_k_tile_count);
704
-
705
- // Signal for the epilogue load warp to begin
706
- if (do_load_order_arrive) {
707
- load_order_barrier.arrive();
708
- do_load_order_arrive = false;
709
- }
710
- // Get next work tile
711
- auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info
712
- ,scheduler_pipeline
713
- ,scheduler_pipe_consumer_state
714
- );
715
- work_tile_info = next_work_tile_info;
716
- if constexpr (IsSchedDynamicPersistent) {
717
- requires_clc_query = increment_pipe;
718
- if (increment_pipe) {
719
- ++scheduler_pipe_consumer_state;
720
- }
721
- }
722
- } // Scheduler work fetch loop
723
-
724
- // Make sure all Consumer Warp Groups have been waited upon
725
- collective_mainloop.load_tail(mainloop_pipeline_nk, mainloop_pipe_producer_state_nk);
726
-
727
- } // Producer Warp LoadNK End
728
- // Epilogue Producer Warp
729
- else if (producer_warp_role == ProducerWarpRole::LoadMN &&
730
- is_epi_load_needed) {
731
- work_tile_info = scheduler.initial_work_tile_info(ClusterShape{});
732
-
733
- // Ensure that the prefetched kernel does not touch
734
- // unflushed global memory prior to this instruction
735
- cutlass::arch::wait_on_dependent_grids();
736
-
737
- if (!TileScheduler::requires_separate_reduction(params.scheduler) && work_tile_info.is_valid()) {
738
- load_order_barrier.wait();
739
- }
740
- CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
741
-
742
- while (work_tile_info.is_valid()) {
743
- if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
744
- // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
745
- auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
746
- auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
747
- auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
748
- auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
749
-
750
- epi_load_pipe_producer_state =
751
- collective_epilogue.load(
752
- epi_load_pipeline,
753
- epi_load_pipe_producer_state,
754
- problem_shape_MNKL,
755
- blk_shape,
756
- blk_coord,
757
- tiled_mma,
758
- lane_idx,
759
- shared_storage.tensors.epilogue,
760
- work_tile_info.reduction_subtile_idx()
761
- );
762
- }
763
-
764
- // Get next work tile
765
- auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info
766
- ,scheduler_pipeline
767
- ,scheduler_pipe_consumer_state
768
- );
769
- work_tile_info = next_work_tile_info;
770
- if constexpr (IsSchedDynamicPersistent) {
771
- if (increment_pipe) {
772
- ++scheduler_pipe_consumer_state;
773
- }
774
- }
775
- } // Scheduler work fetch loop
776
-
777
- // Make sure all Consumer Warp Groups have been waited upon
778
- collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state);
779
- } // Producer Warp LoadMN End
780
- } // Producer Warp Group End
781
-
782
- else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) {
783
- work_tile_info = scheduler.initial_work_tile_info(ClusterShape{});
784
-
785
- cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
786
-
787
- CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue);
788
-
789
- // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it
790
- bool do_store_tail = false;
791
- while (work_tile_info.is_valid()) {
792
- // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
793
- auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
794
- auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
795
- auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
796
- auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
797
-
798
- // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work.
799
- auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
800
- auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info);
801
- auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
802
-
803
- // Allocate the accumulators for the (M,N) blk_shape
804
- //
805
- // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
806
- auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
807
- if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
808
-
809
- collective_mainloop.mma(
810
- mainloop_pipeline_mk,
811
- mainloop_pipe_consumer_state_mk,
812
- mainloop_pipeline_nk,
813
- mainloop_pipe_consumer_state_nk,
814
- accumulators,
815
- k_tile_iter,
816
- work_k_tile_count,
817
- mma_thread_idx,
818
- shared_storage.tensors.mainloop,
819
- params.mainloop,
820
- blk_coord,
821
- problem_shape_MNKL
822
- );
823
-
824
- // Make sure the math instructions are done and free buffers before entering the epilogue
825
- collective_mainloop.mma_tail(
826
- mainloop_pipeline_mk,
827
- mainloop_pipe_consumer_state_mk,
828
- mainloop_pipeline_nk,
829
- mainloop_pipe_consumer_state_nk,
830
- work_k_tile_count
831
- );
832
-
833
- // Update starting mainloop pipeline state for the next tile
834
- mainloop_pipe_consumer_state_mk.advance(work_k_tile_count);
835
- mainloop_pipe_consumer_state_nk.advance(work_k_tile_count * AsymmetricKRatio);
836
- }
837
- #ifdef CUTLASS_ENABLE_GDC_FOR_SM90
838
- if (scheduler.is_last_tile(work_tile_info)) {
839
- // Hint on an early release of global memory resources.
840
- // The timing of calling this function only influences performance,
841
- // not functional correctness.
842
- cutlass::arch::launch_dependent_grids();
843
-
844
- }
845
- #endif
846
-
847
- // Index of warp group within consumer warp groups
848
- int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups;
849
-
850
- // Perform reduction across splits, if needed
851
- TileScheduler::fixup(
852
- params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx);
853
-
854
- if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
855
- // Epilogue and write to gD
856
- auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] =
857
- collective_epilogue.store(
858
- epi_load_pipeline,
859
- epi_load_pipe_consumer_state,
860
- epi_store_pipeline,
861
- epi_store_pipe_producer_state,
862
- problem_shape_MNKL,
863
- blk_shape,
864
- blk_coord,
865
- accumulators,
866
- tiled_mma,
867
- mma_thread_idx,
868
- shared_storage.tensors.epilogue,
869
- work_tile_info.reduction_subtile_idx()
870
- );
871
- epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next;
872
- epi_store_pipe_producer_state = epi_store_pipe_producer_state_next;
873
- do_store_tail = true;
874
- }
875
-
876
- // Get next work tile
877
- auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info
878
- ,scheduler_pipeline
879
- ,scheduler_pipe_consumer_state
880
- );
881
- work_tile_info = next_work_tile_info;
882
- if constexpr (IsSchedDynamicPersistent) {
883
- if (increment_pipe) {
884
- ++scheduler_pipe_consumer_state;
885
- }
886
- }
887
- } // Scheduler work fetch loop
888
-
889
- if (do_store_tail) {
890
- collective_epilogue.store_tail(
891
- epi_load_pipeline,
892
- epi_load_pipe_consumer_state,
893
- epi_store_pipeline,
894
- epi_store_pipe_producer_state
895
- );
896
- }
897
- } // Consumer Warp Groups End
898
- }
899
-
900
- };
901
-
902
- ///////////////////////////////////////////////////////////////////////////////
903
-
904
- } // namespace cutlass::gemm::kernel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm70_gemm.hpp DELETED
@@ -1,270 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- #pragma once
32
-
33
- #include "cutlass/cutlass.h"
34
- #include "cutlass/kernel_hardware_info.hpp"
35
- #include "cutlass/gemm/gemm.h"
36
- #include "cutlass/gemm/dispatch_policy.hpp"
37
-
38
- #include "cute/tensor.hpp"
39
-
40
- namespace cutlass::gemm::kernel {
41
-
42
- ///////////////////////////////////////////////////////////////////////////////
43
-
44
- template <
45
- class ProblemShape_,
46
- class CollectiveMainloop_,
47
- class CollectiveEpilogue_,
48
- class TileScheduler_
49
- >
50
- class GemmUniversal<
51
- ProblemShape_,
52
- CollectiveMainloop_,
53
- CollectiveEpilogue_,
54
- TileScheduler_,
55
- cute::enable_if_t<cute::is_base_of_v<KernelMultistage, typename CollectiveMainloop_::DispatchPolicy::Schedule>>>
56
- {
57
- public:
58
- //
59
- // Type Aliases
60
- //
61
- using ProblemShape = ProblemShape_;
62
- static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
63
- "ProblemShape{} should be <M,N,K> or <M,N,K,L>");
64
-
65
- // Mainloop derived types
66
- using CollectiveMainloop = CollectiveMainloop_;
67
- using TileShape = typename CollectiveMainloop::TileShape;
68
- using TiledMma = typename CollectiveMainloop::TiledMma;
69
- using ArchTag = typename CollectiveMainloop::ArchTag;
70
- using ElementA = typename CollectiveMainloop::ElementA;
71
- using StrideA = typename CollectiveMainloop::StrideA;
72
- using ElementB = typename CollectiveMainloop::ElementB;
73
- using StrideB = typename CollectiveMainloop::StrideB;
74
- using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
75
- using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
76
- using MainloopArguments = typename CollectiveMainloop::Arguments;
77
- using MainloopParams = typename CollectiveMainloop::Params;
78
-
79
- using TileSchedulerTag = TileScheduler_;
80
- using TileScheduler = typename detail::TileSchedulerSelector<
81
- TileScheduler_, ArchTag, TileShape,
82
- cute::Shape<cute::Int<1>, cute::Int<1>, cute::Int<1>>>::Scheduler;
83
- using TileSchedulerArguments = typename TileScheduler::Arguments;
84
- static constexpr bool IsGdcEnabled = false;
85
-
86
- static constexpr bool is_valid_tile_scheduler =
87
- cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>;
88
- static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializing the tile scheduler.");
89
-
90
- // Epilogue derived types
91
- using CollectiveEpilogue = CollectiveEpilogue_;
92
- using ElementC = typename CollectiveEpilogue::ElementC;
93
- using StrideC = typename CollectiveEpilogue::StrideC;
94
- using ElementD = typename CollectiveEpilogue::ElementD;
95
- using StrideD = typename CollectiveEpilogue::StrideD;
96
- using EpilogueArguments = typename CollectiveEpilogue::Arguments;
97
- using EpilogueParams = typename CollectiveEpilogue::Params;
98
- static_assert(cute::is_same_v<ElementAccumulator, typename CollectiveEpilogue::ElementAccumulator>,
99
- "Mainloop and epilogue do not agree on accumulator value type.");
100
-
101
- // MSVC requires the cast to fix a warning-as-error.
102
- static constexpr int SharedStorageSize = static_cast<int>(cute::max(
103
- sizeof(typename CollectiveMainloop::SharedStorage),
104
- sizeof(typename CollectiveEpilogue::SharedStorage)));
105
-
106
- static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(cute::size(TiledMma{}));
107
- static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
108
-
109
- // Device side arguments
110
- struct Arguments {
111
- GemmUniversalMode mode{};
112
- ProblemShape problem_shape{};
113
- MainloopArguments mainloop{};
114
- EpilogueArguments epilogue{};
115
- KernelHardwareInfo hw_info{};
116
- TileSchedulerArguments scheduler{};
117
- };
118
-
119
- // Kernel entry point API
120
- struct Params {
121
- GemmUniversalMode mode{};
122
- ProblemShape problem_shape{};
123
- MainloopParams mainloop{};
124
- EpilogueParams epilogue{};
125
- };
126
-
127
- //
128
- // Methods
129
- //
130
-
131
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
132
- static
133
- Params
134
- to_underlying_arguments(Arguments const& args, void* workspace) {
135
- (void) workspace;
136
-
137
- KernelHardwareInfo hw_info{args.hw_info.device_id, args.hw_info.sm_count};
138
- auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{});
139
-
140
- return {
141
- args.mode,
142
- args.problem_shape,
143
- CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
144
- CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace)
145
- };
146
- }
147
-
148
- static bool
149
- can_implement(Arguments const& args) {
150
- bool mode_implementable = args.mode == GemmUniversalMode::kGemm or
151
- (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
152
- return mode_implementable && TileScheduler::can_implement(args.scheduler);
153
- }
154
-
155
- static size_t
156
- get_workspace_size(Arguments const& args) {
157
- size_t workspace_size = 0;
158
- return workspace_size;
159
- }
160
-
161
- static
162
- cutlass::Status
163
- initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
164
- CudaHostAdapter* cuda_adapter = nullptr) {
165
- cutlass::Status status = Status::kSuccess;
166
-
167
- return status;
168
- }
169
-
170
- static dim3
171
- get_grid_shape(Params const& params) {
172
- int batch_count = 1;
173
- if constexpr (cute::rank(ProblemShape{}) == 4) {
174
- batch_count = cute::size<3>(params.problem_shape);
175
- }
176
-
177
- return dim3(
178
- cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))),
179
- cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))),
180
- batch_count
181
- );
182
- }
183
-
184
- static dim3
185
- get_block_shape() {
186
- return dim3(MaxThreadsPerBlock, 1, 1);
187
- }
188
-
189
- CUTLASS_DEVICE
190
- void
191
- operator()(Params const& params, char* smem_buf) {
192
- using namespace cute;
193
- using X = Underscore;
194
-
195
- // Preconditions
196
- CUTE_STATIC_ASSERT(is_static<TileShape>::value);
197
-
198
- // Separate out problem shape for convenience
199
- // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
200
- auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
201
- auto [M,N,K,L] = problem_shape_MNKL;
202
-
203
- // Preconditions
204
- static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
205
- static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
206
- static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
207
- static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
208
-
209
- // Get the appropriate blocks for this thread block -- potential for thread block locality
210
- int thread_idx = int(threadIdx.x);
211
- auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
212
- auto [m_coord, n_coord, l_coord] = static_cast<uint3>(blockIdx);
213
- auto blk_coord_mnkl = make_coord(int(m_coord), int(n_coord), _, int(l_coord)); // (m,n,k,l)
214
-
215
- // Represent the full tensors
216
- Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA); //(m,k,l)
217
- Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB); //(n,k,l)
218
-
219
- // Get batch slice
220
- Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k)
221
- Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k)
222
-
223
- // Slice to get the tiles this thread block is responsible for
224
- Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
225
- Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
226
-
227
- // Compute tile residues for predication
228
- auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord
229
- auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord
230
- auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max
231
- auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue);
232
-
233
- // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape
234
- TiledMma tiled_mma;
235
- Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
236
- clear(accumulators);
237
-
238
- auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA));
239
- int k_tile_count = size<2>(gA);
240
-
241
- // Perform the collective scoped MMA
242
- CollectiveMainloop collective_mma;
243
- collective_mma(
244
- accumulators,
245
- gA,
246
- gB,
247
- accumulators,
248
- k_tile_iter, k_tile_count,
249
- residue_mnk,
250
- thread_idx,
251
- smem_buf
252
- );
253
- // Epilogue and write to gD
254
- CollectiveEpilogue epilogue{params.epilogue};
255
- epilogue(
256
- problem_shape_MNKL,
257
- blk_shape,
258
- blk_coord_mnkl,
259
- accumulators,
260
- tiled_mma,
261
- residue_mnk,
262
- thread_idx,
263
- smem_buf
264
- );
265
- }
266
- };
267
-
268
- ///////////////////////////////////////////////////////////////////////////////
269
-
270
- } // namespace cutlass::gemm::kernel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch29-cxx11-cu130-x86_64-linux/include/third-party/cutlass/include/cutlass/gemm/kernel/sm70_gemm_array.hpp DELETED
@@ -1,279 +0,0 @@
1
- /***************************************************************************************************
2
- * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- * SPDX-License-Identifier: BSD-3-Clause
4
- *
5
- * Redistribution and use in source and binary forms, with or without
6
- * modification, are permitted provided that the following conditions are met:
7
- *
8
- * 1. Redistributions of source code must retain the above copyright notice, this
9
- * list of conditions and the following disclaimer.
10
- *
11
- * 2. Redistributions in binary form must reproduce the above copyright notice,
12
- * this list of conditions and the following disclaimer in the documentation
13
- * and/or other materials provided with the distribution.
14
- *
15
- * 3. Neither the name of the copyright holder nor the names of its
16
- * contributors may be used to endorse or promote products derived from
17
- * this software without specific prior written permission.
18
- *
19
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
- *
30
- **************************************************************************************************/
31
- #pragma once
32
-
33
- #include "cutlass/cutlass.h"
34
- #include "cutlass/kernel_hardware_info.hpp"
35
- #include "cutlass/gemm/gemm.h"
36
- #include "cutlass/gemm/dispatch_policy.hpp"
37
-
38
- #include "cute/tensor.hpp"
39
-
40
- namespace cutlass::gemm::kernel {
41
-
42
- ///////////////////////////////////////////////////////////////////////////////
43
-
44
- template <
45
- class ProblemShape_,
46
- class CollectiveMainloop_,
47
- class CollectiveEpilogue_,
48
- class TileScheduler_
49
- >
50
- class GemmUniversal<
51
- ProblemShape_,
52
- CollectiveMainloop_,
53
- CollectiveEpilogue_,
54
- TileScheduler_,
55
- cute::enable_if_t<cute::is_base_of_v<KernelPtrArrayMultistage, typename CollectiveMainloop_::DispatchPolicy::Schedule>>>
56
- {
57
- public:
58
- //
59
- // Type Aliases
60
- //
61
- using ProblemShape = ProblemShape_;
62
- static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 4,
63
- "ProblemShape{} should be <M,N,K> or <M,N,K,L>");
64
-
65
- // Mainloop derived types
66
- using CollectiveMainloop = CollectiveMainloop_;
67
- using TileShape = typename CollectiveMainloop::TileShape;
68
- using TiledMma = typename CollectiveMainloop::TiledMma;
69
- using ArchTag = typename CollectiveMainloop::ArchTag;
70
- using ElementA = typename CollectiveMainloop::ElementA;
71
- using StrideA = typename CollectiveMainloop::StrideA;
72
- using InternalStrideA = typename CollectiveMainloop::InternalStrideA;
73
- using ElementB = typename CollectiveMainloop::ElementB;
74
- using StrideB = typename CollectiveMainloop::StrideB;
75
- using InternalStrideB = typename CollectiveMainloop::InternalStrideB;
76
- using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
77
- using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
78
- using MainloopArguments = typename CollectiveMainloop::Arguments;
79
- using MainloopParams = typename CollectiveMainloop::Params;
80
-
81
- using TileSchedulerTag = TileScheduler_;
82
- using TileScheduler = typename detail::TileSchedulerSelector<
83
- TileScheduler_, ArchTag, TileShape,
84
- cute::Shape<cute::Int<1>, cute::Int<1>, cute::Int<1>>>::Scheduler;
85
- using TileSchedulerArguments = typename TileScheduler::Arguments;
86
- static constexpr bool IsGdcEnabled = false;
87
-
88
- static constexpr bool is_valid_tile_scheduler =
89
- cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler>;
90
- static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializing the tile scheduler.");
91
-
92
- // Epilogue derived types
93
- using CollectiveEpilogue = CollectiveEpilogue_;
94
- using ElementC = typename CollectiveEpilogue::ElementC;
95
- using StrideC = typename CollectiveEpilogue::StrideC;
96
- using InternalStrideC = typename CollectiveEpilogue::InternalStrideC;
97
- using ElementD = typename CollectiveEpilogue::ElementD;
98
- using StrideD = typename CollectiveEpilogue::StrideD;
99
- using InternalStrideD = typename CollectiveEpilogue::InternalStrideD;
100
- using EpilogueArguments = typename CollectiveEpilogue::Arguments;
101
- using EpilogueParams = typename CollectiveEpilogue::Params;
102
- static_assert(cute::is_same_v<ElementAccumulator, typename CollectiveEpilogue::ElementAccumulator>,
103
- "Mainloop and epilogue do not agree on accumulator value type.");
104
-
105
- // MSVC requires the cast to fix a warning-as-error.
106
- static constexpr int SharedStorageSize = static_cast<int>(cute::max(
107
- sizeof(typename CollectiveMainloop::SharedStorage),
108
- sizeof(typename CollectiveEpilogue::SharedStorage)));
109
-
110
- static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(cute::size(TiledMma{}));
111
- static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
112
-
113
- // Device side arguments
114
- struct Arguments {
115
- GemmUniversalMode mode{};
116
- ProblemShape problem_shape{};
117
- MainloopArguments mainloop{};
118
- EpilogueArguments epilogue{};
119
- KernelHardwareInfo hw_info{};
120
- TileSchedulerArguments scheduler{};
121
- };
122
-
123
- // Kernel entry point API
124
- struct Params {
125
- GemmUniversalMode mode{};
126
- typename ProblemShape::UnderlyingProblemShape problem_shape{};
127
- MainloopParams mainloop{};
128
- EpilogueParams epilogue{};
129
- };
130
-
131
- //
132
- // Methods
133
- //
134
-
135
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
136
- static
137
- Params
138
- to_underlying_arguments(Arguments const& args, void* workspace) {
139
- (void) workspace;
140
- typename ProblemShape::UnderlyingProblemShape problem_shape = args.problem_shape.get_host_problem_shape();
141
-
142
- KernelHardwareInfo hw_info{args.hw_info.device_id, args.hw_info.sm_count};
143
- auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{});
144
-
145
- return {
146
- args.mode,
147
- problem_shape,
148
- CollectiveMainloop::to_underlying_arguments(problem_shape, args.mainloop, workspace),
149
- CollectiveEpilogue::to_underlying_arguments(problem_shape, args.epilogue, workspace)
150
- };
151
- }
152
-
153
- static bool
154
- can_implement(Arguments const& args) {
155
-
156
- bool implementable = (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4);
157
- if (!implementable) {
158
- CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
159
- return implementable;
160
- }
161
- typename ProblemShape::UnderlyingProblemShape problem_shape = args.problem_shape.get_host_problem_shape();
162
- implementable &= TileScheduler::can_implement(args.scheduler);
163
- return implementable;
164
- }
165
-
166
- static size_t
167
- get_workspace_size(Arguments const& args) {
168
- size_t workspace_size = 0;
169
- return workspace_size;
170
- }
171
-
172
- static
173
- cutlass::Status
174
- initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
175
- CudaHostAdapter* cuda_adapter = nullptr) {
176
- cutlass::Status status = Status::kSuccess;
177
-
178
- return status;
179
- }
180
-
181
- static dim3
182
- get_grid_shape(Params const& params) {
183
- int batch_count = cute::size<3>(params.problem_shape);
184
- return dim3(
185
- cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))),
186
- cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))),
187
- batch_count
188
- );
189
- }
190
-
191
- static dim3
192
- get_block_shape() {
193
- return dim3(MaxThreadsPerBlock, 1, 1);
194
- }
195
-
196
- CUTLASS_DEVICE
197
- void
198
- operator()(Params const& params, char* smem_buf) {
199
- using namespace cute;
200
- using X = Underscore;
201
-
202
- // Preconditions
203
- CUTE_STATIC_ASSERT(is_static<TileShape>::value);
204
-
205
- // Separate out problem shape for convenience
206
- // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
207
- auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
208
- auto [M,N,K,L] = problem_shape_MNKL;
209
-
210
- // Preconditions
211
- static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
212
- static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
213
- static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
214
- static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
215
-
216
- // Get the appropriate blocks for this thread block -- potential for thread block locality
217
- int thread_idx = int(threadIdx.x);
218
- auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
219
- auto [m_coord, n_coord, l_coord] = static_cast<uint3>(blockIdx);
220
- auto blk_coord_mnkl = make_coord(int(m_coord), int(n_coord), _, int(l_coord)); // (m,n,k,l)
221
-
222
- // Represent the full tensors
223
- Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A[l_coord]), make_shape(M,K,1), params.mainloop.dA); //(m,k,l)
224
- Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B[l_coord]), make_shape(N,K,1), params.mainloop.dB); //(n,k,l)
225
-
226
- // Get batch slice
227
- Tensor mA_mk = mA_mkl(_,_,0); // (m,k)
228
- Tensor mB_nk = mB_nkl(_,_,0); // (n,k)
229
-
230
- // Slice to get the tiles this thread block is responsible for
231
- Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
232
- Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
233
-
234
- // Compute tile residues for predication
235
- auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord
236
- auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord
237
- auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max
238
- auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue);
239
-
240
- // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape
241
- TiledMma tiled_mma;
242
- Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
243
- clear(accumulators);
244
-
245
- auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA));
246
- int k_tile_count = size<2>(gA);
247
-
248
-
249
- // Perform the collective scoped MMA
250
- CollectiveMainloop collective_mma;
251
- collective_mma(
252
- accumulators,
253
- gA,
254
- gB,
255
- accumulators,
256
- k_tile_iter, k_tile_count,
257
- residue_mnk,
258
- thread_idx,
259
- smem_buf
260
- );
261
-
262
- // Epilogue and write to gD
263
- CollectiveEpilogue epilogue{params.epilogue};
264
- epilogue(
265
- problem_shape_MNKL,
266
- blk_shape,
267
- blk_coord_mnkl,
268
- accumulators,
269
- tiled_mma,
270
- residue_mnk,
271
- thread_idx,
272
- smem_buf
273
- );
274
- }
275
- };
276
-
277
- ///////////////////////////////////////////////////////////////////////////////
278
-
279
- } // namespace cutlass::gemm::kernel