Kernels
kernels-bot commited on
Commit
dbbf646
·
verified ·
1 Parent(s): af53096

Uploaded using `kernel-builder`.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch211-cxx11-cu126-x86_64-linux/__init__.py +5 -2
  2. build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/__init__.py +0 -0
  3. build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py +0 -0
  4. build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py +574 -0
  5. build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/adapter.py +53 -0
  6. build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/configs.py +5 -0
  7. build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/gmm.py +567 -0
  8. build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/__init__.py +0 -0
  9. build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py +0 -0
  10. build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py +46 -0
  11. build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py +100 -0
  12. build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py +752 -0
  13. build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/logger.py +47 -0
  14. build/torch211-cxx11-cu126-x86_64-linux/{_megablocks_cuda_ae601bb.abi3.so → _megablocks_cuda_f8f8b50.abi3.so} +2 -2
  15. build/torch211-cxx11-cu126-x86_64-linux/_ops.py +3 -3
  16. build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py +9 -9
  17. build/torch211-cxx11-cu126-x86_64-linux/metadata.json +3 -2
  18. build/torch211-cxx11-cu128-x86_64-linux/__init__.py +5 -2
  19. build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/__init__.py +0 -0
  20. build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py +0 -0
  21. build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py +574 -0
  22. build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/adapter.py +53 -0
  23. build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/configs.py +5 -0
  24. build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/gmm.py +567 -0
  25. build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/__init__.py +0 -0
  26. build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py +0 -0
  27. build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py +46 -0
  28. build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py +100 -0
  29. build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py +752 -0
  30. build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/logger.py +47 -0
  31. build/torch211-cxx11-cu128-x86_64-linux/{_megablocks_cuda_ae601bb.abi3.so → _megablocks_cuda_f8f8b50.abi3.so} +2 -2
  32. build/torch211-cxx11-cu128-x86_64-linux/_ops.py +3 -3
  33. build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py +9 -9
  34. build/torch211-cxx11-cu128-x86_64-linux/metadata.json +2 -1
  35. build/torch211-cxx11-cu130-x86_64-linux/__init__.py +5 -2
  36. build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/__init__.py +0 -0
  37. build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py +0 -0
  38. build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py +574 -0
  39. build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/adapter.py +53 -0
  40. build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/configs.py +5 -0
  41. build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/gmm.py +567 -0
  42. build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/__init__.py +0 -0
  43. build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py +0 -0
  44. build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py +46 -0
  45. build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py +100 -0
  46. build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py +752 -0
  47. build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/logger.py +47 -0
  48. build/torch211-cxx11-cu130-x86_64-linux/{_megablocks_cuda_ae601bb.abi3.so → _megablocks_cuda_f8f8b50.abi3.so} +2 -2
  49. build/torch211-cxx11-cu130-x86_64-linux/_ops.py +3 -3
  50. build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py +9 -9
build/torch211-cxx11-cu126-x86_64-linux/__init__.py CHANGED
@@ -3,7 +3,9 @@
3
 
4
  import torch
5
 
6
- from ._ops import ops
 
 
7
 
8
  from .grouped_gemm import backend as gg_backend
9
  from .grouped_gemm import ops as gg_ops
@@ -136,7 +138,8 @@ def sort(
136
  Returns:
137
  The sorted values tensor
138
  """
139
- return ops.sort(x, end_bit, x_out, iota_out)
 
140
 
141
 
142
  # Convenience functions for common use cases
 
3
 
4
  import torch
5
 
6
+ # Stable alias: bare `ops` is shadowed by `from . import layers` below.
7
+ from ._ops import ops as _compiled_ops
8
+ from . import ops
9
 
10
  from .grouped_gemm import backend as gg_backend
11
  from .grouped_gemm import ops as gg_ops
 
138
  Returns:
139
  The sorted values tensor
140
  """
141
+ _compiled_ops.sort(x, end_bit, x_out, iota_out)
142
+ return x_out
143
 
144
 
145
  # Convenience functions for common use cases
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/__init__.py ADDED
File without changes
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py ADDED
File without changes
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+ # Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+
5
+ # Imports.
6
+ # ------------------------------------------------------------------------------
7
+
8
+ # Python standard library
9
+ import functools
10
+
11
+ # Triton
12
+ import triton
13
+ import triton.language as tl
14
+
15
+ # AITER
16
+ from ..configs import CONFIGS as _CONFIGS
17
+ from ..utils._triton import arch_info
18
+ from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
19
+
20
+ # Kernel config.
21
+ # ------------------------------------------------------------------------------
22
+
23
+
24
+ @functools.lru_cache()
25
+ def get_config(
26
+ gmm_type: str, M: int, K: int, N: int, G: int, accumulate: bool = False
27
+ ) -> dict[str, int]:
28
+ assert gmm_type in {
29
+ "gmm",
30
+ "ptgmm",
31
+ "nptgmm",
32
+ }, f"'{gmm_type}' is an invalid GMM variant."
33
+ dev = arch_info.get_arch()
34
+ assert (
35
+ dev in _CONFIGS
36
+ ), f"No GMM configuration tuned for arch '{dev}'. Supported: {sorted(_CONFIGS)}."
37
+ arch_configs = _CONFIGS[dev]
38
+ assert (
39
+ "default" in arch_configs[gmm_type]
40
+ ), "Default configuration is absent."
41
+ key = "accumulate" if accumulate else "default"
42
+ return arch_configs[gmm_type][key]
43
+
44
+
45
+ # Common code shared by GMM and TGMM kernels.
46
+ # ------------------------------------------------------------------------------
47
+
48
+
49
+ # XCD remapping followed by 1D PID to 2D grid mapping.
50
+ @triton.jit
51
+ def _remap_xcd_tile_grid(
52
+ tile_in_mm,
53
+ num_row_tiles,
54
+ num_col_tiles,
55
+ GROUP_SIZE: tl.constexpr = 1,
56
+ NUM_XCDS: tl.constexpr = 8,
57
+ ):
58
+ return pid_grid(
59
+ remap_xcd(tile_in_mm, num_row_tiles * num_col_tiles, NUM_XCDS=NUM_XCDS),
60
+ num_row_tiles,
61
+ num_col_tiles,
62
+ GROUP_SIZE_M=GROUP_SIZE,
63
+ )
64
+
65
+
66
+ # GMM kernel.
67
+ # ------------------------------------------------------------------------------
68
+
69
+
70
+ @triton.heuristics(
71
+ {
72
+ "K_DIVISIBLE_BY_BLOCK_SIZE_K": lambda META: META["K"] % META["BLOCK_SIZE_K"]
73
+ == 0,
74
+ }
75
+ )
76
+ @triton.jit
77
+ def gmm_kernel(
78
+ # Tensor pointers:
79
+ lhs_ptr,
80
+ rhs_ptr,
81
+ group_sizes_ptr,
82
+ out_ptr,
83
+ bias_ptr,
84
+ # Tensor shapes:
85
+ M: int,
86
+ K: int,
87
+ N: int,
88
+ G: int,
89
+ # Meta-parameters:
90
+ TRANS_RHS: tl.constexpr,
91
+ BLOCK_SIZE_M: tl.constexpr,
92
+ BLOCK_SIZE_K: tl.constexpr,
93
+ BLOCK_SIZE_N: tl.constexpr,
94
+ K_DIVISIBLE_BY_BLOCK_SIZE_K: tl.constexpr,
95
+ GROUP_SIZE: tl.constexpr,
96
+ GRID_DIM: tl.constexpr,
97
+ USE_BIAS: tl.constexpr,
98
+ ):
99
+ tl.assume(M > 0)
100
+ tl.assume(K > 0)
101
+ tl.assume(N > 0)
102
+ tl.assume(G > 0)
103
+
104
+ num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
105
+ tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
106
+
107
+ # Current tile. Each program computes multiple tiles of each group.
108
+ tile = tl.program_id(0)
109
+ tl.device_assert(tile >= 0, "tile < 0 (at initialization)")
110
+
111
+ # Tile limit of last MM problem (inclusive).
112
+ last_mm_tile = 0
113
+
114
+ # Last input row of lhs and output row of out. Each group reads some rows of
115
+ # lhs and writes some rows to out.
116
+ last_m = 0
117
+
118
+ # Loop through all (m, K, N) MM problems:
119
+ # (m, K) x (K, N) = (m, N)
120
+ # sum(m) = M
121
+ for g in range(G):
122
+ # Get m dimension of current MM problem.
123
+ m = tl.load(group_sizes_ptr + g)
124
+ # m can be zero if group is empty
125
+ tl.device_assert(m >= 0, "m < 0")
126
+
127
+ num_m_tiles = tl.cdiv(m, BLOCK_SIZE_M)
128
+ # num_m_tiles can be zero if group is empty
129
+ tl.device_assert(num_m_tiles >= 0, "num_m_tiles < 0")
130
+
131
+ num_tiles = num_m_tiles * num_n_tiles
132
+ # num_tiles can be zero if group is empty
133
+ tl.device_assert(num_tiles >= 0, "num_tiles < 0")
134
+
135
+ # Loop through tiles of current MM problem.
136
+ while tile >= last_mm_tile and tile < last_mm_tile + num_tiles:
137
+ # Figure out tile coordinates in current MM problem.
138
+ tile_in_mm = tile - last_mm_tile
139
+ tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
140
+
141
+ tile_m, tile_n = _remap_xcd_tile_grid(
142
+ tile_in_mm, num_m_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
143
+ )
144
+
145
+ # Do regular MM:
146
+
147
+ tl.device_assert(tile_m * BLOCK_SIZE_M >= 0, "tile_m * BLOCK_SIZE_M < 0")
148
+ tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
149
+
150
+ offs_lhs_m = (
151
+ tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
152
+ ) % m
153
+ offs_rhs_n = (
154
+ tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
155
+ ) % N
156
+ offs_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
157
+
158
+ lhs_ptrs = lhs_ptr + (last_m + offs_lhs_m[:, None]) * K + offs_k[None, :]
159
+
160
+ if TRANS_RHS:
161
+ rhs_ptrs = (
162
+ rhs_ptr
163
+ + g.to(tl.int64) * K * N
164
+ + offs_k[:, None]
165
+ + offs_rhs_n[None, :] * K
166
+ )
167
+ else:
168
+ rhs_ptrs = (
169
+ rhs_ptr
170
+ + g.to(tl.int64) * K * N
171
+ + offs_k[:, None] * N
172
+ + offs_rhs_n[None, :]
173
+ )
174
+
175
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
176
+
177
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
178
+ if K_DIVISIBLE_BY_BLOCK_SIZE_K:
179
+ lhs = tl.load(lhs_ptrs)
180
+ rhs = tl.load(rhs_ptrs)
181
+ else:
182
+ k_mask_limit = K - k * BLOCK_SIZE_K
183
+ lhs = tl.load(
184
+ lhs_ptrs, mask=offs_k[None, :] < k_mask_limit, other=0
185
+ )
186
+ rhs = tl.load(
187
+ rhs_ptrs, mask=offs_k[:, None] < k_mask_limit, other=0
188
+ )
189
+
190
+ acc = tl.dot(lhs, rhs, acc=acc)
191
+
192
+ lhs_ptrs += BLOCK_SIZE_K
193
+
194
+ if TRANS_RHS:
195
+ rhs_ptrs += BLOCK_SIZE_K
196
+ else:
197
+ rhs_ptrs += BLOCK_SIZE_K * N
198
+
199
+ # Add bias if enabled
200
+ if USE_BIAS:
201
+ offs_bias_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(
202
+ 0, BLOCK_SIZE_N
203
+ )
204
+ bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bias_n
205
+ bias = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0)
206
+ # Convert bias to float32 to match accumulator precision
207
+ bias = bias.to(tl.float32)
208
+ # Broadcast bias across M dimension and add in float32
209
+ acc += bias[None, :]
210
+
211
+ # Convert to output dtype after all computations
212
+ acc = acc.to(out_ptr.type.element_ty)
213
+
214
+ offs_out_m = tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
215
+ offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
216
+
217
+ out_ptrs = (
218
+ out_ptr + (last_m + offs_out_m[:, None]) * N + offs_out_n[None, :]
219
+ )
220
+
221
+ tl.store(
222
+ out_ptrs,
223
+ acc,
224
+ mask=(offs_out_m[:, None] < m) & (offs_out_n[None, :] < N),
225
+ )
226
+
227
+ # Go to the next tile by advancing number of programs.
228
+ tile += GRID_DIM
229
+ tl.device_assert(tile > 0, "tile <= 0 (at update)")
230
+
231
+ # Get ready to go to the next MM problem.
232
+
233
+ last_mm_tile += num_tiles
234
+ # last_mm_tile can be zero if group 0 is skipped
235
+ tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)")
236
+
237
+ last_m += m
238
+ # last_m can be zero if group 0 is skipped
239
+ tl.device_assert(last_m >= 0, "last_m < 0 (at update)")
240
+ tl.device_assert(last_m <= M, "last_m > M (at update)")
241
+
242
+
243
+ # Persistent TGMM kernel.
244
+ # ------------------------------------------------------------------------------
245
+
246
+
247
+ @triton.jit
248
+ def tgmm_persistent_kernel(
249
+ # Tensor pointers:
250
+ lhs_ptr,
251
+ rhs_ptr,
252
+ group_sizes_ptr,
253
+ out_ptr,
254
+ bias_grad_ptr,
255
+ # Tensor shapes:
256
+ M: int,
257
+ K: int,
258
+ N: int,
259
+ G: int,
260
+ # Meta-parameters:
261
+ TRANS_LHS: tl.constexpr,
262
+ BLOCK_SIZE_M: tl.constexpr,
263
+ BLOCK_SIZE_K: tl.constexpr,
264
+ BLOCK_SIZE_N: tl.constexpr,
265
+ GROUP_SIZE: tl.constexpr,
266
+ GRID_DIM: tl.constexpr,
267
+ COMPUTE_BIAS_GRAD: tl.constexpr,
268
+ ACCUMULATE: tl.constexpr,
269
+ ):
270
+ tl.assume(M > 0)
271
+ tl.assume(K > 0)
272
+ tl.assume(N > 0)
273
+ tl.assume(G > 0)
274
+
275
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
276
+ tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0")
277
+
278
+ num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
279
+ tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
280
+
281
+ num_tiles = num_k_tiles * num_n_tiles
282
+ tl.device_assert(num_tiles > 0, "num_tiles <= 0")
283
+
284
+ # Current tile. Each program computes multiple tiles of each group.
285
+ tile = tl.program_id(0)
286
+ tl.device_assert(tile >= 0, "tile < 0 (at initialization)")
287
+
288
+ # Tile limit of last MM problem (inclusive).
289
+ last_mm_tile = 0
290
+
291
+ # Last input column of lhs and input row of rhs. Each group reads some
292
+ # columns of lhs and some rows of rhs.
293
+ last_m = 0
294
+
295
+ # Loop through all (K, m, N) MM problems:
296
+ # (K, m) x (m, N) = (K, N)
297
+ # sum(m) = M
298
+ for g in range(G):
299
+ # Get m dimension of current MM problem.
300
+ m = tl.load(group_sizes_ptr + g)
301
+ # m can be zero if group is empty
302
+ tl.device_assert(m >= 0, "m < 0")
303
+
304
+ # Loop through tiles of current MM problem.
305
+ while tile >= last_mm_tile and tile < last_mm_tile + num_tiles:
306
+ # Figure out tile coordinates in current MM problem.
307
+ tile_in_mm = tile - last_mm_tile
308
+ tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
309
+
310
+ tile_k, tile_n = _remap_xcd_tile_grid(
311
+ tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
312
+ )
313
+
314
+ # Do regular MM:
315
+
316
+ tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0")
317
+ tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
318
+
319
+ offs_lhs_k = (
320
+ tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
321
+ ) % K
322
+ offs_rhs_n = (
323
+ tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
324
+ ) % N
325
+ offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
326
+
327
+ if TRANS_LHS:
328
+ lhs_ptrs = (
329
+ lhs_ptr + offs_lhs_k[:, None] + (last_m + offs_m[None, :]) * K
330
+ )
331
+ else:
332
+ lhs_ptrs = (
333
+ lhs_ptr + offs_lhs_k[:, None] * M + (last_m + offs_m[None, :])
334
+ )
335
+
336
+ rhs_ptrs = rhs_ptr + (last_m + offs_m[:, None]) * N + offs_rhs_n[None, :]
337
+
338
+ loop_m = tl.cdiv(m, BLOCK_SIZE_M)
339
+ m_divisible_by_block_m = m % BLOCK_SIZE_M == 0
340
+ if not m_divisible_by_block_m:
341
+ loop_m -= 1
342
+
343
+ acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32)
344
+
345
+ # Initialize bias accumulator
346
+ bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
347
+
348
+ for _ in range(0, loop_m):
349
+ lhs = tl.load(lhs_ptrs)
350
+ rhs = tl.load(rhs_ptrs)
351
+
352
+ acc = tl.dot(lhs, rhs, acc=acc)
353
+
354
+ # Accumulate for bias gradient: sum lhs across M dimension
355
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
356
+ bias_acc += tl.sum(
357
+ lhs, axis=1
358
+ ) # Sum across M dimension [K, M] -> [K]
359
+
360
+ if TRANS_LHS:
361
+ lhs_ptrs += BLOCK_SIZE_M * K
362
+ else:
363
+ lhs_ptrs += BLOCK_SIZE_M
364
+
365
+ rhs_ptrs += BLOCK_SIZE_M * N
366
+
367
+ if not m_divisible_by_block_m:
368
+ offs_lhs_k = (
369
+ tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
370
+ ) % K
371
+ offs_rhs_n = (
372
+ tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
373
+ ) % N
374
+ offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
375
+ lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0)
376
+ rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0)
377
+ acc = tl.dot(lhs, rhs, acc=acc)
378
+
379
+ # Accumulate last chunk for bias gradient
380
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
381
+ bias_acc += tl.sum(lhs, axis=1)
382
+
383
+ acc = acc.to(out_ptr.type.element_ty)
384
+
385
+ offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
386
+ offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
387
+
388
+ out_ptrs = (
389
+ out_ptr
390
+ + g.to(tl.int64) * K * N
391
+ + offs_out_k[:, None] * N
392
+ + offs_out_n[None, :]
393
+ )
394
+
395
+ mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N)
396
+ if ACCUMULATE:
397
+ # Load existing values and add to them (like beta=1 in BLAS)
398
+ old_vals = tl.load(out_ptrs, mask=mask, other=0.0)
399
+ tl.store(out_ptrs, acc + old_vals, mask=mask)
400
+ else:
401
+ # Overwrite output (like beta=0 in BLAS)
402
+ tl.store(out_ptrs, acc, mask=mask)
403
+
404
+ # Store bias gradient (only for first N tile, sum across all M)
405
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
406
+ # Keep as float32 for atomic_add (bf16 not supported for atomics)
407
+ bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k
408
+ # Use atomic add since multiple K-tiles may write to same expert's bias
409
+ tl.atomic_add(
410
+ bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed"
411
+ )
412
+
413
+ # Go to the next tile by advancing number of programs.
414
+ tile += GRID_DIM
415
+ tl.device_assert(tile > 0, "tile <= 0 (at update)")
416
+
417
+ # Get ready to go to the next MM problem.
418
+
419
+ last_mm_tile += num_tiles
420
+ # last_mm_tile can be zero if group 0 is skipped
421
+ tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)")
422
+
423
+ last_m += m
424
+ # last_m can be zero if group 0 is skipped
425
+ tl.device_assert(last_m >= 0, "last_m < 0 (at update)")
426
+ tl.device_assert(last_m <= M, "last_m > M (at update)")
427
+
428
+
429
+ # Regular non-persistent TGMM kernel.
430
+ # ------------------------------------------------------------------------------
431
+
432
+
433
+ @triton.heuristics({"BLOCK_SIZE_G": lambda META: triton.next_power_of_2(META["G"])})
434
+ @triton.jit
435
+ def tgmm_non_persistent_kernel(
436
+ # Tensor pointers:
437
+ lhs_ptr,
438
+ rhs_ptr,
439
+ group_sizes_ptr,
440
+ out_ptr,
441
+ bias_grad_ptr,
442
+ # Tensor shapes:
443
+ M: int,
444
+ K: int,
445
+ N: int,
446
+ G: int,
447
+ # Meta-parameters:
448
+ TRANS_LHS: tl.constexpr,
449
+ BLOCK_SIZE_G: tl.constexpr,
450
+ BLOCK_SIZE_M: tl.constexpr,
451
+ BLOCK_SIZE_K: tl.constexpr,
452
+ BLOCK_SIZE_N: tl.constexpr,
453
+ GROUP_SIZE: tl.constexpr,
454
+ COMPUTE_BIAS_GRAD: tl.constexpr,
455
+ ACCUMULATE: tl.constexpr,
456
+ ):
457
+ tl.assume(M > 0)
458
+ tl.assume(K > 0)
459
+ tl.assume(N > 0)
460
+ tl.assume(G > 0)
461
+
462
+ # Get group ID from grid.
463
+ g = tl.program_id(0)
464
+ tl.device_assert(g >= 0, "g < 0")
465
+ tl.device_assert(g < G, "g >= G")
466
+
467
+ # Get m dimension of current MM group.
468
+ m = tl.load(group_sizes_ptr + g)
469
+ # m can be zero if group is empty.
470
+ tl.device_assert(m >= 0, "m < 0")
471
+
472
+ # Skip empty groups.
473
+ if m == 0:
474
+ return
475
+
476
+ # Compute sum(group_sizes) until current group g.
477
+ # It's the starting column of lhs and starting row of rhs.
478
+ offs_g = tl.arange(0, BLOCK_SIZE_G)
479
+ group_sizes = tl.load(group_sizes_ptr + offs_g, mask=offs_g < g, other=0)
480
+ start_m = tl.sum(group_sizes)
481
+
482
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
483
+ tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0")
484
+
485
+ num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
486
+ tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
487
+
488
+ # Get MM tile from grid.
489
+ tile_in_mm = tl.program_id(1)
490
+ tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
491
+
492
+ tile_k, tile_n = _remap_xcd_tile_grid(
493
+ tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
494
+ )
495
+
496
+ tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0")
497
+ tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
498
+
499
+ offs_lhs_k = (tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K
500
+ offs_rhs_n = (tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
501
+ offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
502
+
503
+ if TRANS_LHS:
504
+ lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] + (start_m + offs_m[None, :]) * K
505
+ else:
506
+ lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] * M + (start_m + offs_m[None, :])
507
+
508
+ rhs_ptrs = rhs_ptr + (start_m + offs_m[:, None]) * N + offs_rhs_n[None, :]
509
+
510
+ loop_m = tl.cdiv(m, BLOCK_SIZE_M)
511
+ m_divisible_by_block_m = m % BLOCK_SIZE_M == 0
512
+ if not m_divisible_by_block_m:
513
+ loop_m -= 1
514
+
515
+ acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32)
516
+ # Initialize bias accumulator
517
+ bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
518
+
519
+ for _ in range(0, loop_m):
520
+ lhs = tl.load(lhs_ptrs)
521
+ rhs = tl.load(rhs_ptrs)
522
+
523
+ acc = tl.dot(lhs, rhs, acc=acc)
524
+
525
+ # Accumulate for bias gradient: sum lhs across M dimension
526
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
527
+ bias_acc += tl.sum(lhs, axis=1) # [K, M] -> [K]
528
+
529
+ if TRANS_LHS:
530
+ lhs_ptrs += BLOCK_SIZE_M * K
531
+ else:
532
+ lhs_ptrs += BLOCK_SIZE_M
533
+
534
+ rhs_ptrs += BLOCK_SIZE_M * N
535
+
536
+ if not m_divisible_by_block_m:
537
+ offs_lhs_k = (
538
+ tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
539
+ ) % K
540
+ offs_rhs_n = (
541
+ tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
542
+ ) % N
543
+ offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
544
+ lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0)
545
+ rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0)
546
+ acc = tl.dot(lhs, rhs, acc=acc)
547
+ # Accumulate last chunk for bias gradient
548
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
549
+ bias_acc += tl.sum(lhs, axis=1)
550
+
551
+ acc = acc.to(out_ptr.type.element_ty)
552
+
553
+ offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
554
+ offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
555
+
556
+ out_ptrs = (
557
+ out_ptr + g.to(tl.int64) * K * N + offs_out_k[:, None] * N + offs_out_n[None, :]
558
+ )
559
+
560
+ mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N)
561
+ if ACCUMULATE:
562
+ # Load existing values and add to them (like beta=1 in BLAS)
563
+ old_vals = tl.load(out_ptrs, mask=mask, other=0.0)
564
+ tl.store(out_ptrs, acc + old_vals, mask=mask)
565
+ else:
566
+ # Overwrite output (like beta=0 in BLAS)
567
+ tl.store(out_ptrs, acc, mask=mask)
568
+
569
+ # Store bias gradient (only for first N tile, sum across all M)
570
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
571
+ # Keep as float32 for atomic_add (bf16/fp16 not supported for atomics)
572
+ bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k
573
+ # Use atomic add since multiple K-tiles may write to same expert's bias
574
+ tl.atomic_add(bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed")
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/adapter.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Adapt AITER's Triton grouped GEMM to MegaBlocks' ``gmm`` calling convention.
3
+
4
+ MegaBlocks (following tgale96/grouped_gemm) uses a single ``gmm`` entry point
5
+ with ``trans_a`` / ``trans_b`` flags:
6
+
7
+ * ``trans_a=False, trans_b=False``: a(M,K) @ b(G,K,N) -> c(M,N)
8
+ * ``trans_a=False, trans_b=True`` : a(M,K) @ b(G,N,K)^T -> c(M,N) (dgrad)
9
+ * ``trans_a=True`` : a(M,K)^T @ b(M,N) per group -> c(G,K,N) (wgrad)
10
+
11
+ AITER exposes these as two kernels: ``gmm`` ((M,K)@(G,K,N)->(M,N), transposition
12
+ of the 3D operand inferred from strides) and ``ptgmm`` ((K,M)@(M,N)->(G,K,N),
13
+ transposition of the 2D operand inferred from strides).
14
+ """
15
+
16
+ import torch
17
+
18
+ from .gmm import gmm as _aiter_gmm
19
+ from .gmm import ptgmm as _aiter_ptgmm
20
+
21
+
22
+ def gmm(a, b, c, batch_sizes, trans_a=False, trans_b=False):
23
+ # AITER requires group sizes to be int32 and to live on the compute device.
24
+ group_sizes = batch_sizes.to(device=a.device, dtype=torch.int32)
25
+
26
+ # AITER asserts exact strides: gmm wants lhs/rhs row-major (a transposed
27
+ # 3D operand must be exactly column-major), tgmm wants rhs row-major and
28
+ # lhs row/column-major. Make operands contiguous first so the transposed
29
+ # views have the precise strides the kernels expect. `.contiguous()` is a
30
+ # no-op when the tensor is already contiguous.
31
+ if trans_a:
32
+ # Weight gradient: a(M,K), b(M,N) -> c(G,K,N).
33
+ # Pass a transposed so AITER sees lhs(K,M) column-major (TRANS_LHS).
34
+ _aiter_ptgmm(
35
+ a.contiguous().transpose(0, 1),
36
+ b.contiguous(),
37
+ group_sizes,
38
+ preferred_element_type=c.dtype,
39
+ existing_out=c,
40
+ )
41
+ else:
42
+ # trans_b contracts b's last dim: pass a column-major (G,K,N) view.
43
+ rhs = b.contiguous()
44
+ if trans_b:
45
+ rhs = rhs.transpose(1, 2)
46
+ _aiter_gmm(
47
+ a.contiguous(),
48
+ rhs,
49
+ group_sizes,
50
+ preferred_element_type=c.dtype,
51
+ existing_out=c,
52
+ )
53
+ return c
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/configs.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+ # Tuned GMM configs vendored from ROCm/aiter (aiter/ops/triton/configs/).
3
+ # Inlined as a Python module so packaging always includes them.
4
+
5
+ CONFIGS = {'gfx1250': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx942': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx950': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}}
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/gmm.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+ # Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+
5
+ # Imports.
6
+ # ------------------------------------------------------------------------------
7
+
8
+ # PyTorch
9
+ import torch
10
+ from torch import Tensor
11
+
12
+ # Triton
13
+ import triton
14
+
15
+ # AITER: GMM utility functions
16
+ from .utils.gmm_common import (
17
+ DTYPE,
18
+ is_power_of_2,
19
+ check_input_device_dtype,
20
+ check_bias_shape_stride,
21
+ get_gmm_shape,
22
+ get_gmm_output,
23
+ get_gmm_transposition,
24
+ get_tgmm_shape,
25
+ get_tgmm_output,
26
+ get_tgmm_bias_grad,
27
+ get_tgmm_transposition,
28
+ )
29
+
30
+ # AITER: GMM Triton kernels
31
+ from ._triton_kernels.gmm import (
32
+ gmm_kernel,
33
+ tgmm_persistent_kernel,
34
+ tgmm_non_persistent_kernel,
35
+ get_config,
36
+ )
37
+
38
+ # GMM PyTorch wrapper.
39
+ # ------------------------------------------------------------------------------
40
+
41
+
42
+ def _gmm_grid(
43
+ N: int,
44
+ block_size_m: int,
45
+ block_size_n: int,
46
+ group_sizes: Tensor,
47
+ grid_dim: int,
48
+ ) -> tuple[int]:
49
+ assert N > 0, f"N must be positive, it's {N}."
50
+ assert is_power_of_2(
51
+ block_size_m
52
+ ), f"M-dimension tile size must be a power of 2 (it's {block_size_m})."
53
+ assert is_power_of_2(
54
+ block_size_n
55
+ ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
56
+ assert torch.all(group_sizes >= 0).item(), "All group_sizes must be non-negative."
57
+ assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})."
58
+ num_m_tiles = (group_sizes + block_size_m - 1) // block_size_m
59
+ assert torch.all(num_m_tiles >= 0).item(), "All num_m_tiles must be non-negative."
60
+ num_n_tiles = triton.cdiv(N, block_size_n)
61
+ assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
62
+ num_tiles = torch.sum(num_m_tiles * num_n_tiles).item()
63
+ assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}."
64
+ num_programs = int(min(grid_dim, num_tiles))
65
+ assert num_programs > 0, f"num_programs must be positive, it's {num_programs}."
66
+ return (num_programs,)
67
+
68
+
69
+ def gmm(
70
+ lhs: Tensor,
71
+ rhs: Tensor,
72
+ group_sizes: Tensor,
73
+ preferred_element_type: torch.dtype = DTYPE,
74
+ existing_out: Tensor | None = None,
75
+ config: dict[str, int] | None = None,
76
+ bias: Tensor | None = None,
77
+ ) -> Tensor:
78
+ """
79
+ Perform Group Matrix Multiplication (GMM): out = lhs @ rhs + bias
80
+
81
+ lhs rows are divided into G groups. Each group of lhs rows is matrix multiplied with a plane of
82
+ rhs 3D tensor and then stored in a slice of out. In PyTorch parlance, it can be implemented as
83
+ follows for a given group g:
84
+ out[group_start:group_end, :] = lhs[group_start:group_end, :] @ rhs[g] + bias[g]
85
+
86
+ The size of each group, and their respective start and end positions are specified by
87
+ group_sizes tensor. For instance, suppose that group_sizes = [3, 2, 4, 1]. In this particular
88
+ case we have 4 groups. The 1st group starts at 0 and ends at 2, the second group starts at 3 and
89
+ ends at 4, the third group starts at 5 and ends at 8, and the fourth and final group consists of
90
+ just the 10th (last) row of lhs.
91
+
92
+ Parameters
93
+ ----------
94
+ lhs : torch.Tensor
95
+ Left-hand side 2D input tensor. Shape: (M, K).
96
+ lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
97
+ lhs must be on the same device of rhs and group_sizes.
98
+ rhs : torch.Tensor
99
+ Right-hand side 3D input tensor. Shape: (G, K, N).
100
+ rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
101
+ rhs must be on the same device of lhs and group_sizes.
102
+ group_sizes : torch.Tensor
103
+ 1D input tensor describing group sizes. Shape: (G,).
104
+ group_sizes data type must be torch.int32 and all its elements must be non-negative.
105
+ group_sizes must be on the same device of lhs and rhs.
106
+ preferred_element_type : torch.dtype, optional
107
+ Desired data type for output tensor. Default is torch.bfloat16.
108
+ Supported output types are torch.float16 and torch.bfloat16.
109
+ existing_out : torch.Tensor or None, optional
110
+ Preallocated output tensor. Default is None.
111
+ If provided, results are written into this tensor. Otherwise, a new output tensor is
112
+ allocated.
113
+ If provided then it must have shape (M, N), its data type must match preferred_element_type
114
+ and it must be on the same device of other input tensors.
115
+ config : dict[str, int] or None, optional
116
+ Optional dictionary with kernel metaparameters. If absent, config will be queried from
117
+ internal tuning database.
118
+ bias : torch.Tensor or None, optional
119
+ Optional bias tensor. Shape: (G, N).
120
+ If provided, bias data type must match lhs and rhs data type, and bias must be on the same
121
+ device as other input tensors. Each group g adds bias[g] to the output.
122
+
123
+ Returns
124
+ -------
125
+ torch.Tensor
126
+ The computed output 2D tensor. Shape: (M, N).
127
+ Output tensor data type is given by preferred_element_type.
128
+ If existing_out is provided then existing_out is also returned.
129
+
130
+ Implementation Notes
131
+ --------------------
132
+ - GMM is implemented with a persistent Triton kernel.
133
+ - lhs must be row-major (lhs.stride() == (K, 1)).
134
+ - rhs can be row-major (rhs.stride() == (K * N, N, 1)) or column-major (rhs.stride() ==
135
+ (K * N, 1, K)). If rhs is row-major then kernel parameter TRANS_RHS == False, this is useful
136
+ for implementing forward pass. If rhs is column-major then kernel parameter TRANS_RHS == True,
137
+ this is useful for computing the lhs derivative in the backward pass, while fusing the
138
+ transposition.
139
+ - out must be row-major (out.stride() == (N, 1)).
140
+ - bias must be row-major (bias.stride() == (N, 1)) if provided.
141
+ """
142
+ use_bias = bias is not None
143
+ check_input_device_dtype(lhs, rhs, group_sizes, bias)
144
+
145
+ M, K, N, G = get_gmm_shape(lhs, rhs, group_sizes)
146
+
147
+ if use_bias:
148
+ check_bias_shape_stride(bias, G, N)
149
+
150
+ out = get_gmm_output(
151
+ M,
152
+ N,
153
+ device=lhs.device,
154
+ preferred_element_type=preferred_element_type,
155
+ existing_out=existing_out,
156
+ )
157
+
158
+ trans_rhs, _ = get_gmm_transposition(lhs, rhs, out)
159
+
160
+ if config is None:
161
+ config = get_config("gmm", M, K, N, G)
162
+
163
+ assert all(
164
+ key in config
165
+ and isinstance(config[key], int)
166
+ and (
167
+ is_power_of_2(config[key])
168
+ if key.startswith("BLOCK_SIZE_")
169
+ else config[key] > 0
170
+ )
171
+ for key in {
172
+ "BLOCK_SIZE_M",
173
+ "BLOCK_SIZE_K",
174
+ "BLOCK_SIZE_N",
175
+ "GROUP_SIZE",
176
+ "GRID_DIM",
177
+ }
178
+ ), "Invalid GMM kernel config."
179
+
180
+ grid = _gmm_grid(
181
+ N,
182
+ config["BLOCK_SIZE_M"],
183
+ config["BLOCK_SIZE_N"],
184
+ group_sizes,
185
+ config["GRID_DIM"],
186
+ )
187
+
188
+ # fmt: off
189
+ gmm_kernel[grid](
190
+ # Tensor pointers:
191
+ lhs, rhs, group_sizes, out, bias,
192
+ # Tensor shapes:
193
+ M, K, N, G,
194
+ # Meta-parameters:
195
+ TRANS_RHS=trans_rhs,
196
+ USE_BIAS=use_bias,
197
+ **config,
198
+ )
199
+ # fmt: on
200
+
201
+ return out
202
+
203
+
204
+ # Persistent TGMM PyTorch wrapper.
205
+ # ------------------------------------------------------------------------------
206
+
207
+
208
+ def _ptgmm_grid(
209
+ K: int,
210
+ N: int,
211
+ G: int,
212
+ block_size_k: int,
213
+ block_size_n: int,
214
+ grid_dim: int,
215
+ ) -> tuple[int]:
216
+ assert K > 0, f"K must be positive, it's {K}."
217
+ assert N > 0, f"N must be positive, it's {N}."
218
+ assert G > 0, f"G must be positive, it's {G}."
219
+ assert is_power_of_2(
220
+ block_size_k
221
+ ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})."
222
+ assert is_power_of_2(
223
+ block_size_n
224
+ ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
225
+ assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})."
226
+ num_k_tiles = triton.cdiv(K, block_size_k)
227
+ assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}."
228
+ num_n_tiles = triton.cdiv(N, block_size_n)
229
+ assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
230
+ num_tiles = G * num_k_tiles * num_n_tiles
231
+ assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}."
232
+ num_programs = min(grid_dim, num_tiles)
233
+ assert num_programs > 0, f"num_programs must be positive, it's {num_programs}."
234
+ return (num_programs,)
235
+
236
+
237
+ def ptgmm(
238
+ lhs: Tensor,
239
+ rhs: Tensor,
240
+ group_sizes: Tensor,
241
+ preferred_element_type: torch.dtype = DTYPE,
242
+ existing_out: Tensor | None = None,
243
+ config: dict[str, int] | None = None,
244
+ bias_grad: Tensor | None = None,
245
+ accumulate: bool = False,
246
+ ) -> Tensor:
247
+ """
248
+ Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs
249
+
250
+ lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with
251
+ the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch
252
+ parlance, it can be implemented as follows for a given group g:
253
+ out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :]
254
+
255
+ The 't' in the operator name derives from MaxText implementation
256
+ (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py),
257
+ which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor
258
+ shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N).
259
+
260
+ The 'p' in the operator name means that it is implemented with a persistent kernel. There is
261
+ also the non-persistent variation, which is implemented with a regular kernel. Please take a
262
+ look at nptgmm operator. Both ptgmm and nptgmm implement the same computation, choosing one or
263
+ the other is a matter of performance for the target workload.
264
+
265
+ Parameters
266
+ ----------
267
+ lhs : torch.Tensor
268
+ Left-hand side 2D input tensor. Shape: (K, M).
269
+ lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
270
+ lhs must be on the same device of rhs and group_sizes.
271
+ rhs : torch.Tensor
272
+ Right-hand side 2D input tensor. Shape: (M, N).
273
+ rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
274
+ rhs must be on the same device of lhs and group_sizes.
275
+ group_sizes : torch.Tensor
276
+ 1D input tensor describing group sizes. Shape: (G,).
277
+ group_sizes data type must be torch.int32 and all its elements must be non-negative.
278
+ group_sizes must be on the same device of lhs and rhs.
279
+ preferred_element_type : torch.dtype, optional
280
+ Desired data type for output tensor. Default is torch.bfloat16.
281
+ Supported output types are torch.float16 and torch.bfloat16.
282
+ existing_out : torch.Tensor or None, optional
283
+ Preallocated output tensor. Default is None.
284
+ If provided, results are written into this tensor. Otherwise, a new output tensor is
285
+ allocated.
286
+ If provided then it must have shape (G, K, N), its data type must match
287
+ preferred_element_type and it must be on the same device of other input tensors.
288
+ config : dict[str, int] or None, optional
289
+ Optional dictionary with kernel metaparameters. If absent, config will be queried from
290
+ internal tuning database.
291
+ bias_grad : torch.Tensor or None, optional
292
+ Optional bias gradient output tensor. Shape: (G, K).
293
+ If provided, the kernel will compute the bias gradient and write it to this tensor.
294
+ bias_grad must be torch.float32 (kernel uses atomic_add which requires float32),
295
+ accumulate : bool, optional
296
+ Whether to accumulate into existing output tensor values. Default is False.
297
+ If False, output will be overwritten with fresh computation.
298
+ If True, results will be added to existing output tensor values.
299
+
300
+ Returns
301
+ -------
302
+ torch.Tensor
303
+ The computed output 3D tensor. Shape: (G, K, N).
304
+ Output tensor data type is given by preferred_element_type.
305
+ If existing_out is provided then existing_out is also returned.
306
+
307
+ Implementation Notes
308
+ --------------------
309
+ - PTGMM is implemented with a persistent Triton kernel.
310
+ - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs
311
+ is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel
312
+ parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward
313
+ pass, while fusing the transposition.
314
+ - rhs must be row-major (rhs.stride() == (N, 1)).
315
+ - out must be row-major (out.stride() == (K * N, N, 1)).
316
+ """
317
+ check_input_device_dtype(lhs, rhs, group_sizes)
318
+
319
+ M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes)
320
+
321
+ out = get_tgmm_output(
322
+ K,
323
+ N,
324
+ G,
325
+ device=lhs.device,
326
+ preferred_element_type=preferred_element_type,
327
+ existing_out=existing_out,
328
+ )
329
+
330
+ trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out)
331
+
332
+ if config is None:
333
+ config = get_config("ptgmm", M, K, N, G, accumulate)
334
+
335
+ assert all(
336
+ key in config
337
+ and isinstance(config[key], int)
338
+ and (
339
+ is_power_of_2(config[key])
340
+ if key.startswith("BLOCK_SIZE_")
341
+ else config[key] > 0
342
+ )
343
+ for key in {
344
+ "BLOCK_SIZE_M",
345
+ "BLOCK_SIZE_K",
346
+ "BLOCK_SIZE_N",
347
+ "GROUP_SIZE",
348
+ "GRID_DIM",
349
+ }
350
+ ), "Invalid PTGMM kernel config."
351
+
352
+ # Bias gradient handling.
353
+ # -----------------------
354
+ # Get or validate bias gradient tensor.
355
+ compute_bias_grad = bias_grad is not None
356
+ bias_grad_ptr = get_tgmm_bias_grad(
357
+ K,
358
+ G,
359
+ device=lhs.device,
360
+ existing_bias_grad=bias_grad,
361
+ )
362
+
363
+ grid = _ptgmm_grid(
364
+ K,
365
+ N,
366
+ G,
367
+ config["BLOCK_SIZE_K"],
368
+ config["BLOCK_SIZE_N"],
369
+ config["GRID_DIM"],
370
+ )
371
+
372
+ # fmt: off
373
+ tgmm_persistent_kernel[grid](
374
+ # Tensor pointers:
375
+ lhs, rhs, group_sizes, out, bias_grad_ptr,
376
+ # Tensor shapes:
377
+ M, K, N, G,
378
+ # Meta-parameters:
379
+ TRANS_LHS=trans_lhs,
380
+ COMPUTE_BIAS_GRAD=compute_bias_grad,
381
+ ACCUMULATE=accumulate,
382
+ **config,
383
+ )
384
+ # fmt: on
385
+
386
+ return out
387
+
388
+
389
+ # Regular non-persistent TGMM PyTorch wrapper.
390
+ # ------------------------------------------------------------------------------
391
+
392
+
393
+ def _nptgmm_grid(
394
+ K: int,
395
+ N: int,
396
+ G: int,
397
+ block_size_k: int,
398
+ block_size_n: int,
399
+ ) -> tuple[int, int]:
400
+ assert K > 0, f"K must be positive, it's {K}."
401
+ assert N > 0, f"N must be positive, it's {N}."
402
+ assert G > 0, f"G must be positive, it's {G}."
403
+ assert is_power_of_2(
404
+ block_size_k
405
+ ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})."
406
+ assert is_power_of_2(
407
+ block_size_n
408
+ ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
409
+ num_k_tiles = triton.cdiv(K, block_size_k)
410
+ assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}."
411
+ num_n_tiles = triton.cdiv(N, block_size_n)
412
+ assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
413
+ num_tiles_per_mm = num_k_tiles * num_n_tiles
414
+ assert (
415
+ num_tiles_per_mm > 0
416
+ ), f"num_tiles_per_mm must be positive, it's {num_tiles_per_mm}."
417
+ return (G, num_tiles_per_mm)
418
+
419
+
420
+ def nptgmm(
421
+ lhs: Tensor,
422
+ rhs: Tensor,
423
+ group_sizes: Tensor,
424
+ preferred_element_type: torch.dtype = DTYPE,
425
+ existing_out: Tensor | None = None,
426
+ config: dict[str, int] | None = None,
427
+ bias_grad: Tensor | None = None,
428
+ accumulate: bool = False,
429
+ ) -> Tensor:
430
+ """
431
+ Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs
432
+
433
+ lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with
434
+ the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch
435
+ parlance, it can be implemented as follows for a given group g:
436
+ out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :]
437
+
438
+ The 't' in the operator name derives from MaxText implementation
439
+ (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py),
440
+ which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor
441
+ shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N).
442
+
443
+ The 'np' in the operator name means that it is implemented with a non-persistent, i.e. regular
444
+ kernel. There is also the persistent variation, which is implemented with a persistent kernel.
445
+ Please take a look at ptgmm operator. Both nptgmm and ptgmm implement the same computation,
446
+ choosing one or the other is a matter of performance for the target workload.
447
+
448
+ Parameters
449
+ ----------
450
+ lhs : torch.Tensor
451
+ Left-hand side 2D input tensor. Shape: (K, M).
452
+ lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
453
+ lhs must be on the same device of rhs and group_sizes.
454
+ rhs : torch.Tensor
455
+ Right-hand side 2D input tensor. Shape: (M, N).
456
+ rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
457
+ rhs must be on the same device of lhs and group_sizes.
458
+ group_sizes : torch.Tensor
459
+ 1D input tensor describing group sizes. Shape: (G,).
460
+ group_sizes data type must be torch.int32 and all its elements must be non-negative.
461
+ group_sizes must be on the same device of lhs and rhs.
462
+ preferred_element_type : torch.dtype, optional
463
+ Desired data type for output tensor. Default is torch.bfloat16.
464
+ Supported output types are torch.float16 and torch.bfloat16.
465
+ existing_out : torch.Tensor or None, optional
466
+ Preallocated output tensor. Default is None.
467
+ If provided, results are written into this tensor. Otherwise, a new output tensor is
468
+ allocated.
469
+ If provided then it must have shape (G, K, N), its data type must match
470
+ preferred_element_type and it must be on the same device of other input tensors.
471
+ config : dict[str, int] or None, optional
472
+ Optional dictionary with kernel metaparameters. If absent, config will be queried from
473
+ internal tuning database.
474
+ bias_grad : torch.Tensor or None, optional
475
+ Optional bias gradient output tensor. Shape: (G, K).
476
+ If provided, the kernel will compute the bias gradient and write it to this tensor.
477
+ bias_grad must be torch.float32 (kernel uses atomic_add which requires float32),
478
+ accumulate : bool, optional
479
+ Whether to accumulate into existing output tensor values. Default is False.
480
+ If False, output will be overwritten with fresh computation.
481
+ If True, results will be added to existing output tensor values.
482
+
483
+ Returns
484
+ -------
485
+ torch.Tensor
486
+ The computed output 3D tensor. Shape: (G, K, N).
487
+ Output tensor data type is given by preferred_element_type.
488
+ If existing_out is provided then existing_out is also returned.
489
+
490
+ Implementation Notes
491
+ --------------------
492
+ - NPTGMM is implemented with a non-persistent regular Triton kernel.
493
+ - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs
494
+ is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel
495
+ parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward
496
+ pass, while fusing the transposition.
497
+ - rhs must be row-major (rhs.stride() == (N, 1)).
498
+ - out must be row-major (out.stride() == (K * N, N, 1)).
499
+ """
500
+ check_input_device_dtype(lhs, rhs, group_sizes)
501
+
502
+ M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes)
503
+
504
+ out = get_tgmm_output(
505
+ K,
506
+ N,
507
+ G,
508
+ device=lhs.device,
509
+ preferred_element_type=preferred_element_type,
510
+ existing_out=existing_out,
511
+ )
512
+
513
+ trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out)
514
+
515
+ # Bias gradient handling.
516
+ # -----------------------
517
+ # Get or validate bias gradient tensor.
518
+ compute_bias_grad = bias_grad is not None
519
+ bias_grad_ptr = get_tgmm_bias_grad(
520
+ K,
521
+ G,
522
+ device=lhs.device,
523
+ existing_bias_grad=bias_grad,
524
+ )
525
+
526
+ if config is None:
527
+ config = get_config("nptgmm", M, K, N, G, accumulate)
528
+
529
+ assert all(
530
+ key in config
531
+ and isinstance(config[key], int)
532
+ and (
533
+ is_power_of_2(config[key])
534
+ if key.startswith("BLOCK_SIZE_")
535
+ else config[key] > 0
536
+ )
537
+ for key in {
538
+ "BLOCK_SIZE_M",
539
+ "BLOCK_SIZE_K",
540
+ "BLOCK_SIZE_N",
541
+ "GROUP_SIZE",
542
+ }
543
+ ), "Invalid NPTGMM kernel config."
544
+
545
+ grid = _nptgmm_grid(
546
+ K,
547
+ N,
548
+ G,
549
+ config["BLOCK_SIZE_K"],
550
+ config["BLOCK_SIZE_N"],
551
+ )
552
+
553
+ # fmt: off
554
+ tgmm_non_persistent_kernel[grid](
555
+ # Tensor pointers:
556
+ lhs, rhs, group_sizes, out, bias_grad_ptr,
557
+ # Tensor shapes:
558
+ M, K, N, G,
559
+ # Meta-parameters:
560
+ TRANS_LHS=trans_lhs,
561
+ COMPUTE_BIAS_GRAD=compute_bias_grad,
562
+ ACCUMULATE=accumulate,
563
+ **config,
564
+ )
565
+ # fmt: on
566
+
567
+ return out
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/__init__.py ADDED
File without changes
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py ADDED
File without changes
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+
3
+ # Detect the GPU arch lazily: querying the triton driver at import time fails
4
+ # in headless environments (e.g. the kernel-builder ABI check sandbox has no
5
+ # GPU), and the original JAX fallback pulled in an unrelated runtime dep. The
6
+ # arch is only actually needed when a GMM kernel is dispatched, so resolve and
7
+ # cache on first call.
8
+ _CACHED_ARCH = None
9
+
10
+
11
+ def get_arch():
12
+ global _CACHED_ARCH
13
+ if _CACHED_ARCH is not None:
14
+ return _CACHED_ARCH
15
+ try:
16
+ _CACHED_ARCH = triton.runtime.driver.active.get_current_target().arch
17
+ except RuntimeError:
18
+ try:
19
+ from jax._src.lib import gpu_triton as triton_kernel_call_lib
20
+ _CACHED_ARCH = triton_kernel_call_lib.get_arch_details("0").split(":")[0]
21
+ except ImportError as e:
22
+ raise RuntimeError(
23
+ "Cannot determine GPU arch: triton driver is inactive and "
24
+ "JAX is not available. A GPU is required for grouped GEMM."
25
+ ) from e
26
+ return _CACHED_ARCH
27
+
28
+
29
+ def is_gluon_avail():
30
+ return get_arch() in ("gfx950", "gfx1250")
31
+
32
+
33
+ def is_fp4_avail():
34
+ return get_arch() in ("gfx950", "gfx1250")
35
+
36
+
37
+ def is_fp8_avail():
38
+ return get_arch() in ("gfx942", "gfx950", "gfx1250", "gfx1200", "gfx1201")
39
+
40
+
41
+ def is_mx_scale_preshuffling_avail():
42
+ return get_arch() in ("gfx950", "gfx1250")
43
+
44
+
45
+ def is_tdm_avail():
46
+ return get_arch() in ("gfx1250",)
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+
3
+ # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+
9
+ @triton.jit
10
+ def remap_xcd_chunked(
11
+ pid, GRID_MN, NUM_XCDS: tl.constexpr = 8, CHUNK_SIZE: tl.constexpr = 2
12
+ ):
13
+ # Compute current XCD and local PID
14
+ xcd = pid % NUM_XCDS
15
+ # distribute the modulo pids in round robin
16
+ if pid > (GRID_MN // (NUM_XCDS * CHUNK_SIZE)) * (NUM_XCDS * CHUNK_SIZE):
17
+ return pid
18
+ local_pid = pid // NUM_XCDS
19
+ # Calculate chunk index and position within chunk
20
+ chunk_idx = local_pid // CHUNK_SIZE
21
+ pos_in_chunk = local_pid % CHUNK_SIZE
22
+ # Calculate new PID
23
+ new_pid = chunk_idx * NUM_XCDS * CHUNK_SIZE + xcd * CHUNK_SIZE + pos_in_chunk
24
+ return new_pid
25
+
26
+
27
+ @triton.jit
28
+ def remap_xcd(pid, GRID_MN, NUM_XCDS: tl.constexpr = 8):
29
+ ## pid remapping on xcds
30
+ # Number of pids per XCD in the new arrangement
31
+ pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS
32
+ # When GRID_MN cannot divide NUM_XCDS, some xcds will have
33
+ # pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
34
+ # We calculate the number of xcds that have pids_per_xcd pids as
35
+ # tall_xcds
36
+ tall_xcds = GRID_MN % NUM_XCDS
37
+ tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds
38
+ # Compute current XCD and local pid within the XCD
39
+ xcd = pid % NUM_XCDS
40
+ local_pid = pid // NUM_XCDS
41
+ # Calculate new pid based on the new grouping
42
+ # Note that we need to consider the following two cases:
43
+ # 1. the current pid is on a tall xcd
44
+ # 2. the current pid is on a short xcd
45
+ if xcd < tall_xcds:
46
+ pid = xcd * pids_per_xcd + local_pid
47
+ else:
48
+ pid = (
49
+ tall_xcds * pids_per_xcd
50
+ + (xcd - tall_xcds) * (pids_per_xcd - 1)
51
+ + local_pid
52
+ )
53
+
54
+ return pid
55
+
56
+
57
+ @triton.jit
58
+ def pid_grid(pid: int, num_pid_m: int, num_pid_n: int, GROUP_SIZE_M: tl.constexpr = 1):
59
+ """
60
+ Maps 1D pid to 2D grid coords (pid_m, pid_n).
61
+
62
+ Args:
63
+ - pid: 1D pid
64
+ - num_pid_m: grid m size
65
+ - num_pid_n: grid n size
66
+ - GROUP_SIZE_M: tl.constexpr: default is 1
67
+ """
68
+ if GROUP_SIZE_M == 1:
69
+ pid_m = pid // num_pid_n
70
+ pid_n = pid % num_pid_n
71
+ else:
72
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
73
+ group_id = pid // num_pid_in_group
74
+ first_pid_m = group_id * GROUP_SIZE_M
75
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
76
+ tl.assume(group_size_m >= 0)
77
+ pid_m = first_pid_m + (pid % group_size_m)
78
+ pid_n = (pid % num_pid_in_group) // group_size_m
79
+
80
+ return pid_m, pid_n
81
+
82
+
83
+ @triton.jit
84
+ def pid_grid_3d(pid: int, num_pid_m: int, num_pid_n: int, num_pid_k):
85
+ """
86
+ Maps 1D pid to 3D grid coords (pid_m, pid_n, pid_k).
87
+ Args:
88
+ - pid: 1D pid
89
+ - num_pid_m: grid m size
90
+ - num_pid_n: grid n size
91
+ - num_pid_k: grid k size
92
+
93
+ Returns:
94
+ - pid_m, pid_n, pid_k: 3D grid coordinates
95
+ """
96
+ pid_m = pid % num_pid_m
97
+ pid_n = (pid // num_pid_m) % num_pid_n
98
+ pid_k = pid // (num_pid_m * num_pid_n) % num_pid_k
99
+
100
+ return pid_m, pid_n, pid_k
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+ # Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+ # Imports.
5
+ # ------------------------------------------------------------------------------
6
+
7
+ # PyTorch
8
+ import torch
9
+ from torch import Tensor
10
+
11
+ # AITER: logging
12
+ from .logger import AiterTritonLogger
13
+
14
+ _LOGGER: AiterTritonLogger = AiterTritonLogger()
15
+
16
+
17
+ # Supported data types.
18
+ # ------------------------------------------------------------------------------
19
+
20
+ # Supported data types, as strings.
21
+ SUPPORTED_DTYPES_STR: set[str] = {"fp16", "bf16"}
22
+
23
+
24
+ # Convert string data type to PyTorch data type.
25
+ def dtype_from_str(dtype_str: str) -> torch.dtype:
26
+ dtype_str = dtype_str.strip().lower()
27
+ dtype_str = dtype_str[1:] if dtype_str[0] in {"i", "o"} else dtype_str
28
+ assert (
29
+ dtype_str in SUPPORTED_DTYPES_STR
30
+ ), "String data type isn't in set of supported string data types."
31
+ return {"fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str]
32
+
33
+
34
+ # Supported data types, as PyTorch types.
35
+ SUPPORTED_DTYPES: set[torch.dtype] = {
36
+ dtype_from_str(dtype_str) for dtype_str in SUPPORTED_DTYPES_STR
37
+ }
38
+
39
+
40
+ # Convert PyTorch data type to string data type.
41
+ def str_from_dtype(dtype: torch.dtype) -> str:
42
+ assert (
43
+ dtype in SUPPORTED_DTYPES
44
+ ), "PyTorch data type isn't in set of supported PyTorch data types."
45
+ return {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype]
46
+
47
+
48
+ # Default data type, as string.
49
+ DTYPE_STR: str = "bf16"
50
+ assert (
51
+ DTYPE_STR in SUPPORTED_DTYPES_STR
52
+ ), "Default string data type isn't in set of supported string data types."
53
+
54
+
55
+ # Default data type, as PyTorch type.
56
+ DTYPE: torch.dtype = dtype_from_str(DTYPE_STR)
57
+
58
+
59
+ # Other defaults.
60
+ # ------------------------------------------------------------------------------
61
+
62
+ # Default device.
63
+ DEVICE: torch.device | str = "cuda"
64
+
65
+ # Default RNG seed for input generation.
66
+ RNG_SEED: int = 0
67
+
68
+ # Default number of group sizes.
69
+ NUM_GROUP_SIZES: int = 1
70
+
71
+ # Default transposition (NN).
72
+ TRANS_LHS: bool = False
73
+ TRANS_RHS: bool = False
74
+
75
+
76
+ # Parameter checking functions.
77
+ # ------------------------------------------------------------------------------
78
+
79
+
80
+ def is_power_of_2(x: int) -> bool:
81
+ return (x > 0) and (x & (x - 1) == 0)
82
+
83
+
84
+ def check_input_device_dtype(
85
+ lhs: Tensor, rhs: Tensor, group_sizes: Tensor, bias: Tensor | None = None
86
+ ) -> None:
87
+ assert (
88
+ lhs.device == rhs.device == group_sizes.device
89
+ ), f"All input tensors must be in the same device (lhs = {lhs.device}, rhs = {rhs.device}, group_sizes = {group_sizes.device})."
90
+ assert (
91
+ lhs.dtype == rhs.dtype
92
+ ), f"lhs and rhs types must match (lhs = {lhs.dtype}, rhs = {rhs.dtype})."
93
+ assert group_sizes.dtype == torch.int32, "group_sizes type must be int32."
94
+
95
+ if bias is not None:
96
+ assert (
97
+ bias.device == lhs.device
98
+ ), f"bias must be on the same device as lhs (bias = {bias.device}, lhs = {lhs.device})."
99
+ assert (
100
+ bias.dtype == lhs.dtype
101
+ ), f"bias dtype must match lhs dtype (bias = {bias.dtype}, lhs = {lhs.dtype})."
102
+
103
+
104
+ def check_bias_shape_stride(bias: Tensor, G: int, N: int) -> None:
105
+ assert bias.shape == (
106
+ G,
107
+ N,
108
+ ), f"bias must have shape (G, N) = ({G}, {N}), got {bias.shape}."
109
+ assert bias.stride() == (N, 1), "bias must be row-major (bias.stride() == (N, 1))."
110
+
111
+
112
+ # Generation of group sizes.
113
+ # ------------------------------------------------------------------------------
114
+
115
+
116
+ # Probabilities for generating random group sizes.
117
+ UNUSED_TOKENS_PROB: float = 0.0
118
+ UNUSED_EXPERTS_PROB: float = 0.1
119
+
120
+
121
+ def gen_uniform_group_sizes(
122
+ M: int,
123
+ G: int,
124
+ device: torch.device | str = DEVICE,
125
+ ) -> Tensor:
126
+ assert M >= 0, f"Number of tokens M must be non-negative (it's {M})."
127
+ assert G > 0, f"Number of experts G must be positive (it's {G})."
128
+
129
+ base = M // G
130
+ remainder = M % G
131
+ group_sizes = torch.full((G,), base, dtype=torch.int32, device=device)
132
+ if remainder > 0:
133
+ group_sizes[:remainder] += 1
134
+
135
+ assert (
136
+ len(group_sizes) == G
137
+ ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})."
138
+ assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative."
139
+ assert (
140
+ torch.sum(group_sizes).item() == M
141
+ ), f"Group sizes don't add up to total tokens {M}."
142
+ assert group_sizes.dtype == torch.int32, "Group sizes must be int32."
143
+
144
+ return group_sizes
145
+
146
+
147
+ def gen_group_sizes(
148
+ M: int,
149
+ G: int,
150
+ device: torch.device | str = DEVICE,
151
+ rng_seed: int | None = RNG_SEED,
152
+ unused_tokens_prob: float = UNUSED_TOKENS_PROB,
153
+ unused_experts_prob: float = UNUSED_EXPERTS_PROB,
154
+ ) -> Tensor:
155
+ assert M >= 0, f"Number of tokens M must be non-negative (it's {M})."
156
+ assert G > 0, f"Number of experts G must be positive (it's {G})."
157
+ assert (
158
+ 0 <= unused_tokens_prob <= 1
159
+ ), f"Probability of unused tokens must be in [0, 1] interval (it's {unused_tokens_prob})."
160
+ assert (
161
+ 0 <= unused_experts_prob <= 1
162
+ ), f"Probability of unused experts must be in [0, 1] interval (it's {unused_experts_prob})."
163
+
164
+ if rng_seed is not None:
165
+ torch.manual_seed(rng_seed)
166
+
167
+ if unused_tokens_prob > 0:
168
+ # Optionally drop tokens to simulate routing sparsity, some tokens may not be routed.
169
+ num_unused_tokens = M
170
+ while num_unused_tokens == M:
171
+ num_unused_tokens = int(
172
+ torch.binomial(
173
+ torch.tensor(float(M), device=device),
174
+ torch.tensor(unused_tokens_prob, device=device),
175
+ ).item()
176
+ )
177
+ else:
178
+ num_unused_tokens = 0
179
+ num_used_tokens = M - num_unused_tokens
180
+ assert (
181
+ num_unused_tokens >= 0
182
+ ), f"Number of unused tokens must be non-negative (it's {num_unused_tokens})."
183
+ assert (
184
+ num_used_tokens > 0
185
+ ), f"Number of used tokens must be positive (it's {num_used_tokens})."
186
+ assert (
187
+ num_used_tokens + num_unused_tokens == M
188
+ ), f"Unused + used tokens don't add up total tokens ({num_used_tokens} + {num_unused_tokens} != {M})."
189
+
190
+ if num_unused_tokens > 0:
191
+ _LOGGER.debug(
192
+ f"Group sizes generation: dropped {num_unused_tokens} token{'s' if num_unused_tokens > 1 else ''}.",
193
+ )
194
+
195
+ if unused_experts_prob > 0:
196
+ # Some experts may have zero tokens assigned to them.
197
+ num_used_experts = 0
198
+ while num_used_experts == 0:
199
+ used_experts = torch.nonzero(
200
+ torch.rand((G,), device=device) >= unused_experts_prob
201
+ ).squeeze()
202
+ num_used_experts = used_experts.numel()
203
+ else:
204
+ used_experts = torch.arange(0, G, device=device)
205
+ num_used_experts = G
206
+ num_unused_experts = G - num_used_experts
207
+ assert (
208
+ num_unused_experts >= 0
209
+ ), f"Number of unused experts must be non-negative (it's {num_unused_experts})."
210
+ assert (
211
+ num_used_experts >= 1
212
+ ), f"At least one expert must be used (it's {num_used_experts})."
213
+ assert (
214
+ num_unused_experts + num_used_experts == G
215
+ ), f"Unused + used experts don't add up total experts ({num_unused_experts} + {num_used_experts} != {G})."
216
+
217
+ if num_unused_experts > 0:
218
+ _LOGGER.debug(
219
+ f"Group sizes generation: dropped {num_unused_experts} expert{'s' if num_unused_experts > 1 else ''}.",
220
+ )
221
+
222
+ group_sizes = torch.bincount(
223
+ used_experts[
224
+ torch.randint(low=0, high=num_used_experts, size=(num_used_tokens,))
225
+ ],
226
+ minlength=G,
227
+ ).to(torch.int32)
228
+
229
+ assert (
230
+ len(group_sizes) == G
231
+ ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})."
232
+ assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative."
233
+ assert (
234
+ torch.sum(group_sizes).item() == num_used_tokens
235
+ ), f"Group sizes don't add up to used tokens {num_used_tokens}."
236
+ assert group_sizes.dtype == torch.int32, "Group sizes must be int32."
237
+
238
+ return group_sizes
239
+
240
+
241
+ def gen_multiple_group_sizes(
242
+ num_group_sizes: int,
243
+ M: int,
244
+ G: int,
245
+ device: torch.device | str = DEVICE,
246
+ rng_seed: int | None = RNG_SEED,
247
+ unused_tokens_prob: float = UNUSED_TOKENS_PROB,
248
+ unused_experts_prob: float = UNUSED_EXPERTS_PROB,
249
+ group_sizes_0: Tensor | None = None,
250
+ ) -> list[Tensor]:
251
+ assert (
252
+ num_group_sizes > 0
253
+ ), f"Number of group sizes to be generated must be positive, it's {num_group_sizes}."
254
+ multiple_group_sizes = [
255
+ gen_group_sizes(
256
+ M,
257
+ G,
258
+ device=device,
259
+ rng_seed=rng_seed if g == 0 else None,
260
+ unused_tokens_prob=unused_tokens_prob,
261
+ unused_experts_prob=unused_experts_prob,
262
+ )
263
+ for g in range(
264
+ num_group_sizes if group_sizes_0 is None else num_group_sizes - 1
265
+ )
266
+ ]
267
+ if group_sizes_0 is not None:
268
+ multiple_group_sizes.insert(0, group_sizes_0)
269
+ assert (
270
+ len(multiple_group_sizes) == num_group_sizes
271
+ ), f"Expecting {num_group_sizes} distinct group sizes (it's {len(multiple_group_sizes)})."
272
+ return multiple_group_sizes
273
+
274
+
275
+ # GMM helpers: tensor generation.
276
+ # ------------------------------------------------------------------------------
277
+
278
+
279
+ def gen_gmm_input(
280
+ M: int,
281
+ K: int,
282
+ N: int,
283
+ G: int,
284
+ device: torch.device | str = DEVICE,
285
+ preferred_element_type: torch.dtype = DTYPE,
286
+ trans_rhs: bool = TRANS_RHS,
287
+ rng_seed: int | None = RNG_SEED,
288
+ unif_group_sizes: bool = False,
289
+ ) -> tuple[Tensor, Tensor, Tensor]:
290
+ assert M > 0, f"Number of lhs rows M must be positive (M = {M})."
291
+ assert K > 0, f"Number of lhs columns / rhs rows K must be positive (K = {K})."
292
+ assert N > 0, f"Number of rhs columns N must be positive (N = {N})."
293
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
294
+
295
+ if rng_seed is not None:
296
+ torch.manual_seed(rng_seed)
297
+
298
+ lhs = torch.randn((M, K), dtype=torch.float32, device=device)
299
+ lhs = lhs.to(preferred_element_type)
300
+
301
+ if trans_rhs:
302
+ rhs = torch.randn((G, N, K), dtype=torch.float32, device=device).permute(
303
+ 0, 2, 1
304
+ )
305
+ else:
306
+ rhs = torch.randn((G, K, N), dtype=torch.float32, device=device)
307
+ rhs = rhs.to(preferred_element_type)
308
+
309
+ group_sizes = (
310
+ gen_uniform_group_sizes(M, G, device=device)
311
+ if unif_group_sizes
312
+ else gen_group_sizes(M, G, device=device, rng_seed=None)
313
+ )
314
+
315
+ return lhs, rhs, group_sizes
316
+
317
+
318
+ def gen_gmm_output(
319
+ M: int,
320
+ N: int,
321
+ device: torch.device | str = DEVICE,
322
+ preferred_element_type: torch.dtype = DTYPE,
323
+ ) -> Tensor:
324
+ assert M > 0, f"Number of out rows M must be positive (M = {M})."
325
+ assert N > 0, f"Number of out columns N must be positive (N = {N})."
326
+
327
+ out = torch.empty((M, N), dtype=preferred_element_type, device=device)
328
+
329
+ return out
330
+
331
+
332
+ def gen_gmm_tensors(
333
+ M: int,
334
+ K: int,
335
+ N: int,
336
+ G: int,
337
+ num_group_sizes: int,
338
+ device: torch.device | str = DEVICE,
339
+ input_type: torch.dtype = DTYPE,
340
+ output_type: torch.dtype = DTYPE,
341
+ trans_lhs: bool = False,
342
+ trans_rhs: bool = TRANS_RHS,
343
+ rng_seed: int | None = RNG_SEED,
344
+ unif_group_sizes: bool = False,
345
+ use_bias: bool = False,
346
+ ) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]:
347
+ lhs, rhs, group_sizes_0 = gen_gmm_input(
348
+ M,
349
+ K,
350
+ N,
351
+ G,
352
+ device=device,
353
+ preferred_element_type=input_type,
354
+ trans_rhs=trans_rhs,
355
+ rng_seed=rng_seed,
356
+ unif_group_sizes=unif_group_sizes,
357
+ )
358
+ multiple_group_sizes = gen_multiple_group_sizes(
359
+ num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0
360
+ )
361
+ out = gen_gmm_output(M, N, device=device, preferred_element_type=output_type)
362
+ bias = None
363
+ if use_bias:
364
+ torch.manual_seed(rng_seed + 1000) # Different seed for bias
365
+ bias = torch.randn(G, N, dtype=input_type, device=device)
366
+
367
+ return lhs, rhs, multiple_group_sizes, out, bias
368
+
369
+
370
+ # GMM helpers: get information from tensors.
371
+ # ------------------------------------------------------------------------------
372
+
373
+
374
+ def get_gmm_shape(
375
+ lhs: Tensor, rhs: Tensor, group_sizes: Tensor
376
+ ) -> tuple[int, int, int, int]:
377
+ assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
378
+ assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})."
379
+ assert (
380
+ group_sizes.dim() == 1
381
+ ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})."
382
+
383
+ M, lhs_k = lhs.shape
384
+ rhs_g, rhs_k, N = rhs.shape
385
+ group_sizes_g = group_sizes.shape[0]
386
+
387
+ assert (
388
+ lhs_k == rhs_k
389
+ ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})."
390
+ K = lhs_k
391
+ assert (
392
+ rhs_g == group_sizes_g
393
+ ), f"G dimension of rhs and group_sizes don't match (rhs = {rhs_g}, group_sizes = {group_sizes_g})."
394
+ G = rhs_g
395
+
396
+ assert M > 0, f"M must be positive, it's {M}."
397
+ assert K > 0, f"K must be positive, it's {K}."
398
+ assert N > 0, f"N must be positive, it's {N}"
399
+ assert G > 0, f"G must be positive, it's {G}"
400
+
401
+ return M, K, N, G
402
+
403
+
404
+ def get_gmm_output(
405
+ M: int,
406
+ N: int,
407
+ device: torch.device | str = DEVICE,
408
+ preferred_element_type: torch.dtype = DTYPE,
409
+ existing_out: Tensor | None = None,
410
+ ) -> Tensor:
411
+ assert M > 0, f"Number of out rows M must be positive (M = {M})."
412
+ assert N > 0, f"Number of out columns N must be positive (N = {N})."
413
+
414
+ if existing_out is not None:
415
+ assert (
416
+ existing_out.device == device
417
+ ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})."
418
+ assert (
419
+ existing_out.dtype == preferred_element_type
420
+ ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})."
421
+ assert existing_out.shape == (
422
+ M,
423
+ N,
424
+ ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(M, N)})."
425
+ return existing_out
426
+
427
+ return gen_gmm_output(
428
+ M,
429
+ N,
430
+ device=device,
431
+ preferred_element_type=preferred_element_type,
432
+ )
433
+
434
+
435
+ def get_gmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]:
436
+ assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
437
+ assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})."
438
+ assert out.dim() == 2, f"out must have 2 dimensions (it's {out.dim()})."
439
+
440
+ lhs_m, lhs_k = lhs.shape
441
+ G, rhs_k, rhs_n = rhs.shape
442
+ out_m, out_n = out.shape
443
+
444
+ assert (
445
+ lhs_m == out_m
446
+ ), f"M dimension of lhs and out don't match (lhs = {lhs_m}, rhs = {out_m})."
447
+ M = lhs_m
448
+ assert (
449
+ lhs_k == rhs_k
450
+ ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})."
451
+ K = lhs_k
452
+ assert (
453
+ rhs_n == out_n
454
+ ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})."
455
+ N = rhs_n
456
+
457
+ assert M > 0, f"M must be positive, it's {M}."
458
+ assert K > 0, f"K must be positive, it's {K}."
459
+ assert N > 0, f"N must be positive, it's {N}"
460
+ assert G > 0, f"G must be positive, it's {G}"
461
+
462
+ is_lhs_row_major = lhs.stride() == (K, 1)
463
+ assert is_lhs_row_major, "lhs must be row-major."
464
+ is_rhs_row_major = rhs.stride() == (K * N, N, 1)
465
+ is_rhs_col_major = rhs.stride() == (K * N, 1, K)
466
+ assert (
467
+ is_rhs_row_major != is_rhs_col_major
468
+ ), "rhs must be row-major or column-major."
469
+ is_out_row_major = out.stride() == (N, 1)
470
+ assert is_out_row_major, "out must be row-major."
471
+
472
+ # Get rhs leading dimension according to transposition configuration.
473
+ ld_rhs = N if is_rhs_row_major else K
474
+
475
+ return is_rhs_col_major, ld_rhs
476
+
477
+
478
+ # TGMM helpers: tensor generation.
479
+ # ------------------------------------------------------------------------------
480
+
481
+
482
+ def gen_tgmm_input(
483
+ M: int,
484
+ K: int,
485
+ N: int,
486
+ G: int,
487
+ device: torch.device | str = DEVICE,
488
+ preferred_element_type: torch.dtype = DTYPE,
489
+ trans_lhs: bool = TRANS_LHS,
490
+ rng_seed: int | None = RNG_SEED,
491
+ unif_group_sizes: bool = False,
492
+ ) -> tuple[Tensor, Tensor, Tensor]:
493
+ assert K > 0, f"Number of lhs rows K must be positive (M = {K})."
494
+ assert M > 0, f"Number of lhs columns / rhs rows M must be positive (K = {M})."
495
+ assert N > 0, f"Number of rhs columns N must be positive (N = {N})."
496
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
497
+
498
+ if rng_seed is not None:
499
+ torch.manual_seed(rng_seed)
500
+
501
+ if trans_lhs:
502
+ lhs = torch.randn((M, K), dtype=torch.float32, device=device).T
503
+ else:
504
+ lhs = torch.randn((K, M), dtype=torch.float32, device=device)
505
+ lhs = lhs.to(preferred_element_type)
506
+
507
+ rhs = torch.randn((M, N), dtype=torch.float32, device=device)
508
+ rhs = rhs.to(preferred_element_type)
509
+
510
+ group_sizes = (
511
+ gen_uniform_group_sizes(M, G, device=device)
512
+ if unif_group_sizes
513
+ else gen_group_sizes(M, G, device=device, rng_seed=None)
514
+ )
515
+
516
+ return lhs, rhs, group_sizes
517
+
518
+
519
+ def gen_tgmm_output(
520
+ K: int,
521
+ N: int,
522
+ G: int,
523
+ device: torch.device | str = DEVICE,
524
+ preferred_element_type: torch.dtype = DTYPE,
525
+ ) -> Tensor:
526
+ assert K > 0, f"Number of out rows K must be positive (K = {K})."
527
+ assert N > 0, f"Number of out columns N must be positive (N = {N})."
528
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
529
+
530
+ out = torch.empty((G, K, N), dtype=preferred_element_type, device=device)
531
+
532
+ return out
533
+
534
+
535
+ def gen_tgmm_bias_grad(
536
+ K: int,
537
+ G: int,
538
+ device: torch.device | str = DEVICE,
539
+ with_bias_grad: bool = False,
540
+ ) -> Tensor:
541
+ if with_bias_grad:
542
+ assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})."
543
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
544
+ return torch.empty((G, K), device=device, dtype=torch.float32)
545
+ else:
546
+ # Return dummy pointer when bias_grad is not needed.
547
+ # Must be float32 because atomic_add does not support bf16/fp16,
548
+ # and Triton validates the pointer dtype even in dead branches.
549
+ return torch.tensor([], device=device, dtype=torch.float32)
550
+
551
+
552
+ def gen_tgmm_tensors(
553
+ M: int,
554
+ K: int,
555
+ N: int,
556
+ G: int,
557
+ num_group_sizes: int,
558
+ device: torch.device | str = DEVICE,
559
+ input_type: torch.dtype = DTYPE,
560
+ output_type: torch.dtype = DTYPE,
561
+ trans_lhs: bool = TRANS_LHS,
562
+ trans_rhs: bool = False,
563
+ rng_seed: int | None = RNG_SEED,
564
+ unif_group_sizes: bool = False,
565
+ use_bias: bool = False,
566
+ ) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]:
567
+ lhs, rhs, group_sizes_0 = gen_tgmm_input(
568
+ M,
569
+ K,
570
+ N,
571
+ G,
572
+ device=device,
573
+ preferred_element_type=input_type,
574
+ trans_lhs=trans_lhs,
575
+ rng_seed=rng_seed,
576
+ unif_group_sizes=unif_group_sizes,
577
+ )
578
+ multiple_group_sizes = gen_multiple_group_sizes(
579
+ num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0
580
+ )
581
+ out = gen_tgmm_output(K, N, G, device=device, preferred_element_type=output_type)
582
+ if use_bias:
583
+ bias_grad = gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=True)
584
+ else:
585
+ bias_grad = None
586
+ return lhs, rhs, multiple_group_sizes, out, bias_grad
587
+
588
+
589
+ # TGMM helpers: get information from tensors.
590
+ # ------------------------------------------------------------------------------
591
+
592
+
593
+ def get_tgmm_shape(
594
+ lhs: Tensor, rhs: Tensor, group_sizes: Tensor
595
+ ) -> tuple[int, int, int, int]:
596
+ assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
597
+ assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})."
598
+ assert (
599
+ group_sizes.dim() == 1
600
+ ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})."
601
+
602
+ K, lhs_m = lhs.shape
603
+ rhs_m, N = rhs.shape
604
+ G = group_sizes.shape[0]
605
+
606
+ assert (
607
+ lhs_m == rhs_m
608
+ ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})."
609
+ M = lhs_m
610
+
611
+ assert M > 0, f"M must be positive, it's {M}."
612
+ assert K > 0, f"K must be positive, it's {K}."
613
+ assert N > 0, f"N must be positive, it's {N}"
614
+ assert G > 0, f"G must be positive, it's {G}"
615
+
616
+ return M, K, N, G
617
+
618
+
619
+ def get_tgmm_output(
620
+ K: int,
621
+ N: int,
622
+ G: int,
623
+ device: torch.device | str = DEVICE,
624
+ preferred_element_type: torch.dtype = DTYPE,
625
+ existing_out: Tensor | None = None,
626
+ ) -> Tensor:
627
+ assert K > 0, f"Number of out rows K must be positive (K = {K})."
628
+ assert N > 0, f"Number of out columns N must be positive (N = {N})."
629
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
630
+
631
+ if existing_out is not None:
632
+ assert (
633
+ existing_out.device == device
634
+ ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})."
635
+ assert (
636
+ existing_out.dtype == preferred_element_type
637
+ ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})."
638
+ assert existing_out.shape == (
639
+ G,
640
+ K,
641
+ N,
642
+ ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(G, K, N)})."
643
+ return existing_out
644
+
645
+ return gen_tgmm_output(
646
+ K,
647
+ N,
648
+ G,
649
+ device=device,
650
+ preferred_element_type=preferred_element_type,
651
+ )
652
+
653
+
654
+ def get_tgmm_bias_grad(
655
+ K: int,
656
+ G: int,
657
+ device: torch.device | str = DEVICE,
658
+ existing_bias_grad: Tensor | None = None,
659
+ ) -> Tensor:
660
+ """
661
+ Get or validate bias gradient tensor for TGMM.
662
+
663
+ If existing_bias_grad is provided, validates its shape, device, dtype, and stride,
664
+ and always zeros it before returning (since the kernel uses atomic_add).
665
+ If existing_bias_grad is None, returns a dummy tensor (for use when COMPUTE_BIAS_GRAD=False).
666
+ Parameters
667
+ ----------
668
+ K : int
669
+ Number of rows in the bias gradient tensor.
670
+ G : int
671
+ Number of groups.
672
+ device : torch.device or str
673
+ Device for the tensor.
674
+ existing_bias_grad : torch.Tensor or None
675
+ Existing bias gradient tensor to validate and use.
676
+ Returns
677
+ -------
678
+ torch.Tensor
679
+ Valid bias gradient tensor or dummy tensor.
680
+ """
681
+ assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})."
682
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
683
+
684
+ if existing_bias_grad is not None:
685
+ # Validate existing bias_grad tensor.
686
+ expected_shape = (G, K)
687
+ assert (
688
+ tuple(existing_bias_grad.shape) == expected_shape
689
+ ), f"bias_grad must have shape {expected_shape}, got {tuple(existing_bias_grad.shape)}."
690
+ assert (
691
+ existing_bias_grad.device == device
692
+ ), f"bias_grad must be on the same device (bias_grad = {existing_bias_grad.device}, device = {device})."
693
+ assert (
694
+ existing_bias_grad.dtype == torch.float32
695
+ ), f"bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), got {existing_bias_grad.dtype}."
696
+ assert existing_bias_grad.stride() == (
697
+ K,
698
+ 1,
699
+ ), f"bias_grad must be row-major with stride (K, 1) = ({K}, 1), got {existing_bias_grad.stride()}."
700
+
701
+ # Always zero the tensor since bias_grad represents gradients for the current
702
+ # computation and should start fresh. The kernel uses atomic_add which adds to
703
+ # existing values, so we must zero before the kernel runs.
704
+ existing_bias_grad.zero_()
705
+
706
+ return existing_bias_grad
707
+
708
+ else:
709
+ return gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=False)
710
+
711
+
712
+ def get_tgmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]:
713
+ assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
714
+ assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})."
715
+ assert out.dim() == 3, f"out must have 3 dimensions (it's {out.dim()})."
716
+
717
+ lhs_k, lhs_m = lhs.shape
718
+ rhs_m, rhs_n = rhs.shape
719
+ G, out_k, out_n = out.shape
720
+
721
+ assert (
722
+ lhs_m == rhs_m
723
+ ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})."
724
+ M = lhs_m
725
+ assert (
726
+ lhs_k == out_k
727
+ ), f"K dimension of lhs and out don't match (lhs = {lhs_k}, rhs = {out_k})."
728
+ K = lhs_k
729
+ assert (
730
+ rhs_n == out_n
731
+ ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})."
732
+ N = rhs_n
733
+
734
+ assert M > 0, f"M must be positive, it's {M}."
735
+ assert K > 0, f"K must be positive, it's {K}."
736
+ assert N > 0, f"N must be positive, it's {N}"
737
+ assert G > 0, f"G must be positive, it's {G}"
738
+
739
+ is_lhs_row_major = lhs.stride() == (M, 1)
740
+ is_lhs_col_major = lhs.stride() == (1, K)
741
+ assert (
742
+ is_lhs_row_major != is_lhs_col_major
743
+ ), "lhs must be row-major or column-major."
744
+ is_rhs_row_major = rhs.stride() == (N, 1)
745
+ assert is_rhs_row_major, "rhs must be row-major."
746
+ is_out_row_major = out.stride() == (K * N, N, 1)
747
+ assert is_out_row_major, "out must be row-major."
748
+
749
+ # Get lhs leading dimension according to transposition configuration.
750
+ ld_lhs = M if is_lhs_row_major else K
751
+
752
+ return is_lhs_col_major, ld_lhs
build/torch211-cxx11-cu126-x86_64-linux/_grouped_gemm_triton/utils/logger.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+
5
+ # AITER Triton Logger which is singleton object around python logging.
6
+ # Note: Python logging is also a singleton object, but we want to read the
7
+ # env var AITER_LOG_LEVEL once at the beginning. Another alternative is to do
8
+ # this in __init__.py. In fact, that's how CK logger is setup. We can look at
9
+ # switching to that at some point
10
+ #
11
+ # AITER_LOG_LEVEL follows python logging levels
12
+ # DEBUG
13
+ # INFO
14
+ # WARNING
15
+ # ERROR
16
+ # CRITICAL
17
+ #
18
+ class AiterTritonLogger(object):
19
+ _instance = None
20
+
21
+ def __new__(cls):
22
+ if cls._instance is None:
23
+ cls._instance = super(AiterTritonLogger, cls).__new__(cls)
24
+ log_level_str = os.getenv("AITER_TRITON_LOG_LEVEL", "WARNING").upper()
25
+ numeric_level = getattr(logging, log_level_str, logging.WARNING)
26
+ cls._instance._logger = logging.getLogger("AITER_TRITON")
27
+ cls._instance._logger.setLevel(numeric_level)
28
+
29
+ return cls._instance
30
+
31
+ def get_logger(self):
32
+ return self._logger
33
+
34
+ def debug(self, msg):
35
+ self._logger.debug(msg)
36
+
37
+ def info(self, msg):
38
+ self._logger.info(msg)
39
+
40
+ def warning(self, msg):
41
+ self._logger.warning(msg)
42
+
43
+ def error(self, msg):
44
+ self._logger.error(msg)
45
+
46
+ def critical(self, msg):
47
+ self._logger.critical(msg)
build/torch211-cxx11-cu126-x86_64-linux/{_megablocks_cuda_ae601bb.abi3.so → _megablocks_cuda_f8f8b50.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:04357ebe4748e32fc898f2b6b3c4310beda29692b0ac34b78bd1c031efdee1bb
3
- size 13179696
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78a283cd033d5770287d652455033307d26b1896681abbeb5ed4d1cba4dbc1fe
3
+ size 13822768
build/torch211-cxx11-cu126-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_ae601bb
3
- ops = torch.ops._megablocks_cuda_ae601bb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_ae601bb::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_f8f8b50
3
+ ops = torch.ops._megablocks_cuda_f8f8b50
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_f8f8b50::{op_name}"
build/torch211-cxx11-cu126-x86_64-linux/grouped_gemm/backend.py CHANGED
@@ -2,16 +2,16 @@
2
  # extensions. Otherwise libc10.so cannot be found.
3
  import torch
4
 
5
- # # TODO(tgale): Wrap this in a try-block with better
6
- # # error message and instructions for building the
7
- # # c++ operations.
8
- # import grouped_gemm_backend as backend
9
 
10
- # We import the backend operations from the megablocks package as
11
- # grouped_gemm is vendored in megablocks in this repository.
12
- # from ... import _ops as backend
13
- # from megablocks._ops import ops as backend # type: ignore
14
- from .._ops import ops as backend # type: ignore
 
15
 
16
  def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
17
  assert not (trans_a and trans_b)
 
2
  # extensions. Otherwise libc10.so cannot be found.
3
  import torch
4
 
5
+ # On ROCm there is no CUTLASS grouped GEMM; dispatch to the vendored AITER
6
+ # Triton kernels instead. On CUDA we use the compiled CUTLASS `gmm` op.
7
+ _IS_ROCM = torch.version.hip is not None
 
8
 
9
+ if _IS_ROCM:
10
+ from .._grouped_gemm_triton import adapter as backend
11
+ else:
12
+ # We import the backend operations from the megablocks package as
13
+ # grouped_gemm is vendored in megablocks in this repository.
14
+ from .._ops import ops as backend # type: ignore
15
 
16
  def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
17
  assert not (trans_a and trans_b)
build/torch211-cxx11-cu126-x86_64-linux/metadata.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "name": "megablocks",
3
- "id": "_megablocks_cuda_ae601bb",
4
  "version": 1,
5
  "license": "Apache-2.0",
6
  "python-depends": [],
@@ -14,7 +14,8 @@
14
  "8.6",
15
  "8.7",
16
  "8.9",
17
- "9.0"
 
18
  ]
19
  }
20
  }
 
1
  {
2
  "name": "megablocks",
3
+ "id": "_megablocks_cuda_f8f8b50",
4
  "version": 1,
5
  "license": "Apache-2.0",
6
  "python-depends": [],
 
14
  "8.6",
15
  "8.7",
16
  "8.9",
17
+ "9.0",
18
+ "9.0+PTX"
19
  ]
20
  }
21
  }
build/torch211-cxx11-cu128-x86_64-linux/__init__.py CHANGED
@@ -3,7 +3,9 @@
3
 
4
  import torch
5
 
6
- from ._ops import ops
 
 
7
 
8
  from .grouped_gemm import backend as gg_backend
9
  from .grouped_gemm import ops as gg_ops
@@ -136,7 +138,8 @@ def sort(
136
  Returns:
137
  The sorted values tensor
138
  """
139
- return ops.sort(x, end_bit, x_out, iota_out)
 
140
 
141
 
142
  # Convenience functions for common use cases
 
3
 
4
  import torch
5
 
6
+ # Stable alias: bare `ops` is shadowed by `from . import layers` below.
7
+ from ._ops import ops as _compiled_ops
8
+ from . import ops
9
 
10
  from .grouped_gemm import backend as gg_backend
11
  from .grouped_gemm import ops as gg_ops
 
138
  Returns:
139
  The sorted values tensor
140
  """
141
+ _compiled_ops.sort(x, end_bit, x_out, iota_out)
142
+ return x_out
143
 
144
 
145
  # Convenience functions for common use cases
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/__init__.py ADDED
File without changes
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py ADDED
File without changes
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+ # Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+
5
+ # Imports.
6
+ # ------------------------------------------------------------------------------
7
+
8
+ # Python standard library
9
+ import functools
10
+
11
+ # Triton
12
+ import triton
13
+ import triton.language as tl
14
+
15
+ # AITER
16
+ from ..configs import CONFIGS as _CONFIGS
17
+ from ..utils._triton import arch_info
18
+ from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
19
+
20
+ # Kernel config.
21
+ # ------------------------------------------------------------------------------
22
+
23
+
24
+ @functools.lru_cache()
25
+ def get_config(
26
+ gmm_type: str, M: int, K: int, N: int, G: int, accumulate: bool = False
27
+ ) -> dict[str, int]:
28
+ assert gmm_type in {
29
+ "gmm",
30
+ "ptgmm",
31
+ "nptgmm",
32
+ }, f"'{gmm_type}' is an invalid GMM variant."
33
+ dev = arch_info.get_arch()
34
+ assert (
35
+ dev in _CONFIGS
36
+ ), f"No GMM configuration tuned for arch '{dev}'. Supported: {sorted(_CONFIGS)}."
37
+ arch_configs = _CONFIGS[dev]
38
+ assert (
39
+ "default" in arch_configs[gmm_type]
40
+ ), "Default configuration is absent."
41
+ key = "accumulate" if accumulate else "default"
42
+ return arch_configs[gmm_type][key]
43
+
44
+
45
+ # Common code shared by GMM and TGMM kernels.
46
+ # ------------------------------------------------------------------------------
47
+
48
+
49
+ # XCD remapping followed by 1D PID to 2D grid mapping.
50
+ @triton.jit
51
+ def _remap_xcd_tile_grid(
52
+ tile_in_mm,
53
+ num_row_tiles,
54
+ num_col_tiles,
55
+ GROUP_SIZE: tl.constexpr = 1,
56
+ NUM_XCDS: tl.constexpr = 8,
57
+ ):
58
+ return pid_grid(
59
+ remap_xcd(tile_in_mm, num_row_tiles * num_col_tiles, NUM_XCDS=NUM_XCDS),
60
+ num_row_tiles,
61
+ num_col_tiles,
62
+ GROUP_SIZE_M=GROUP_SIZE,
63
+ )
64
+
65
+
66
+ # GMM kernel.
67
+ # ------------------------------------------------------------------------------
68
+
69
+
70
+ @triton.heuristics(
71
+ {
72
+ "K_DIVISIBLE_BY_BLOCK_SIZE_K": lambda META: META["K"] % META["BLOCK_SIZE_K"]
73
+ == 0,
74
+ }
75
+ )
76
+ @triton.jit
77
+ def gmm_kernel(
78
+ # Tensor pointers:
79
+ lhs_ptr,
80
+ rhs_ptr,
81
+ group_sizes_ptr,
82
+ out_ptr,
83
+ bias_ptr,
84
+ # Tensor shapes:
85
+ M: int,
86
+ K: int,
87
+ N: int,
88
+ G: int,
89
+ # Meta-parameters:
90
+ TRANS_RHS: tl.constexpr,
91
+ BLOCK_SIZE_M: tl.constexpr,
92
+ BLOCK_SIZE_K: tl.constexpr,
93
+ BLOCK_SIZE_N: tl.constexpr,
94
+ K_DIVISIBLE_BY_BLOCK_SIZE_K: tl.constexpr,
95
+ GROUP_SIZE: tl.constexpr,
96
+ GRID_DIM: tl.constexpr,
97
+ USE_BIAS: tl.constexpr,
98
+ ):
99
+ tl.assume(M > 0)
100
+ tl.assume(K > 0)
101
+ tl.assume(N > 0)
102
+ tl.assume(G > 0)
103
+
104
+ num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
105
+ tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
106
+
107
+ # Current tile. Each program computes multiple tiles of each group.
108
+ tile = tl.program_id(0)
109
+ tl.device_assert(tile >= 0, "tile < 0 (at initialization)")
110
+
111
+ # Tile limit of last MM problem (inclusive).
112
+ last_mm_tile = 0
113
+
114
+ # Last input row of lhs and output row of out. Each group reads some rows of
115
+ # lhs and writes some rows to out.
116
+ last_m = 0
117
+
118
+ # Loop through all (m, K, N) MM problems:
119
+ # (m, K) x (K, N) = (m, N)
120
+ # sum(m) = M
121
+ for g in range(G):
122
+ # Get m dimension of current MM problem.
123
+ m = tl.load(group_sizes_ptr + g)
124
+ # m can be zero if group is empty
125
+ tl.device_assert(m >= 0, "m < 0")
126
+
127
+ num_m_tiles = tl.cdiv(m, BLOCK_SIZE_M)
128
+ # num_m_tiles can be zero if group is empty
129
+ tl.device_assert(num_m_tiles >= 0, "num_m_tiles < 0")
130
+
131
+ num_tiles = num_m_tiles * num_n_tiles
132
+ # num_tiles can be zero if group is empty
133
+ tl.device_assert(num_tiles >= 0, "num_tiles < 0")
134
+
135
+ # Loop through tiles of current MM problem.
136
+ while tile >= last_mm_tile and tile < last_mm_tile + num_tiles:
137
+ # Figure out tile coordinates in current MM problem.
138
+ tile_in_mm = tile - last_mm_tile
139
+ tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
140
+
141
+ tile_m, tile_n = _remap_xcd_tile_grid(
142
+ tile_in_mm, num_m_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
143
+ )
144
+
145
+ # Do regular MM:
146
+
147
+ tl.device_assert(tile_m * BLOCK_SIZE_M >= 0, "tile_m * BLOCK_SIZE_M < 0")
148
+ tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
149
+
150
+ offs_lhs_m = (
151
+ tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
152
+ ) % m
153
+ offs_rhs_n = (
154
+ tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
155
+ ) % N
156
+ offs_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
157
+
158
+ lhs_ptrs = lhs_ptr + (last_m + offs_lhs_m[:, None]) * K + offs_k[None, :]
159
+
160
+ if TRANS_RHS:
161
+ rhs_ptrs = (
162
+ rhs_ptr
163
+ + g.to(tl.int64) * K * N
164
+ + offs_k[:, None]
165
+ + offs_rhs_n[None, :] * K
166
+ )
167
+ else:
168
+ rhs_ptrs = (
169
+ rhs_ptr
170
+ + g.to(tl.int64) * K * N
171
+ + offs_k[:, None] * N
172
+ + offs_rhs_n[None, :]
173
+ )
174
+
175
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
176
+
177
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
178
+ if K_DIVISIBLE_BY_BLOCK_SIZE_K:
179
+ lhs = tl.load(lhs_ptrs)
180
+ rhs = tl.load(rhs_ptrs)
181
+ else:
182
+ k_mask_limit = K - k * BLOCK_SIZE_K
183
+ lhs = tl.load(
184
+ lhs_ptrs, mask=offs_k[None, :] < k_mask_limit, other=0
185
+ )
186
+ rhs = tl.load(
187
+ rhs_ptrs, mask=offs_k[:, None] < k_mask_limit, other=0
188
+ )
189
+
190
+ acc = tl.dot(lhs, rhs, acc=acc)
191
+
192
+ lhs_ptrs += BLOCK_SIZE_K
193
+
194
+ if TRANS_RHS:
195
+ rhs_ptrs += BLOCK_SIZE_K
196
+ else:
197
+ rhs_ptrs += BLOCK_SIZE_K * N
198
+
199
+ # Add bias if enabled
200
+ if USE_BIAS:
201
+ offs_bias_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(
202
+ 0, BLOCK_SIZE_N
203
+ )
204
+ bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bias_n
205
+ bias = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0)
206
+ # Convert bias to float32 to match accumulator precision
207
+ bias = bias.to(tl.float32)
208
+ # Broadcast bias across M dimension and add in float32
209
+ acc += bias[None, :]
210
+
211
+ # Convert to output dtype after all computations
212
+ acc = acc.to(out_ptr.type.element_ty)
213
+
214
+ offs_out_m = tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
215
+ offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
216
+
217
+ out_ptrs = (
218
+ out_ptr + (last_m + offs_out_m[:, None]) * N + offs_out_n[None, :]
219
+ )
220
+
221
+ tl.store(
222
+ out_ptrs,
223
+ acc,
224
+ mask=(offs_out_m[:, None] < m) & (offs_out_n[None, :] < N),
225
+ )
226
+
227
+ # Go to the next tile by advancing number of programs.
228
+ tile += GRID_DIM
229
+ tl.device_assert(tile > 0, "tile <= 0 (at update)")
230
+
231
+ # Get ready to go to the next MM problem.
232
+
233
+ last_mm_tile += num_tiles
234
+ # last_mm_tile can be zero if group 0 is skipped
235
+ tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)")
236
+
237
+ last_m += m
238
+ # last_m can be zero if group 0 is skipped
239
+ tl.device_assert(last_m >= 0, "last_m < 0 (at update)")
240
+ tl.device_assert(last_m <= M, "last_m > M (at update)")
241
+
242
+
243
+ # Persistent TGMM kernel.
244
+ # ------------------------------------------------------------------------------
245
+
246
+
247
+ @triton.jit
248
+ def tgmm_persistent_kernel(
249
+ # Tensor pointers:
250
+ lhs_ptr,
251
+ rhs_ptr,
252
+ group_sizes_ptr,
253
+ out_ptr,
254
+ bias_grad_ptr,
255
+ # Tensor shapes:
256
+ M: int,
257
+ K: int,
258
+ N: int,
259
+ G: int,
260
+ # Meta-parameters:
261
+ TRANS_LHS: tl.constexpr,
262
+ BLOCK_SIZE_M: tl.constexpr,
263
+ BLOCK_SIZE_K: tl.constexpr,
264
+ BLOCK_SIZE_N: tl.constexpr,
265
+ GROUP_SIZE: tl.constexpr,
266
+ GRID_DIM: tl.constexpr,
267
+ COMPUTE_BIAS_GRAD: tl.constexpr,
268
+ ACCUMULATE: tl.constexpr,
269
+ ):
270
+ tl.assume(M > 0)
271
+ tl.assume(K > 0)
272
+ tl.assume(N > 0)
273
+ tl.assume(G > 0)
274
+
275
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
276
+ tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0")
277
+
278
+ num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
279
+ tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
280
+
281
+ num_tiles = num_k_tiles * num_n_tiles
282
+ tl.device_assert(num_tiles > 0, "num_tiles <= 0")
283
+
284
+ # Current tile. Each program computes multiple tiles of each group.
285
+ tile = tl.program_id(0)
286
+ tl.device_assert(tile >= 0, "tile < 0 (at initialization)")
287
+
288
+ # Tile limit of last MM problem (inclusive).
289
+ last_mm_tile = 0
290
+
291
+ # Last input column of lhs and input row of rhs. Each group reads some
292
+ # columns of lhs and some rows of rhs.
293
+ last_m = 0
294
+
295
+ # Loop through all (K, m, N) MM problems:
296
+ # (K, m) x (m, N) = (K, N)
297
+ # sum(m) = M
298
+ for g in range(G):
299
+ # Get m dimension of current MM problem.
300
+ m = tl.load(group_sizes_ptr + g)
301
+ # m can be zero if group is empty
302
+ tl.device_assert(m >= 0, "m < 0")
303
+
304
+ # Loop through tiles of current MM problem.
305
+ while tile >= last_mm_tile and tile < last_mm_tile + num_tiles:
306
+ # Figure out tile coordinates in current MM problem.
307
+ tile_in_mm = tile - last_mm_tile
308
+ tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
309
+
310
+ tile_k, tile_n = _remap_xcd_tile_grid(
311
+ tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
312
+ )
313
+
314
+ # Do regular MM:
315
+
316
+ tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0")
317
+ tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
318
+
319
+ offs_lhs_k = (
320
+ tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
321
+ ) % K
322
+ offs_rhs_n = (
323
+ tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
324
+ ) % N
325
+ offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
326
+
327
+ if TRANS_LHS:
328
+ lhs_ptrs = (
329
+ lhs_ptr + offs_lhs_k[:, None] + (last_m + offs_m[None, :]) * K
330
+ )
331
+ else:
332
+ lhs_ptrs = (
333
+ lhs_ptr + offs_lhs_k[:, None] * M + (last_m + offs_m[None, :])
334
+ )
335
+
336
+ rhs_ptrs = rhs_ptr + (last_m + offs_m[:, None]) * N + offs_rhs_n[None, :]
337
+
338
+ loop_m = tl.cdiv(m, BLOCK_SIZE_M)
339
+ m_divisible_by_block_m = m % BLOCK_SIZE_M == 0
340
+ if not m_divisible_by_block_m:
341
+ loop_m -= 1
342
+
343
+ acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32)
344
+
345
+ # Initialize bias accumulator
346
+ bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
347
+
348
+ for _ in range(0, loop_m):
349
+ lhs = tl.load(lhs_ptrs)
350
+ rhs = tl.load(rhs_ptrs)
351
+
352
+ acc = tl.dot(lhs, rhs, acc=acc)
353
+
354
+ # Accumulate for bias gradient: sum lhs across M dimension
355
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
356
+ bias_acc += tl.sum(
357
+ lhs, axis=1
358
+ ) # Sum across M dimension [K, M] -> [K]
359
+
360
+ if TRANS_LHS:
361
+ lhs_ptrs += BLOCK_SIZE_M * K
362
+ else:
363
+ lhs_ptrs += BLOCK_SIZE_M
364
+
365
+ rhs_ptrs += BLOCK_SIZE_M * N
366
+
367
+ if not m_divisible_by_block_m:
368
+ offs_lhs_k = (
369
+ tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
370
+ ) % K
371
+ offs_rhs_n = (
372
+ tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
373
+ ) % N
374
+ offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
375
+ lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0)
376
+ rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0)
377
+ acc = tl.dot(lhs, rhs, acc=acc)
378
+
379
+ # Accumulate last chunk for bias gradient
380
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
381
+ bias_acc += tl.sum(lhs, axis=1)
382
+
383
+ acc = acc.to(out_ptr.type.element_ty)
384
+
385
+ offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
386
+ offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
387
+
388
+ out_ptrs = (
389
+ out_ptr
390
+ + g.to(tl.int64) * K * N
391
+ + offs_out_k[:, None] * N
392
+ + offs_out_n[None, :]
393
+ )
394
+
395
+ mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N)
396
+ if ACCUMULATE:
397
+ # Load existing values and add to them (like beta=1 in BLAS)
398
+ old_vals = tl.load(out_ptrs, mask=mask, other=0.0)
399
+ tl.store(out_ptrs, acc + old_vals, mask=mask)
400
+ else:
401
+ # Overwrite output (like beta=0 in BLAS)
402
+ tl.store(out_ptrs, acc, mask=mask)
403
+
404
+ # Store bias gradient (only for first N tile, sum across all M)
405
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
406
+ # Keep as float32 for atomic_add (bf16 not supported for atomics)
407
+ bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k
408
+ # Use atomic add since multiple K-tiles may write to same expert's bias
409
+ tl.atomic_add(
410
+ bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed"
411
+ )
412
+
413
+ # Go to the next tile by advancing number of programs.
414
+ tile += GRID_DIM
415
+ tl.device_assert(tile > 0, "tile <= 0 (at update)")
416
+
417
+ # Get ready to go to the next MM problem.
418
+
419
+ last_mm_tile += num_tiles
420
+ # last_mm_tile can be zero if group 0 is skipped
421
+ tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)")
422
+
423
+ last_m += m
424
+ # last_m can be zero if group 0 is skipped
425
+ tl.device_assert(last_m >= 0, "last_m < 0 (at update)")
426
+ tl.device_assert(last_m <= M, "last_m > M (at update)")
427
+
428
+
429
+ # Regular non-persistent TGMM kernel.
430
+ # ------------------------------------------------------------------------------
431
+
432
+
433
+ @triton.heuristics({"BLOCK_SIZE_G": lambda META: triton.next_power_of_2(META["G"])})
434
+ @triton.jit
435
+ def tgmm_non_persistent_kernel(
436
+ # Tensor pointers:
437
+ lhs_ptr,
438
+ rhs_ptr,
439
+ group_sizes_ptr,
440
+ out_ptr,
441
+ bias_grad_ptr,
442
+ # Tensor shapes:
443
+ M: int,
444
+ K: int,
445
+ N: int,
446
+ G: int,
447
+ # Meta-parameters:
448
+ TRANS_LHS: tl.constexpr,
449
+ BLOCK_SIZE_G: tl.constexpr,
450
+ BLOCK_SIZE_M: tl.constexpr,
451
+ BLOCK_SIZE_K: tl.constexpr,
452
+ BLOCK_SIZE_N: tl.constexpr,
453
+ GROUP_SIZE: tl.constexpr,
454
+ COMPUTE_BIAS_GRAD: tl.constexpr,
455
+ ACCUMULATE: tl.constexpr,
456
+ ):
457
+ tl.assume(M > 0)
458
+ tl.assume(K > 0)
459
+ tl.assume(N > 0)
460
+ tl.assume(G > 0)
461
+
462
+ # Get group ID from grid.
463
+ g = tl.program_id(0)
464
+ tl.device_assert(g >= 0, "g < 0")
465
+ tl.device_assert(g < G, "g >= G")
466
+
467
+ # Get m dimension of current MM group.
468
+ m = tl.load(group_sizes_ptr + g)
469
+ # m can be zero if group is empty.
470
+ tl.device_assert(m >= 0, "m < 0")
471
+
472
+ # Skip empty groups.
473
+ if m == 0:
474
+ return
475
+
476
+ # Compute sum(group_sizes) until current group g.
477
+ # It's the starting column of lhs and starting row of rhs.
478
+ offs_g = tl.arange(0, BLOCK_SIZE_G)
479
+ group_sizes = tl.load(group_sizes_ptr + offs_g, mask=offs_g < g, other=0)
480
+ start_m = tl.sum(group_sizes)
481
+
482
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
483
+ tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0")
484
+
485
+ num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
486
+ tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
487
+
488
+ # Get MM tile from grid.
489
+ tile_in_mm = tl.program_id(1)
490
+ tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
491
+
492
+ tile_k, tile_n = _remap_xcd_tile_grid(
493
+ tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
494
+ )
495
+
496
+ tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0")
497
+ tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
498
+
499
+ offs_lhs_k = (tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K
500
+ offs_rhs_n = (tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
501
+ offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
502
+
503
+ if TRANS_LHS:
504
+ lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] + (start_m + offs_m[None, :]) * K
505
+ else:
506
+ lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] * M + (start_m + offs_m[None, :])
507
+
508
+ rhs_ptrs = rhs_ptr + (start_m + offs_m[:, None]) * N + offs_rhs_n[None, :]
509
+
510
+ loop_m = tl.cdiv(m, BLOCK_SIZE_M)
511
+ m_divisible_by_block_m = m % BLOCK_SIZE_M == 0
512
+ if not m_divisible_by_block_m:
513
+ loop_m -= 1
514
+
515
+ acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32)
516
+ # Initialize bias accumulator
517
+ bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
518
+
519
+ for _ in range(0, loop_m):
520
+ lhs = tl.load(lhs_ptrs)
521
+ rhs = tl.load(rhs_ptrs)
522
+
523
+ acc = tl.dot(lhs, rhs, acc=acc)
524
+
525
+ # Accumulate for bias gradient: sum lhs across M dimension
526
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
527
+ bias_acc += tl.sum(lhs, axis=1) # [K, M] -> [K]
528
+
529
+ if TRANS_LHS:
530
+ lhs_ptrs += BLOCK_SIZE_M * K
531
+ else:
532
+ lhs_ptrs += BLOCK_SIZE_M
533
+
534
+ rhs_ptrs += BLOCK_SIZE_M * N
535
+
536
+ if not m_divisible_by_block_m:
537
+ offs_lhs_k = (
538
+ tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
539
+ ) % K
540
+ offs_rhs_n = (
541
+ tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
542
+ ) % N
543
+ offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
544
+ lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0)
545
+ rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0)
546
+ acc = tl.dot(lhs, rhs, acc=acc)
547
+ # Accumulate last chunk for bias gradient
548
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
549
+ bias_acc += tl.sum(lhs, axis=1)
550
+
551
+ acc = acc.to(out_ptr.type.element_ty)
552
+
553
+ offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
554
+ offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
555
+
556
+ out_ptrs = (
557
+ out_ptr + g.to(tl.int64) * K * N + offs_out_k[:, None] * N + offs_out_n[None, :]
558
+ )
559
+
560
+ mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N)
561
+ if ACCUMULATE:
562
+ # Load existing values and add to them (like beta=1 in BLAS)
563
+ old_vals = tl.load(out_ptrs, mask=mask, other=0.0)
564
+ tl.store(out_ptrs, acc + old_vals, mask=mask)
565
+ else:
566
+ # Overwrite output (like beta=0 in BLAS)
567
+ tl.store(out_ptrs, acc, mask=mask)
568
+
569
+ # Store bias gradient (only for first N tile, sum across all M)
570
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
571
+ # Keep as float32 for atomic_add (bf16/fp16 not supported for atomics)
572
+ bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k
573
+ # Use atomic add since multiple K-tiles may write to same expert's bias
574
+ tl.atomic_add(bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed")
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/adapter.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Adapt AITER's Triton grouped GEMM to MegaBlocks' ``gmm`` calling convention.
3
+
4
+ MegaBlocks (following tgale96/grouped_gemm) uses a single ``gmm`` entry point
5
+ with ``trans_a`` / ``trans_b`` flags:
6
+
7
+ * ``trans_a=False, trans_b=False``: a(M,K) @ b(G,K,N) -> c(M,N)
8
+ * ``trans_a=False, trans_b=True`` : a(M,K) @ b(G,N,K)^T -> c(M,N) (dgrad)
9
+ * ``trans_a=True`` : a(M,K)^T @ b(M,N) per group -> c(G,K,N) (wgrad)
10
+
11
+ AITER exposes these as two kernels: ``gmm`` ((M,K)@(G,K,N)->(M,N), transposition
12
+ of the 3D operand inferred from strides) and ``ptgmm`` ((K,M)@(M,N)->(G,K,N),
13
+ transposition of the 2D operand inferred from strides).
14
+ """
15
+
16
+ import torch
17
+
18
+ from .gmm import gmm as _aiter_gmm
19
+ from .gmm import ptgmm as _aiter_ptgmm
20
+
21
+
22
+ def gmm(a, b, c, batch_sizes, trans_a=False, trans_b=False):
23
+ # AITER requires group sizes to be int32 and to live on the compute device.
24
+ group_sizes = batch_sizes.to(device=a.device, dtype=torch.int32)
25
+
26
+ # AITER asserts exact strides: gmm wants lhs/rhs row-major (a transposed
27
+ # 3D operand must be exactly column-major), tgmm wants rhs row-major and
28
+ # lhs row/column-major. Make operands contiguous first so the transposed
29
+ # views have the precise strides the kernels expect. `.contiguous()` is a
30
+ # no-op when the tensor is already contiguous.
31
+ if trans_a:
32
+ # Weight gradient: a(M,K), b(M,N) -> c(G,K,N).
33
+ # Pass a transposed so AITER sees lhs(K,M) column-major (TRANS_LHS).
34
+ _aiter_ptgmm(
35
+ a.contiguous().transpose(0, 1),
36
+ b.contiguous(),
37
+ group_sizes,
38
+ preferred_element_type=c.dtype,
39
+ existing_out=c,
40
+ )
41
+ else:
42
+ # trans_b contracts b's last dim: pass a column-major (G,K,N) view.
43
+ rhs = b.contiguous()
44
+ if trans_b:
45
+ rhs = rhs.transpose(1, 2)
46
+ _aiter_gmm(
47
+ a.contiguous(),
48
+ rhs,
49
+ group_sizes,
50
+ preferred_element_type=c.dtype,
51
+ existing_out=c,
52
+ )
53
+ return c
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/configs.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+ # Tuned GMM configs vendored from ROCm/aiter (aiter/ops/triton/configs/).
3
+ # Inlined as a Python module so packaging always includes them.
4
+
5
+ CONFIGS = {'gfx1250': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx942': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx950': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}}
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/gmm.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+ # Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+
5
+ # Imports.
6
+ # ------------------------------------------------------------------------------
7
+
8
+ # PyTorch
9
+ import torch
10
+ from torch import Tensor
11
+
12
+ # Triton
13
+ import triton
14
+
15
+ # AITER: GMM utility functions
16
+ from .utils.gmm_common import (
17
+ DTYPE,
18
+ is_power_of_2,
19
+ check_input_device_dtype,
20
+ check_bias_shape_stride,
21
+ get_gmm_shape,
22
+ get_gmm_output,
23
+ get_gmm_transposition,
24
+ get_tgmm_shape,
25
+ get_tgmm_output,
26
+ get_tgmm_bias_grad,
27
+ get_tgmm_transposition,
28
+ )
29
+
30
+ # AITER: GMM Triton kernels
31
+ from ._triton_kernels.gmm import (
32
+ gmm_kernel,
33
+ tgmm_persistent_kernel,
34
+ tgmm_non_persistent_kernel,
35
+ get_config,
36
+ )
37
+
38
+ # GMM PyTorch wrapper.
39
+ # ------------------------------------------------------------------------------
40
+
41
+
42
+ def _gmm_grid(
43
+ N: int,
44
+ block_size_m: int,
45
+ block_size_n: int,
46
+ group_sizes: Tensor,
47
+ grid_dim: int,
48
+ ) -> tuple[int]:
49
+ assert N > 0, f"N must be positive, it's {N}."
50
+ assert is_power_of_2(
51
+ block_size_m
52
+ ), f"M-dimension tile size must be a power of 2 (it's {block_size_m})."
53
+ assert is_power_of_2(
54
+ block_size_n
55
+ ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
56
+ assert torch.all(group_sizes >= 0).item(), "All group_sizes must be non-negative."
57
+ assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})."
58
+ num_m_tiles = (group_sizes + block_size_m - 1) // block_size_m
59
+ assert torch.all(num_m_tiles >= 0).item(), "All num_m_tiles must be non-negative."
60
+ num_n_tiles = triton.cdiv(N, block_size_n)
61
+ assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
62
+ num_tiles = torch.sum(num_m_tiles * num_n_tiles).item()
63
+ assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}."
64
+ num_programs = int(min(grid_dim, num_tiles))
65
+ assert num_programs > 0, f"num_programs must be positive, it's {num_programs}."
66
+ return (num_programs,)
67
+
68
+
69
+ def gmm(
70
+ lhs: Tensor,
71
+ rhs: Tensor,
72
+ group_sizes: Tensor,
73
+ preferred_element_type: torch.dtype = DTYPE,
74
+ existing_out: Tensor | None = None,
75
+ config: dict[str, int] | None = None,
76
+ bias: Tensor | None = None,
77
+ ) -> Tensor:
78
+ """
79
+ Perform Group Matrix Multiplication (GMM): out = lhs @ rhs + bias
80
+
81
+ lhs rows are divided into G groups. Each group of lhs rows is matrix multiplied with a plane of
82
+ rhs 3D tensor and then stored in a slice of out. In PyTorch parlance, it can be implemented as
83
+ follows for a given group g:
84
+ out[group_start:group_end, :] = lhs[group_start:group_end, :] @ rhs[g] + bias[g]
85
+
86
+ The size of each group, and their respective start and end positions are specified by
87
+ group_sizes tensor. For instance, suppose that group_sizes = [3, 2, 4, 1]. In this particular
88
+ case we have 4 groups. The 1st group starts at 0 and ends at 2, the second group starts at 3 and
89
+ ends at 4, the third group starts at 5 and ends at 8, and the fourth and final group consists of
90
+ just the 10th (last) row of lhs.
91
+
92
+ Parameters
93
+ ----------
94
+ lhs : torch.Tensor
95
+ Left-hand side 2D input tensor. Shape: (M, K).
96
+ lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
97
+ lhs must be on the same device of rhs and group_sizes.
98
+ rhs : torch.Tensor
99
+ Right-hand side 3D input tensor. Shape: (G, K, N).
100
+ rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
101
+ rhs must be on the same device of lhs and group_sizes.
102
+ group_sizes : torch.Tensor
103
+ 1D input tensor describing group sizes. Shape: (G,).
104
+ group_sizes data type must be torch.int32 and all its elements must be non-negative.
105
+ group_sizes must be on the same device of lhs and rhs.
106
+ preferred_element_type : torch.dtype, optional
107
+ Desired data type for output tensor. Default is torch.bfloat16.
108
+ Supported output types are torch.float16 and torch.bfloat16.
109
+ existing_out : torch.Tensor or None, optional
110
+ Preallocated output tensor. Default is None.
111
+ If provided, results are written into this tensor. Otherwise, a new output tensor is
112
+ allocated.
113
+ If provided then it must have shape (M, N), its data type must match preferred_element_type
114
+ and it must be on the same device of other input tensors.
115
+ config : dict[str, int] or None, optional
116
+ Optional dictionary with kernel metaparameters. If absent, config will be queried from
117
+ internal tuning database.
118
+ bias : torch.Tensor or None, optional
119
+ Optional bias tensor. Shape: (G, N).
120
+ If provided, bias data type must match lhs and rhs data type, and bias must be on the same
121
+ device as other input tensors. Each group g adds bias[g] to the output.
122
+
123
+ Returns
124
+ -------
125
+ torch.Tensor
126
+ The computed output 2D tensor. Shape: (M, N).
127
+ Output tensor data type is given by preferred_element_type.
128
+ If existing_out is provided then existing_out is also returned.
129
+
130
+ Implementation Notes
131
+ --------------------
132
+ - GMM is implemented with a persistent Triton kernel.
133
+ - lhs must be row-major (lhs.stride() == (K, 1)).
134
+ - rhs can be row-major (rhs.stride() == (K * N, N, 1)) or column-major (rhs.stride() ==
135
+ (K * N, 1, K)). If rhs is row-major then kernel parameter TRANS_RHS == False, this is useful
136
+ for implementing forward pass. If rhs is column-major then kernel parameter TRANS_RHS == True,
137
+ this is useful for computing the lhs derivative in the backward pass, while fusing the
138
+ transposition.
139
+ - out must be row-major (out.stride() == (N, 1)).
140
+ - bias must be row-major (bias.stride() == (N, 1)) if provided.
141
+ """
142
+ use_bias = bias is not None
143
+ check_input_device_dtype(lhs, rhs, group_sizes, bias)
144
+
145
+ M, K, N, G = get_gmm_shape(lhs, rhs, group_sizes)
146
+
147
+ if use_bias:
148
+ check_bias_shape_stride(bias, G, N)
149
+
150
+ out = get_gmm_output(
151
+ M,
152
+ N,
153
+ device=lhs.device,
154
+ preferred_element_type=preferred_element_type,
155
+ existing_out=existing_out,
156
+ )
157
+
158
+ trans_rhs, _ = get_gmm_transposition(lhs, rhs, out)
159
+
160
+ if config is None:
161
+ config = get_config("gmm", M, K, N, G)
162
+
163
+ assert all(
164
+ key in config
165
+ and isinstance(config[key], int)
166
+ and (
167
+ is_power_of_2(config[key])
168
+ if key.startswith("BLOCK_SIZE_")
169
+ else config[key] > 0
170
+ )
171
+ for key in {
172
+ "BLOCK_SIZE_M",
173
+ "BLOCK_SIZE_K",
174
+ "BLOCK_SIZE_N",
175
+ "GROUP_SIZE",
176
+ "GRID_DIM",
177
+ }
178
+ ), "Invalid GMM kernel config."
179
+
180
+ grid = _gmm_grid(
181
+ N,
182
+ config["BLOCK_SIZE_M"],
183
+ config["BLOCK_SIZE_N"],
184
+ group_sizes,
185
+ config["GRID_DIM"],
186
+ )
187
+
188
+ # fmt: off
189
+ gmm_kernel[grid](
190
+ # Tensor pointers:
191
+ lhs, rhs, group_sizes, out, bias,
192
+ # Tensor shapes:
193
+ M, K, N, G,
194
+ # Meta-parameters:
195
+ TRANS_RHS=trans_rhs,
196
+ USE_BIAS=use_bias,
197
+ **config,
198
+ )
199
+ # fmt: on
200
+
201
+ return out
202
+
203
+
204
+ # Persistent TGMM PyTorch wrapper.
205
+ # ------------------------------------------------------------------------------
206
+
207
+
208
+ def _ptgmm_grid(
209
+ K: int,
210
+ N: int,
211
+ G: int,
212
+ block_size_k: int,
213
+ block_size_n: int,
214
+ grid_dim: int,
215
+ ) -> tuple[int]:
216
+ assert K > 0, f"K must be positive, it's {K}."
217
+ assert N > 0, f"N must be positive, it's {N}."
218
+ assert G > 0, f"G must be positive, it's {G}."
219
+ assert is_power_of_2(
220
+ block_size_k
221
+ ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})."
222
+ assert is_power_of_2(
223
+ block_size_n
224
+ ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
225
+ assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})."
226
+ num_k_tiles = triton.cdiv(K, block_size_k)
227
+ assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}."
228
+ num_n_tiles = triton.cdiv(N, block_size_n)
229
+ assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
230
+ num_tiles = G * num_k_tiles * num_n_tiles
231
+ assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}."
232
+ num_programs = min(grid_dim, num_tiles)
233
+ assert num_programs > 0, f"num_programs must be positive, it's {num_programs}."
234
+ return (num_programs,)
235
+
236
+
237
+ def ptgmm(
238
+ lhs: Tensor,
239
+ rhs: Tensor,
240
+ group_sizes: Tensor,
241
+ preferred_element_type: torch.dtype = DTYPE,
242
+ existing_out: Tensor | None = None,
243
+ config: dict[str, int] | None = None,
244
+ bias_grad: Tensor | None = None,
245
+ accumulate: bool = False,
246
+ ) -> Tensor:
247
+ """
248
+ Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs
249
+
250
+ lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with
251
+ the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch
252
+ parlance, it can be implemented as follows for a given group g:
253
+ out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :]
254
+
255
+ The 't' in the operator name derives from MaxText implementation
256
+ (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py),
257
+ which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor
258
+ shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N).
259
+
260
+ The 'p' in the operator name means that it is implemented with a persistent kernel. There is
261
+ also the non-persistent variation, which is implemented with a regular kernel. Please take a
262
+ look at nptgmm operator. Both ptgmm and nptgmm implement the same computation, choosing one or
263
+ the other is a matter of performance for the target workload.
264
+
265
+ Parameters
266
+ ----------
267
+ lhs : torch.Tensor
268
+ Left-hand side 2D input tensor. Shape: (K, M).
269
+ lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
270
+ lhs must be on the same device of rhs and group_sizes.
271
+ rhs : torch.Tensor
272
+ Right-hand side 2D input tensor. Shape: (M, N).
273
+ rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
274
+ rhs must be on the same device of lhs and group_sizes.
275
+ group_sizes : torch.Tensor
276
+ 1D input tensor describing group sizes. Shape: (G,).
277
+ group_sizes data type must be torch.int32 and all its elements must be non-negative.
278
+ group_sizes must be on the same device of lhs and rhs.
279
+ preferred_element_type : torch.dtype, optional
280
+ Desired data type for output tensor. Default is torch.bfloat16.
281
+ Supported output types are torch.float16 and torch.bfloat16.
282
+ existing_out : torch.Tensor or None, optional
283
+ Preallocated output tensor. Default is None.
284
+ If provided, results are written into this tensor. Otherwise, a new output tensor is
285
+ allocated.
286
+ If provided then it must have shape (G, K, N), its data type must match
287
+ preferred_element_type and it must be on the same device of other input tensors.
288
+ config : dict[str, int] or None, optional
289
+ Optional dictionary with kernel metaparameters. If absent, config will be queried from
290
+ internal tuning database.
291
+ bias_grad : torch.Tensor or None, optional
292
+ Optional bias gradient output tensor. Shape: (G, K).
293
+ If provided, the kernel will compute the bias gradient and write it to this tensor.
294
+ bias_grad must be torch.float32 (kernel uses atomic_add which requires float32),
295
+ accumulate : bool, optional
296
+ Whether to accumulate into existing output tensor values. Default is False.
297
+ If False, output will be overwritten with fresh computation.
298
+ If True, results will be added to existing output tensor values.
299
+
300
+ Returns
301
+ -------
302
+ torch.Tensor
303
+ The computed output 3D tensor. Shape: (G, K, N).
304
+ Output tensor data type is given by preferred_element_type.
305
+ If existing_out is provided then existing_out is also returned.
306
+
307
+ Implementation Notes
308
+ --------------------
309
+ - PTGMM is implemented with a persistent Triton kernel.
310
+ - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs
311
+ is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel
312
+ parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward
313
+ pass, while fusing the transposition.
314
+ - rhs must be row-major (rhs.stride() == (N, 1)).
315
+ - out must be row-major (out.stride() == (K * N, N, 1)).
316
+ """
317
+ check_input_device_dtype(lhs, rhs, group_sizes)
318
+
319
+ M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes)
320
+
321
+ out = get_tgmm_output(
322
+ K,
323
+ N,
324
+ G,
325
+ device=lhs.device,
326
+ preferred_element_type=preferred_element_type,
327
+ existing_out=existing_out,
328
+ )
329
+
330
+ trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out)
331
+
332
+ if config is None:
333
+ config = get_config("ptgmm", M, K, N, G, accumulate)
334
+
335
+ assert all(
336
+ key in config
337
+ and isinstance(config[key], int)
338
+ and (
339
+ is_power_of_2(config[key])
340
+ if key.startswith("BLOCK_SIZE_")
341
+ else config[key] > 0
342
+ )
343
+ for key in {
344
+ "BLOCK_SIZE_M",
345
+ "BLOCK_SIZE_K",
346
+ "BLOCK_SIZE_N",
347
+ "GROUP_SIZE",
348
+ "GRID_DIM",
349
+ }
350
+ ), "Invalid PTGMM kernel config."
351
+
352
+ # Bias gradient handling.
353
+ # -----------------------
354
+ # Get or validate bias gradient tensor.
355
+ compute_bias_grad = bias_grad is not None
356
+ bias_grad_ptr = get_tgmm_bias_grad(
357
+ K,
358
+ G,
359
+ device=lhs.device,
360
+ existing_bias_grad=bias_grad,
361
+ )
362
+
363
+ grid = _ptgmm_grid(
364
+ K,
365
+ N,
366
+ G,
367
+ config["BLOCK_SIZE_K"],
368
+ config["BLOCK_SIZE_N"],
369
+ config["GRID_DIM"],
370
+ )
371
+
372
+ # fmt: off
373
+ tgmm_persistent_kernel[grid](
374
+ # Tensor pointers:
375
+ lhs, rhs, group_sizes, out, bias_grad_ptr,
376
+ # Tensor shapes:
377
+ M, K, N, G,
378
+ # Meta-parameters:
379
+ TRANS_LHS=trans_lhs,
380
+ COMPUTE_BIAS_GRAD=compute_bias_grad,
381
+ ACCUMULATE=accumulate,
382
+ **config,
383
+ )
384
+ # fmt: on
385
+
386
+ return out
387
+
388
+
389
+ # Regular non-persistent TGMM PyTorch wrapper.
390
+ # ------------------------------------------------------------------------------
391
+
392
+
393
+ def _nptgmm_grid(
394
+ K: int,
395
+ N: int,
396
+ G: int,
397
+ block_size_k: int,
398
+ block_size_n: int,
399
+ ) -> tuple[int, int]:
400
+ assert K > 0, f"K must be positive, it's {K}."
401
+ assert N > 0, f"N must be positive, it's {N}."
402
+ assert G > 0, f"G must be positive, it's {G}."
403
+ assert is_power_of_2(
404
+ block_size_k
405
+ ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})."
406
+ assert is_power_of_2(
407
+ block_size_n
408
+ ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
409
+ num_k_tiles = triton.cdiv(K, block_size_k)
410
+ assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}."
411
+ num_n_tiles = triton.cdiv(N, block_size_n)
412
+ assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
413
+ num_tiles_per_mm = num_k_tiles * num_n_tiles
414
+ assert (
415
+ num_tiles_per_mm > 0
416
+ ), f"num_tiles_per_mm must be positive, it's {num_tiles_per_mm}."
417
+ return (G, num_tiles_per_mm)
418
+
419
+
420
+ def nptgmm(
421
+ lhs: Tensor,
422
+ rhs: Tensor,
423
+ group_sizes: Tensor,
424
+ preferred_element_type: torch.dtype = DTYPE,
425
+ existing_out: Tensor | None = None,
426
+ config: dict[str, int] | None = None,
427
+ bias_grad: Tensor | None = None,
428
+ accumulate: bool = False,
429
+ ) -> Tensor:
430
+ """
431
+ Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs
432
+
433
+ lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with
434
+ the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch
435
+ parlance, it can be implemented as follows for a given group g:
436
+ out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :]
437
+
438
+ The 't' in the operator name derives from MaxText implementation
439
+ (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py),
440
+ which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor
441
+ shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N).
442
+
443
+ The 'np' in the operator name means that it is implemented with a non-persistent, i.e. regular
444
+ kernel. There is also the persistent variation, which is implemented with a persistent kernel.
445
+ Please take a look at ptgmm operator. Both nptgmm and ptgmm implement the same computation,
446
+ choosing one or the other is a matter of performance for the target workload.
447
+
448
+ Parameters
449
+ ----------
450
+ lhs : torch.Tensor
451
+ Left-hand side 2D input tensor. Shape: (K, M).
452
+ lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
453
+ lhs must be on the same device of rhs and group_sizes.
454
+ rhs : torch.Tensor
455
+ Right-hand side 2D input tensor. Shape: (M, N).
456
+ rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
457
+ rhs must be on the same device of lhs and group_sizes.
458
+ group_sizes : torch.Tensor
459
+ 1D input tensor describing group sizes. Shape: (G,).
460
+ group_sizes data type must be torch.int32 and all its elements must be non-negative.
461
+ group_sizes must be on the same device of lhs and rhs.
462
+ preferred_element_type : torch.dtype, optional
463
+ Desired data type for output tensor. Default is torch.bfloat16.
464
+ Supported output types are torch.float16 and torch.bfloat16.
465
+ existing_out : torch.Tensor or None, optional
466
+ Preallocated output tensor. Default is None.
467
+ If provided, results are written into this tensor. Otherwise, a new output tensor is
468
+ allocated.
469
+ If provided then it must have shape (G, K, N), its data type must match
470
+ preferred_element_type and it must be on the same device of other input tensors.
471
+ config : dict[str, int] or None, optional
472
+ Optional dictionary with kernel metaparameters. If absent, config will be queried from
473
+ internal tuning database.
474
+ bias_grad : torch.Tensor or None, optional
475
+ Optional bias gradient output tensor. Shape: (G, K).
476
+ If provided, the kernel will compute the bias gradient and write it to this tensor.
477
+ bias_grad must be torch.float32 (kernel uses atomic_add which requires float32),
478
+ accumulate : bool, optional
479
+ Whether to accumulate into existing output tensor values. Default is False.
480
+ If False, output will be overwritten with fresh computation.
481
+ If True, results will be added to existing output tensor values.
482
+
483
+ Returns
484
+ -------
485
+ torch.Tensor
486
+ The computed output 3D tensor. Shape: (G, K, N).
487
+ Output tensor data type is given by preferred_element_type.
488
+ If existing_out is provided then existing_out is also returned.
489
+
490
+ Implementation Notes
491
+ --------------------
492
+ - NPTGMM is implemented with a non-persistent regular Triton kernel.
493
+ - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs
494
+ is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel
495
+ parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward
496
+ pass, while fusing the transposition.
497
+ - rhs must be row-major (rhs.stride() == (N, 1)).
498
+ - out must be row-major (out.stride() == (K * N, N, 1)).
499
+ """
500
+ check_input_device_dtype(lhs, rhs, group_sizes)
501
+
502
+ M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes)
503
+
504
+ out = get_tgmm_output(
505
+ K,
506
+ N,
507
+ G,
508
+ device=lhs.device,
509
+ preferred_element_type=preferred_element_type,
510
+ existing_out=existing_out,
511
+ )
512
+
513
+ trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out)
514
+
515
+ # Bias gradient handling.
516
+ # -----------------------
517
+ # Get or validate bias gradient tensor.
518
+ compute_bias_grad = bias_grad is not None
519
+ bias_grad_ptr = get_tgmm_bias_grad(
520
+ K,
521
+ G,
522
+ device=lhs.device,
523
+ existing_bias_grad=bias_grad,
524
+ )
525
+
526
+ if config is None:
527
+ config = get_config("nptgmm", M, K, N, G, accumulate)
528
+
529
+ assert all(
530
+ key in config
531
+ and isinstance(config[key], int)
532
+ and (
533
+ is_power_of_2(config[key])
534
+ if key.startswith("BLOCK_SIZE_")
535
+ else config[key] > 0
536
+ )
537
+ for key in {
538
+ "BLOCK_SIZE_M",
539
+ "BLOCK_SIZE_K",
540
+ "BLOCK_SIZE_N",
541
+ "GROUP_SIZE",
542
+ }
543
+ ), "Invalid NPTGMM kernel config."
544
+
545
+ grid = _nptgmm_grid(
546
+ K,
547
+ N,
548
+ G,
549
+ config["BLOCK_SIZE_K"],
550
+ config["BLOCK_SIZE_N"],
551
+ )
552
+
553
+ # fmt: off
554
+ tgmm_non_persistent_kernel[grid](
555
+ # Tensor pointers:
556
+ lhs, rhs, group_sizes, out, bias_grad_ptr,
557
+ # Tensor shapes:
558
+ M, K, N, G,
559
+ # Meta-parameters:
560
+ TRANS_LHS=trans_lhs,
561
+ COMPUTE_BIAS_GRAD=compute_bias_grad,
562
+ ACCUMULATE=accumulate,
563
+ **config,
564
+ )
565
+ # fmt: on
566
+
567
+ return out
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/__init__.py ADDED
File without changes
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py ADDED
File without changes
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+
3
+ # Detect the GPU arch lazily: querying the triton driver at import time fails
4
+ # in headless environments (e.g. the kernel-builder ABI check sandbox has no
5
+ # GPU), and the original JAX fallback pulled in an unrelated runtime dep. The
6
+ # arch is only actually needed when a GMM kernel is dispatched, so resolve and
7
+ # cache on first call.
8
+ _CACHED_ARCH = None
9
+
10
+
11
+ def get_arch():
12
+ global _CACHED_ARCH
13
+ if _CACHED_ARCH is not None:
14
+ return _CACHED_ARCH
15
+ try:
16
+ _CACHED_ARCH = triton.runtime.driver.active.get_current_target().arch
17
+ except RuntimeError:
18
+ try:
19
+ from jax._src.lib import gpu_triton as triton_kernel_call_lib
20
+ _CACHED_ARCH = triton_kernel_call_lib.get_arch_details("0").split(":")[0]
21
+ except ImportError as e:
22
+ raise RuntimeError(
23
+ "Cannot determine GPU arch: triton driver is inactive and "
24
+ "JAX is not available. A GPU is required for grouped GEMM."
25
+ ) from e
26
+ return _CACHED_ARCH
27
+
28
+
29
+ def is_gluon_avail():
30
+ return get_arch() in ("gfx950", "gfx1250")
31
+
32
+
33
+ def is_fp4_avail():
34
+ return get_arch() in ("gfx950", "gfx1250")
35
+
36
+
37
+ def is_fp8_avail():
38
+ return get_arch() in ("gfx942", "gfx950", "gfx1250", "gfx1200", "gfx1201")
39
+
40
+
41
+ def is_mx_scale_preshuffling_avail():
42
+ return get_arch() in ("gfx950", "gfx1250")
43
+
44
+
45
+ def is_tdm_avail():
46
+ return get_arch() in ("gfx1250",)
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+
3
+ # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+
9
+ @triton.jit
10
+ def remap_xcd_chunked(
11
+ pid, GRID_MN, NUM_XCDS: tl.constexpr = 8, CHUNK_SIZE: tl.constexpr = 2
12
+ ):
13
+ # Compute current XCD and local PID
14
+ xcd = pid % NUM_XCDS
15
+ # distribute the modulo pids in round robin
16
+ if pid > (GRID_MN // (NUM_XCDS * CHUNK_SIZE)) * (NUM_XCDS * CHUNK_SIZE):
17
+ return pid
18
+ local_pid = pid // NUM_XCDS
19
+ # Calculate chunk index and position within chunk
20
+ chunk_idx = local_pid // CHUNK_SIZE
21
+ pos_in_chunk = local_pid % CHUNK_SIZE
22
+ # Calculate new PID
23
+ new_pid = chunk_idx * NUM_XCDS * CHUNK_SIZE + xcd * CHUNK_SIZE + pos_in_chunk
24
+ return new_pid
25
+
26
+
27
+ @triton.jit
28
+ def remap_xcd(pid, GRID_MN, NUM_XCDS: tl.constexpr = 8):
29
+ ## pid remapping on xcds
30
+ # Number of pids per XCD in the new arrangement
31
+ pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS
32
+ # When GRID_MN cannot divide NUM_XCDS, some xcds will have
33
+ # pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
34
+ # We calculate the number of xcds that have pids_per_xcd pids as
35
+ # tall_xcds
36
+ tall_xcds = GRID_MN % NUM_XCDS
37
+ tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds
38
+ # Compute current XCD and local pid within the XCD
39
+ xcd = pid % NUM_XCDS
40
+ local_pid = pid // NUM_XCDS
41
+ # Calculate new pid based on the new grouping
42
+ # Note that we need to consider the following two cases:
43
+ # 1. the current pid is on a tall xcd
44
+ # 2. the current pid is on a short xcd
45
+ if xcd < tall_xcds:
46
+ pid = xcd * pids_per_xcd + local_pid
47
+ else:
48
+ pid = (
49
+ tall_xcds * pids_per_xcd
50
+ + (xcd - tall_xcds) * (pids_per_xcd - 1)
51
+ + local_pid
52
+ )
53
+
54
+ return pid
55
+
56
+
57
+ @triton.jit
58
+ def pid_grid(pid: int, num_pid_m: int, num_pid_n: int, GROUP_SIZE_M: tl.constexpr = 1):
59
+ """
60
+ Maps 1D pid to 2D grid coords (pid_m, pid_n).
61
+
62
+ Args:
63
+ - pid: 1D pid
64
+ - num_pid_m: grid m size
65
+ - num_pid_n: grid n size
66
+ - GROUP_SIZE_M: tl.constexpr: default is 1
67
+ """
68
+ if GROUP_SIZE_M == 1:
69
+ pid_m = pid // num_pid_n
70
+ pid_n = pid % num_pid_n
71
+ else:
72
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
73
+ group_id = pid // num_pid_in_group
74
+ first_pid_m = group_id * GROUP_SIZE_M
75
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
76
+ tl.assume(group_size_m >= 0)
77
+ pid_m = first_pid_m + (pid % group_size_m)
78
+ pid_n = (pid % num_pid_in_group) // group_size_m
79
+
80
+ return pid_m, pid_n
81
+
82
+
83
+ @triton.jit
84
+ def pid_grid_3d(pid: int, num_pid_m: int, num_pid_n: int, num_pid_k):
85
+ """
86
+ Maps 1D pid to 3D grid coords (pid_m, pid_n, pid_k).
87
+ Args:
88
+ - pid: 1D pid
89
+ - num_pid_m: grid m size
90
+ - num_pid_n: grid n size
91
+ - num_pid_k: grid k size
92
+
93
+ Returns:
94
+ - pid_m, pid_n, pid_k: 3D grid coordinates
95
+ """
96
+ pid_m = pid % num_pid_m
97
+ pid_n = (pid // num_pid_m) % num_pid_n
98
+ pid_k = pid // (num_pid_m * num_pid_n) % num_pid_k
99
+
100
+ return pid_m, pid_n, pid_k
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+ # Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+ # Imports.
5
+ # ------------------------------------------------------------------------------
6
+
7
+ # PyTorch
8
+ import torch
9
+ from torch import Tensor
10
+
11
+ # AITER: logging
12
+ from .logger import AiterTritonLogger
13
+
14
+ _LOGGER: AiterTritonLogger = AiterTritonLogger()
15
+
16
+
17
+ # Supported data types.
18
+ # ------------------------------------------------------------------------------
19
+
20
+ # Supported data types, as strings.
21
+ SUPPORTED_DTYPES_STR: set[str] = {"fp16", "bf16"}
22
+
23
+
24
+ # Convert string data type to PyTorch data type.
25
+ def dtype_from_str(dtype_str: str) -> torch.dtype:
26
+ dtype_str = dtype_str.strip().lower()
27
+ dtype_str = dtype_str[1:] if dtype_str[0] in {"i", "o"} else dtype_str
28
+ assert (
29
+ dtype_str in SUPPORTED_DTYPES_STR
30
+ ), "String data type isn't in set of supported string data types."
31
+ return {"fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str]
32
+
33
+
34
+ # Supported data types, as PyTorch types.
35
+ SUPPORTED_DTYPES: set[torch.dtype] = {
36
+ dtype_from_str(dtype_str) for dtype_str in SUPPORTED_DTYPES_STR
37
+ }
38
+
39
+
40
+ # Convert PyTorch data type to string data type.
41
+ def str_from_dtype(dtype: torch.dtype) -> str:
42
+ assert (
43
+ dtype in SUPPORTED_DTYPES
44
+ ), "PyTorch data type isn't in set of supported PyTorch data types."
45
+ return {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype]
46
+
47
+
48
+ # Default data type, as string.
49
+ DTYPE_STR: str = "bf16"
50
+ assert (
51
+ DTYPE_STR in SUPPORTED_DTYPES_STR
52
+ ), "Default string data type isn't in set of supported string data types."
53
+
54
+
55
+ # Default data type, as PyTorch type.
56
+ DTYPE: torch.dtype = dtype_from_str(DTYPE_STR)
57
+
58
+
59
+ # Other defaults.
60
+ # ------------------------------------------------------------------------------
61
+
62
+ # Default device.
63
+ DEVICE: torch.device | str = "cuda"
64
+
65
+ # Default RNG seed for input generation.
66
+ RNG_SEED: int = 0
67
+
68
+ # Default number of group sizes.
69
+ NUM_GROUP_SIZES: int = 1
70
+
71
+ # Default transposition (NN).
72
+ TRANS_LHS: bool = False
73
+ TRANS_RHS: bool = False
74
+
75
+
76
+ # Parameter checking functions.
77
+ # ------------------------------------------------------------------------------
78
+
79
+
80
+ def is_power_of_2(x: int) -> bool:
81
+ return (x > 0) and (x & (x - 1) == 0)
82
+
83
+
84
+ def check_input_device_dtype(
85
+ lhs: Tensor, rhs: Tensor, group_sizes: Tensor, bias: Tensor | None = None
86
+ ) -> None:
87
+ assert (
88
+ lhs.device == rhs.device == group_sizes.device
89
+ ), f"All input tensors must be in the same device (lhs = {lhs.device}, rhs = {rhs.device}, group_sizes = {group_sizes.device})."
90
+ assert (
91
+ lhs.dtype == rhs.dtype
92
+ ), f"lhs and rhs types must match (lhs = {lhs.dtype}, rhs = {rhs.dtype})."
93
+ assert group_sizes.dtype == torch.int32, "group_sizes type must be int32."
94
+
95
+ if bias is not None:
96
+ assert (
97
+ bias.device == lhs.device
98
+ ), f"bias must be on the same device as lhs (bias = {bias.device}, lhs = {lhs.device})."
99
+ assert (
100
+ bias.dtype == lhs.dtype
101
+ ), f"bias dtype must match lhs dtype (bias = {bias.dtype}, lhs = {lhs.dtype})."
102
+
103
+
104
+ def check_bias_shape_stride(bias: Tensor, G: int, N: int) -> None:
105
+ assert bias.shape == (
106
+ G,
107
+ N,
108
+ ), f"bias must have shape (G, N) = ({G}, {N}), got {bias.shape}."
109
+ assert bias.stride() == (N, 1), "bias must be row-major (bias.stride() == (N, 1))."
110
+
111
+
112
+ # Generation of group sizes.
113
+ # ------------------------------------------------------------------------------
114
+
115
+
116
+ # Probabilities for generating random group sizes.
117
+ UNUSED_TOKENS_PROB: float = 0.0
118
+ UNUSED_EXPERTS_PROB: float = 0.1
119
+
120
+
121
+ def gen_uniform_group_sizes(
122
+ M: int,
123
+ G: int,
124
+ device: torch.device | str = DEVICE,
125
+ ) -> Tensor:
126
+ assert M >= 0, f"Number of tokens M must be non-negative (it's {M})."
127
+ assert G > 0, f"Number of experts G must be positive (it's {G})."
128
+
129
+ base = M // G
130
+ remainder = M % G
131
+ group_sizes = torch.full((G,), base, dtype=torch.int32, device=device)
132
+ if remainder > 0:
133
+ group_sizes[:remainder] += 1
134
+
135
+ assert (
136
+ len(group_sizes) == G
137
+ ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})."
138
+ assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative."
139
+ assert (
140
+ torch.sum(group_sizes).item() == M
141
+ ), f"Group sizes don't add up to total tokens {M}."
142
+ assert group_sizes.dtype == torch.int32, "Group sizes must be int32."
143
+
144
+ return group_sizes
145
+
146
+
147
+ def gen_group_sizes(
148
+ M: int,
149
+ G: int,
150
+ device: torch.device | str = DEVICE,
151
+ rng_seed: int | None = RNG_SEED,
152
+ unused_tokens_prob: float = UNUSED_TOKENS_PROB,
153
+ unused_experts_prob: float = UNUSED_EXPERTS_PROB,
154
+ ) -> Tensor:
155
+ assert M >= 0, f"Number of tokens M must be non-negative (it's {M})."
156
+ assert G > 0, f"Number of experts G must be positive (it's {G})."
157
+ assert (
158
+ 0 <= unused_tokens_prob <= 1
159
+ ), f"Probability of unused tokens must be in [0, 1] interval (it's {unused_tokens_prob})."
160
+ assert (
161
+ 0 <= unused_experts_prob <= 1
162
+ ), f"Probability of unused experts must be in [0, 1] interval (it's {unused_experts_prob})."
163
+
164
+ if rng_seed is not None:
165
+ torch.manual_seed(rng_seed)
166
+
167
+ if unused_tokens_prob > 0:
168
+ # Optionally drop tokens to simulate routing sparsity, some tokens may not be routed.
169
+ num_unused_tokens = M
170
+ while num_unused_tokens == M:
171
+ num_unused_tokens = int(
172
+ torch.binomial(
173
+ torch.tensor(float(M), device=device),
174
+ torch.tensor(unused_tokens_prob, device=device),
175
+ ).item()
176
+ )
177
+ else:
178
+ num_unused_tokens = 0
179
+ num_used_tokens = M - num_unused_tokens
180
+ assert (
181
+ num_unused_tokens >= 0
182
+ ), f"Number of unused tokens must be non-negative (it's {num_unused_tokens})."
183
+ assert (
184
+ num_used_tokens > 0
185
+ ), f"Number of used tokens must be positive (it's {num_used_tokens})."
186
+ assert (
187
+ num_used_tokens + num_unused_tokens == M
188
+ ), f"Unused + used tokens don't add up total tokens ({num_used_tokens} + {num_unused_tokens} != {M})."
189
+
190
+ if num_unused_tokens > 0:
191
+ _LOGGER.debug(
192
+ f"Group sizes generation: dropped {num_unused_tokens} token{'s' if num_unused_tokens > 1 else ''}.",
193
+ )
194
+
195
+ if unused_experts_prob > 0:
196
+ # Some experts may have zero tokens assigned to them.
197
+ num_used_experts = 0
198
+ while num_used_experts == 0:
199
+ used_experts = torch.nonzero(
200
+ torch.rand((G,), device=device) >= unused_experts_prob
201
+ ).squeeze()
202
+ num_used_experts = used_experts.numel()
203
+ else:
204
+ used_experts = torch.arange(0, G, device=device)
205
+ num_used_experts = G
206
+ num_unused_experts = G - num_used_experts
207
+ assert (
208
+ num_unused_experts >= 0
209
+ ), f"Number of unused experts must be non-negative (it's {num_unused_experts})."
210
+ assert (
211
+ num_used_experts >= 1
212
+ ), f"At least one expert must be used (it's {num_used_experts})."
213
+ assert (
214
+ num_unused_experts + num_used_experts == G
215
+ ), f"Unused + used experts don't add up total experts ({num_unused_experts} + {num_used_experts} != {G})."
216
+
217
+ if num_unused_experts > 0:
218
+ _LOGGER.debug(
219
+ f"Group sizes generation: dropped {num_unused_experts} expert{'s' if num_unused_experts > 1 else ''}.",
220
+ )
221
+
222
+ group_sizes = torch.bincount(
223
+ used_experts[
224
+ torch.randint(low=0, high=num_used_experts, size=(num_used_tokens,))
225
+ ],
226
+ minlength=G,
227
+ ).to(torch.int32)
228
+
229
+ assert (
230
+ len(group_sizes) == G
231
+ ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})."
232
+ assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative."
233
+ assert (
234
+ torch.sum(group_sizes).item() == num_used_tokens
235
+ ), f"Group sizes don't add up to used tokens {num_used_tokens}."
236
+ assert group_sizes.dtype == torch.int32, "Group sizes must be int32."
237
+
238
+ return group_sizes
239
+
240
+
241
+ def gen_multiple_group_sizes(
242
+ num_group_sizes: int,
243
+ M: int,
244
+ G: int,
245
+ device: torch.device | str = DEVICE,
246
+ rng_seed: int | None = RNG_SEED,
247
+ unused_tokens_prob: float = UNUSED_TOKENS_PROB,
248
+ unused_experts_prob: float = UNUSED_EXPERTS_PROB,
249
+ group_sizes_0: Tensor | None = None,
250
+ ) -> list[Tensor]:
251
+ assert (
252
+ num_group_sizes > 0
253
+ ), f"Number of group sizes to be generated must be positive, it's {num_group_sizes}."
254
+ multiple_group_sizes = [
255
+ gen_group_sizes(
256
+ M,
257
+ G,
258
+ device=device,
259
+ rng_seed=rng_seed if g == 0 else None,
260
+ unused_tokens_prob=unused_tokens_prob,
261
+ unused_experts_prob=unused_experts_prob,
262
+ )
263
+ for g in range(
264
+ num_group_sizes if group_sizes_0 is None else num_group_sizes - 1
265
+ )
266
+ ]
267
+ if group_sizes_0 is not None:
268
+ multiple_group_sizes.insert(0, group_sizes_0)
269
+ assert (
270
+ len(multiple_group_sizes) == num_group_sizes
271
+ ), f"Expecting {num_group_sizes} distinct group sizes (it's {len(multiple_group_sizes)})."
272
+ return multiple_group_sizes
273
+
274
+
275
+ # GMM helpers: tensor generation.
276
+ # ------------------------------------------------------------------------------
277
+
278
+
279
+ def gen_gmm_input(
280
+ M: int,
281
+ K: int,
282
+ N: int,
283
+ G: int,
284
+ device: torch.device | str = DEVICE,
285
+ preferred_element_type: torch.dtype = DTYPE,
286
+ trans_rhs: bool = TRANS_RHS,
287
+ rng_seed: int | None = RNG_SEED,
288
+ unif_group_sizes: bool = False,
289
+ ) -> tuple[Tensor, Tensor, Tensor]:
290
+ assert M > 0, f"Number of lhs rows M must be positive (M = {M})."
291
+ assert K > 0, f"Number of lhs columns / rhs rows K must be positive (K = {K})."
292
+ assert N > 0, f"Number of rhs columns N must be positive (N = {N})."
293
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
294
+
295
+ if rng_seed is not None:
296
+ torch.manual_seed(rng_seed)
297
+
298
+ lhs = torch.randn((M, K), dtype=torch.float32, device=device)
299
+ lhs = lhs.to(preferred_element_type)
300
+
301
+ if trans_rhs:
302
+ rhs = torch.randn((G, N, K), dtype=torch.float32, device=device).permute(
303
+ 0, 2, 1
304
+ )
305
+ else:
306
+ rhs = torch.randn((G, K, N), dtype=torch.float32, device=device)
307
+ rhs = rhs.to(preferred_element_type)
308
+
309
+ group_sizes = (
310
+ gen_uniform_group_sizes(M, G, device=device)
311
+ if unif_group_sizes
312
+ else gen_group_sizes(M, G, device=device, rng_seed=None)
313
+ )
314
+
315
+ return lhs, rhs, group_sizes
316
+
317
+
318
+ def gen_gmm_output(
319
+ M: int,
320
+ N: int,
321
+ device: torch.device | str = DEVICE,
322
+ preferred_element_type: torch.dtype = DTYPE,
323
+ ) -> Tensor:
324
+ assert M > 0, f"Number of out rows M must be positive (M = {M})."
325
+ assert N > 0, f"Number of out columns N must be positive (N = {N})."
326
+
327
+ out = torch.empty((M, N), dtype=preferred_element_type, device=device)
328
+
329
+ return out
330
+
331
+
332
+ def gen_gmm_tensors(
333
+ M: int,
334
+ K: int,
335
+ N: int,
336
+ G: int,
337
+ num_group_sizes: int,
338
+ device: torch.device | str = DEVICE,
339
+ input_type: torch.dtype = DTYPE,
340
+ output_type: torch.dtype = DTYPE,
341
+ trans_lhs: bool = False,
342
+ trans_rhs: bool = TRANS_RHS,
343
+ rng_seed: int | None = RNG_SEED,
344
+ unif_group_sizes: bool = False,
345
+ use_bias: bool = False,
346
+ ) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]:
347
+ lhs, rhs, group_sizes_0 = gen_gmm_input(
348
+ M,
349
+ K,
350
+ N,
351
+ G,
352
+ device=device,
353
+ preferred_element_type=input_type,
354
+ trans_rhs=trans_rhs,
355
+ rng_seed=rng_seed,
356
+ unif_group_sizes=unif_group_sizes,
357
+ )
358
+ multiple_group_sizes = gen_multiple_group_sizes(
359
+ num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0
360
+ )
361
+ out = gen_gmm_output(M, N, device=device, preferred_element_type=output_type)
362
+ bias = None
363
+ if use_bias:
364
+ torch.manual_seed(rng_seed + 1000) # Different seed for bias
365
+ bias = torch.randn(G, N, dtype=input_type, device=device)
366
+
367
+ return lhs, rhs, multiple_group_sizes, out, bias
368
+
369
+
370
+ # GMM helpers: get information from tensors.
371
+ # ------------------------------------------------------------------------------
372
+
373
+
374
+ def get_gmm_shape(
375
+ lhs: Tensor, rhs: Tensor, group_sizes: Tensor
376
+ ) -> tuple[int, int, int, int]:
377
+ assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
378
+ assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})."
379
+ assert (
380
+ group_sizes.dim() == 1
381
+ ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})."
382
+
383
+ M, lhs_k = lhs.shape
384
+ rhs_g, rhs_k, N = rhs.shape
385
+ group_sizes_g = group_sizes.shape[0]
386
+
387
+ assert (
388
+ lhs_k == rhs_k
389
+ ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})."
390
+ K = lhs_k
391
+ assert (
392
+ rhs_g == group_sizes_g
393
+ ), f"G dimension of rhs and group_sizes don't match (rhs = {rhs_g}, group_sizes = {group_sizes_g})."
394
+ G = rhs_g
395
+
396
+ assert M > 0, f"M must be positive, it's {M}."
397
+ assert K > 0, f"K must be positive, it's {K}."
398
+ assert N > 0, f"N must be positive, it's {N}"
399
+ assert G > 0, f"G must be positive, it's {G}"
400
+
401
+ return M, K, N, G
402
+
403
+
404
+ def get_gmm_output(
405
+ M: int,
406
+ N: int,
407
+ device: torch.device | str = DEVICE,
408
+ preferred_element_type: torch.dtype = DTYPE,
409
+ existing_out: Tensor | None = None,
410
+ ) -> Tensor:
411
+ assert M > 0, f"Number of out rows M must be positive (M = {M})."
412
+ assert N > 0, f"Number of out columns N must be positive (N = {N})."
413
+
414
+ if existing_out is not None:
415
+ assert (
416
+ existing_out.device == device
417
+ ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})."
418
+ assert (
419
+ existing_out.dtype == preferred_element_type
420
+ ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})."
421
+ assert existing_out.shape == (
422
+ M,
423
+ N,
424
+ ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(M, N)})."
425
+ return existing_out
426
+
427
+ return gen_gmm_output(
428
+ M,
429
+ N,
430
+ device=device,
431
+ preferred_element_type=preferred_element_type,
432
+ )
433
+
434
+
435
+ def get_gmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]:
436
+ assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
437
+ assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})."
438
+ assert out.dim() == 2, f"out must have 2 dimensions (it's {out.dim()})."
439
+
440
+ lhs_m, lhs_k = lhs.shape
441
+ G, rhs_k, rhs_n = rhs.shape
442
+ out_m, out_n = out.shape
443
+
444
+ assert (
445
+ lhs_m == out_m
446
+ ), f"M dimension of lhs and out don't match (lhs = {lhs_m}, rhs = {out_m})."
447
+ M = lhs_m
448
+ assert (
449
+ lhs_k == rhs_k
450
+ ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})."
451
+ K = lhs_k
452
+ assert (
453
+ rhs_n == out_n
454
+ ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})."
455
+ N = rhs_n
456
+
457
+ assert M > 0, f"M must be positive, it's {M}."
458
+ assert K > 0, f"K must be positive, it's {K}."
459
+ assert N > 0, f"N must be positive, it's {N}"
460
+ assert G > 0, f"G must be positive, it's {G}"
461
+
462
+ is_lhs_row_major = lhs.stride() == (K, 1)
463
+ assert is_lhs_row_major, "lhs must be row-major."
464
+ is_rhs_row_major = rhs.stride() == (K * N, N, 1)
465
+ is_rhs_col_major = rhs.stride() == (K * N, 1, K)
466
+ assert (
467
+ is_rhs_row_major != is_rhs_col_major
468
+ ), "rhs must be row-major or column-major."
469
+ is_out_row_major = out.stride() == (N, 1)
470
+ assert is_out_row_major, "out must be row-major."
471
+
472
+ # Get rhs leading dimension according to transposition configuration.
473
+ ld_rhs = N if is_rhs_row_major else K
474
+
475
+ return is_rhs_col_major, ld_rhs
476
+
477
+
478
+ # TGMM helpers: tensor generation.
479
+ # ------------------------------------------------------------------------------
480
+
481
+
482
+ def gen_tgmm_input(
483
+ M: int,
484
+ K: int,
485
+ N: int,
486
+ G: int,
487
+ device: torch.device | str = DEVICE,
488
+ preferred_element_type: torch.dtype = DTYPE,
489
+ trans_lhs: bool = TRANS_LHS,
490
+ rng_seed: int | None = RNG_SEED,
491
+ unif_group_sizes: bool = False,
492
+ ) -> tuple[Tensor, Tensor, Tensor]:
493
+ assert K > 0, f"Number of lhs rows K must be positive (M = {K})."
494
+ assert M > 0, f"Number of lhs columns / rhs rows M must be positive (K = {M})."
495
+ assert N > 0, f"Number of rhs columns N must be positive (N = {N})."
496
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
497
+
498
+ if rng_seed is not None:
499
+ torch.manual_seed(rng_seed)
500
+
501
+ if trans_lhs:
502
+ lhs = torch.randn((M, K), dtype=torch.float32, device=device).T
503
+ else:
504
+ lhs = torch.randn((K, M), dtype=torch.float32, device=device)
505
+ lhs = lhs.to(preferred_element_type)
506
+
507
+ rhs = torch.randn((M, N), dtype=torch.float32, device=device)
508
+ rhs = rhs.to(preferred_element_type)
509
+
510
+ group_sizes = (
511
+ gen_uniform_group_sizes(M, G, device=device)
512
+ if unif_group_sizes
513
+ else gen_group_sizes(M, G, device=device, rng_seed=None)
514
+ )
515
+
516
+ return lhs, rhs, group_sizes
517
+
518
+
519
+ def gen_tgmm_output(
520
+ K: int,
521
+ N: int,
522
+ G: int,
523
+ device: torch.device | str = DEVICE,
524
+ preferred_element_type: torch.dtype = DTYPE,
525
+ ) -> Tensor:
526
+ assert K > 0, f"Number of out rows K must be positive (K = {K})."
527
+ assert N > 0, f"Number of out columns N must be positive (N = {N})."
528
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
529
+
530
+ out = torch.empty((G, K, N), dtype=preferred_element_type, device=device)
531
+
532
+ return out
533
+
534
+
535
+ def gen_tgmm_bias_grad(
536
+ K: int,
537
+ G: int,
538
+ device: torch.device | str = DEVICE,
539
+ with_bias_grad: bool = False,
540
+ ) -> Tensor:
541
+ if with_bias_grad:
542
+ assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})."
543
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
544
+ return torch.empty((G, K), device=device, dtype=torch.float32)
545
+ else:
546
+ # Return dummy pointer when bias_grad is not needed.
547
+ # Must be float32 because atomic_add does not support bf16/fp16,
548
+ # and Triton validates the pointer dtype even in dead branches.
549
+ return torch.tensor([], device=device, dtype=torch.float32)
550
+
551
+
552
+ def gen_tgmm_tensors(
553
+ M: int,
554
+ K: int,
555
+ N: int,
556
+ G: int,
557
+ num_group_sizes: int,
558
+ device: torch.device | str = DEVICE,
559
+ input_type: torch.dtype = DTYPE,
560
+ output_type: torch.dtype = DTYPE,
561
+ trans_lhs: bool = TRANS_LHS,
562
+ trans_rhs: bool = False,
563
+ rng_seed: int | None = RNG_SEED,
564
+ unif_group_sizes: bool = False,
565
+ use_bias: bool = False,
566
+ ) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]:
567
+ lhs, rhs, group_sizes_0 = gen_tgmm_input(
568
+ M,
569
+ K,
570
+ N,
571
+ G,
572
+ device=device,
573
+ preferred_element_type=input_type,
574
+ trans_lhs=trans_lhs,
575
+ rng_seed=rng_seed,
576
+ unif_group_sizes=unif_group_sizes,
577
+ )
578
+ multiple_group_sizes = gen_multiple_group_sizes(
579
+ num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0
580
+ )
581
+ out = gen_tgmm_output(K, N, G, device=device, preferred_element_type=output_type)
582
+ if use_bias:
583
+ bias_grad = gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=True)
584
+ else:
585
+ bias_grad = None
586
+ return lhs, rhs, multiple_group_sizes, out, bias_grad
587
+
588
+
589
+ # TGMM helpers: get information from tensors.
590
+ # ------------------------------------------------------------------------------
591
+
592
+
593
+ def get_tgmm_shape(
594
+ lhs: Tensor, rhs: Tensor, group_sizes: Tensor
595
+ ) -> tuple[int, int, int, int]:
596
+ assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
597
+ assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})."
598
+ assert (
599
+ group_sizes.dim() == 1
600
+ ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})."
601
+
602
+ K, lhs_m = lhs.shape
603
+ rhs_m, N = rhs.shape
604
+ G = group_sizes.shape[0]
605
+
606
+ assert (
607
+ lhs_m == rhs_m
608
+ ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})."
609
+ M = lhs_m
610
+
611
+ assert M > 0, f"M must be positive, it's {M}."
612
+ assert K > 0, f"K must be positive, it's {K}."
613
+ assert N > 0, f"N must be positive, it's {N}"
614
+ assert G > 0, f"G must be positive, it's {G}"
615
+
616
+ return M, K, N, G
617
+
618
+
619
+ def get_tgmm_output(
620
+ K: int,
621
+ N: int,
622
+ G: int,
623
+ device: torch.device | str = DEVICE,
624
+ preferred_element_type: torch.dtype = DTYPE,
625
+ existing_out: Tensor | None = None,
626
+ ) -> Tensor:
627
+ assert K > 0, f"Number of out rows K must be positive (K = {K})."
628
+ assert N > 0, f"Number of out columns N must be positive (N = {N})."
629
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
630
+
631
+ if existing_out is not None:
632
+ assert (
633
+ existing_out.device == device
634
+ ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})."
635
+ assert (
636
+ existing_out.dtype == preferred_element_type
637
+ ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})."
638
+ assert existing_out.shape == (
639
+ G,
640
+ K,
641
+ N,
642
+ ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(G, K, N)})."
643
+ return existing_out
644
+
645
+ return gen_tgmm_output(
646
+ K,
647
+ N,
648
+ G,
649
+ device=device,
650
+ preferred_element_type=preferred_element_type,
651
+ )
652
+
653
+
654
+ def get_tgmm_bias_grad(
655
+ K: int,
656
+ G: int,
657
+ device: torch.device | str = DEVICE,
658
+ existing_bias_grad: Tensor | None = None,
659
+ ) -> Tensor:
660
+ """
661
+ Get or validate bias gradient tensor for TGMM.
662
+
663
+ If existing_bias_grad is provided, validates its shape, device, dtype, and stride,
664
+ and always zeros it before returning (since the kernel uses atomic_add).
665
+ If existing_bias_grad is None, returns a dummy tensor (for use when COMPUTE_BIAS_GRAD=False).
666
+ Parameters
667
+ ----------
668
+ K : int
669
+ Number of rows in the bias gradient tensor.
670
+ G : int
671
+ Number of groups.
672
+ device : torch.device or str
673
+ Device for the tensor.
674
+ existing_bias_grad : torch.Tensor or None
675
+ Existing bias gradient tensor to validate and use.
676
+ Returns
677
+ -------
678
+ torch.Tensor
679
+ Valid bias gradient tensor or dummy tensor.
680
+ """
681
+ assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})."
682
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
683
+
684
+ if existing_bias_grad is not None:
685
+ # Validate existing bias_grad tensor.
686
+ expected_shape = (G, K)
687
+ assert (
688
+ tuple(existing_bias_grad.shape) == expected_shape
689
+ ), f"bias_grad must have shape {expected_shape}, got {tuple(existing_bias_grad.shape)}."
690
+ assert (
691
+ existing_bias_grad.device == device
692
+ ), f"bias_grad must be on the same device (bias_grad = {existing_bias_grad.device}, device = {device})."
693
+ assert (
694
+ existing_bias_grad.dtype == torch.float32
695
+ ), f"bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), got {existing_bias_grad.dtype}."
696
+ assert existing_bias_grad.stride() == (
697
+ K,
698
+ 1,
699
+ ), f"bias_grad must be row-major with stride (K, 1) = ({K}, 1), got {existing_bias_grad.stride()}."
700
+
701
+ # Always zero the tensor since bias_grad represents gradients for the current
702
+ # computation and should start fresh. The kernel uses atomic_add which adds to
703
+ # existing values, so we must zero before the kernel runs.
704
+ existing_bias_grad.zero_()
705
+
706
+ return existing_bias_grad
707
+
708
+ else:
709
+ return gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=False)
710
+
711
+
712
+ def get_tgmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]:
713
+ assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
714
+ assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})."
715
+ assert out.dim() == 3, f"out must have 3 dimensions (it's {out.dim()})."
716
+
717
+ lhs_k, lhs_m = lhs.shape
718
+ rhs_m, rhs_n = rhs.shape
719
+ G, out_k, out_n = out.shape
720
+
721
+ assert (
722
+ lhs_m == rhs_m
723
+ ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})."
724
+ M = lhs_m
725
+ assert (
726
+ lhs_k == out_k
727
+ ), f"K dimension of lhs and out don't match (lhs = {lhs_k}, rhs = {out_k})."
728
+ K = lhs_k
729
+ assert (
730
+ rhs_n == out_n
731
+ ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})."
732
+ N = rhs_n
733
+
734
+ assert M > 0, f"M must be positive, it's {M}."
735
+ assert K > 0, f"K must be positive, it's {K}."
736
+ assert N > 0, f"N must be positive, it's {N}"
737
+ assert G > 0, f"G must be positive, it's {G}"
738
+
739
+ is_lhs_row_major = lhs.stride() == (M, 1)
740
+ is_lhs_col_major = lhs.stride() == (1, K)
741
+ assert (
742
+ is_lhs_row_major != is_lhs_col_major
743
+ ), "lhs must be row-major or column-major."
744
+ is_rhs_row_major = rhs.stride() == (N, 1)
745
+ assert is_rhs_row_major, "rhs must be row-major."
746
+ is_out_row_major = out.stride() == (K * N, N, 1)
747
+ assert is_out_row_major, "out must be row-major."
748
+
749
+ # Get lhs leading dimension according to transposition configuration.
750
+ ld_lhs = M if is_lhs_row_major else K
751
+
752
+ return is_lhs_col_major, ld_lhs
build/torch211-cxx11-cu128-x86_64-linux/_grouped_gemm_triton/utils/logger.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+
5
+ # AITER Triton Logger which is singleton object around python logging.
6
+ # Note: Python logging is also a singleton object, but we want to read the
7
+ # env var AITER_LOG_LEVEL once at the beginning. Another alternative is to do
8
+ # this in __init__.py. In fact, that's how CK logger is setup. We can look at
9
+ # switching to that at some point
10
+ #
11
+ # AITER_LOG_LEVEL follows python logging levels
12
+ # DEBUG
13
+ # INFO
14
+ # WARNING
15
+ # ERROR
16
+ # CRITICAL
17
+ #
18
+ class AiterTritonLogger(object):
19
+ _instance = None
20
+
21
+ def __new__(cls):
22
+ if cls._instance is None:
23
+ cls._instance = super(AiterTritonLogger, cls).__new__(cls)
24
+ log_level_str = os.getenv("AITER_TRITON_LOG_LEVEL", "WARNING").upper()
25
+ numeric_level = getattr(logging, log_level_str, logging.WARNING)
26
+ cls._instance._logger = logging.getLogger("AITER_TRITON")
27
+ cls._instance._logger.setLevel(numeric_level)
28
+
29
+ return cls._instance
30
+
31
+ def get_logger(self):
32
+ return self._logger
33
+
34
+ def debug(self, msg):
35
+ self._logger.debug(msg)
36
+
37
+ def info(self, msg):
38
+ self._logger.info(msg)
39
+
40
+ def warning(self, msg):
41
+ self._logger.warning(msg)
42
+
43
+ def error(self, msg):
44
+ self._logger.error(msg)
45
+
46
+ def critical(self, msg):
47
+ self._logger.critical(msg)
build/torch211-cxx11-cu128-x86_64-linux/{_megablocks_cuda_ae601bb.abi3.so → _megablocks_cuda_f8f8b50.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ce5790e025e92878a33c9a766bca1cda450c920f68f49549525413c7e754c100
3
- size 19082856
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ea3f6a68cbc730572a4a4c8d3814a2075cc775bffcf3082c9dbd6291e888555
3
+ size 19750504
build/torch211-cxx11-cu128-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_ae601bb
3
- ops = torch.ops._megablocks_cuda_ae601bb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_ae601bb::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_f8f8b50
3
+ ops = torch.ops._megablocks_cuda_f8f8b50
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_f8f8b50::{op_name}"
build/torch211-cxx11-cu128-x86_64-linux/grouped_gemm/backend.py CHANGED
@@ -2,16 +2,16 @@
2
  # extensions. Otherwise libc10.so cannot be found.
3
  import torch
4
 
5
- # # TODO(tgale): Wrap this in a try-block with better
6
- # # error message and instructions for building the
7
- # # c++ operations.
8
- # import grouped_gemm_backend as backend
9
 
10
- # We import the backend operations from the megablocks package as
11
- # grouped_gemm is vendored in megablocks in this repository.
12
- # from ... import _ops as backend
13
- # from megablocks._ops import ops as backend # type: ignore
14
- from .._ops import ops as backend # type: ignore
 
15
 
16
  def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
17
  assert not (trans_a and trans_b)
 
2
  # extensions. Otherwise libc10.so cannot be found.
3
  import torch
4
 
5
+ # On ROCm there is no CUTLASS grouped GEMM; dispatch to the vendored AITER
6
+ # Triton kernels instead. On CUDA we use the compiled CUTLASS `gmm` op.
7
+ _IS_ROCM = torch.version.hip is not None
 
8
 
9
+ if _IS_ROCM:
10
+ from .._grouped_gemm_triton import adapter as backend
11
+ else:
12
+ # We import the backend operations from the megablocks package as
13
+ # grouped_gemm is vendored in megablocks in this repository.
14
+ from .._ops import ops as backend # type: ignore
15
 
16
  def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
17
  assert not (trans_a and trans_b)
build/torch211-cxx11-cu128-x86_64-linux/metadata.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "name": "megablocks",
3
- "id": "_megablocks_cuda_ae601bb",
4
  "version": 1,
5
  "license": "Apache-2.0",
6
  "python-depends": [],
@@ -10,6 +10,7 @@
10
  "10.0",
11
  "10.1",
12
  "12.0",
 
13
  "7.0",
14
  "7.2",
15
  "7.5",
 
1
  {
2
  "name": "megablocks",
3
+ "id": "_megablocks_cuda_f8f8b50",
4
  "version": 1,
5
  "license": "Apache-2.0",
6
  "python-depends": [],
 
10
  "10.0",
11
  "10.1",
12
  "12.0",
13
+ "12.0+PTX",
14
  "7.0",
15
  "7.2",
16
  "7.5",
build/torch211-cxx11-cu130-x86_64-linux/__init__.py CHANGED
@@ -3,7 +3,9 @@
3
 
4
  import torch
5
 
6
- from ._ops import ops
 
 
7
 
8
  from .grouped_gemm import backend as gg_backend
9
  from .grouped_gemm import ops as gg_ops
@@ -136,7 +138,8 @@ def sort(
136
  Returns:
137
  The sorted values tensor
138
  """
139
- return ops.sort(x, end_bit, x_out, iota_out)
 
140
 
141
 
142
  # Convenience functions for common use cases
 
3
 
4
  import torch
5
 
6
+ # Stable alias: bare `ops` is shadowed by `from . import layers` below.
7
+ from ._ops import ops as _compiled_ops
8
+ from . import ops
9
 
10
  from .grouped_gemm import backend as gg_backend
11
  from .grouped_gemm import ops as gg_ops
 
138
  Returns:
139
  The sorted values tensor
140
  """
141
+ _compiled_ops.sort(x, end_bit, x_out, iota_out)
142
+ return x_out
143
 
144
 
145
  # Convenience functions for common use cases
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/__init__.py ADDED
File without changes
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/__init__.py ADDED
File without changes
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/_triton_kernels/gmm.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+ # Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+
5
+ # Imports.
6
+ # ------------------------------------------------------------------------------
7
+
8
+ # Python standard library
9
+ import functools
10
+
11
+ # Triton
12
+ import triton
13
+ import triton.language as tl
14
+
15
+ # AITER
16
+ from ..configs import CONFIGS as _CONFIGS
17
+ from ..utils._triton import arch_info
18
+ from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
19
+
20
+ # Kernel config.
21
+ # ------------------------------------------------------------------------------
22
+
23
+
24
+ @functools.lru_cache()
25
+ def get_config(
26
+ gmm_type: str, M: int, K: int, N: int, G: int, accumulate: bool = False
27
+ ) -> dict[str, int]:
28
+ assert gmm_type in {
29
+ "gmm",
30
+ "ptgmm",
31
+ "nptgmm",
32
+ }, f"'{gmm_type}' is an invalid GMM variant."
33
+ dev = arch_info.get_arch()
34
+ assert (
35
+ dev in _CONFIGS
36
+ ), f"No GMM configuration tuned for arch '{dev}'. Supported: {sorted(_CONFIGS)}."
37
+ arch_configs = _CONFIGS[dev]
38
+ assert (
39
+ "default" in arch_configs[gmm_type]
40
+ ), "Default configuration is absent."
41
+ key = "accumulate" if accumulate else "default"
42
+ return arch_configs[gmm_type][key]
43
+
44
+
45
+ # Common code shared by GMM and TGMM kernels.
46
+ # ------------------------------------------------------------------------------
47
+
48
+
49
+ # XCD remapping followed by 1D PID to 2D grid mapping.
50
+ @triton.jit
51
+ def _remap_xcd_tile_grid(
52
+ tile_in_mm,
53
+ num_row_tiles,
54
+ num_col_tiles,
55
+ GROUP_SIZE: tl.constexpr = 1,
56
+ NUM_XCDS: tl.constexpr = 8,
57
+ ):
58
+ return pid_grid(
59
+ remap_xcd(tile_in_mm, num_row_tiles * num_col_tiles, NUM_XCDS=NUM_XCDS),
60
+ num_row_tiles,
61
+ num_col_tiles,
62
+ GROUP_SIZE_M=GROUP_SIZE,
63
+ )
64
+
65
+
66
+ # GMM kernel.
67
+ # ------------------------------------------------------------------------------
68
+
69
+
70
+ @triton.heuristics(
71
+ {
72
+ "K_DIVISIBLE_BY_BLOCK_SIZE_K": lambda META: META["K"] % META["BLOCK_SIZE_K"]
73
+ == 0,
74
+ }
75
+ )
76
+ @triton.jit
77
+ def gmm_kernel(
78
+ # Tensor pointers:
79
+ lhs_ptr,
80
+ rhs_ptr,
81
+ group_sizes_ptr,
82
+ out_ptr,
83
+ bias_ptr,
84
+ # Tensor shapes:
85
+ M: int,
86
+ K: int,
87
+ N: int,
88
+ G: int,
89
+ # Meta-parameters:
90
+ TRANS_RHS: tl.constexpr,
91
+ BLOCK_SIZE_M: tl.constexpr,
92
+ BLOCK_SIZE_K: tl.constexpr,
93
+ BLOCK_SIZE_N: tl.constexpr,
94
+ K_DIVISIBLE_BY_BLOCK_SIZE_K: tl.constexpr,
95
+ GROUP_SIZE: tl.constexpr,
96
+ GRID_DIM: tl.constexpr,
97
+ USE_BIAS: tl.constexpr,
98
+ ):
99
+ tl.assume(M > 0)
100
+ tl.assume(K > 0)
101
+ tl.assume(N > 0)
102
+ tl.assume(G > 0)
103
+
104
+ num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
105
+ tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
106
+
107
+ # Current tile. Each program computes multiple tiles of each group.
108
+ tile = tl.program_id(0)
109
+ tl.device_assert(tile >= 0, "tile < 0 (at initialization)")
110
+
111
+ # Tile limit of last MM problem (inclusive).
112
+ last_mm_tile = 0
113
+
114
+ # Last input row of lhs and output row of out. Each group reads some rows of
115
+ # lhs and writes some rows to out.
116
+ last_m = 0
117
+
118
+ # Loop through all (m, K, N) MM problems:
119
+ # (m, K) x (K, N) = (m, N)
120
+ # sum(m) = M
121
+ for g in range(G):
122
+ # Get m dimension of current MM problem.
123
+ m = tl.load(group_sizes_ptr + g)
124
+ # m can be zero if group is empty
125
+ tl.device_assert(m >= 0, "m < 0")
126
+
127
+ num_m_tiles = tl.cdiv(m, BLOCK_SIZE_M)
128
+ # num_m_tiles can be zero if group is empty
129
+ tl.device_assert(num_m_tiles >= 0, "num_m_tiles < 0")
130
+
131
+ num_tiles = num_m_tiles * num_n_tiles
132
+ # num_tiles can be zero if group is empty
133
+ tl.device_assert(num_tiles >= 0, "num_tiles < 0")
134
+
135
+ # Loop through tiles of current MM problem.
136
+ while tile >= last_mm_tile and tile < last_mm_tile + num_tiles:
137
+ # Figure out tile coordinates in current MM problem.
138
+ tile_in_mm = tile - last_mm_tile
139
+ tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
140
+
141
+ tile_m, tile_n = _remap_xcd_tile_grid(
142
+ tile_in_mm, num_m_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
143
+ )
144
+
145
+ # Do regular MM:
146
+
147
+ tl.device_assert(tile_m * BLOCK_SIZE_M >= 0, "tile_m * BLOCK_SIZE_M < 0")
148
+ tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
149
+
150
+ offs_lhs_m = (
151
+ tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
152
+ ) % m
153
+ offs_rhs_n = (
154
+ tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
155
+ ) % N
156
+ offs_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
157
+
158
+ lhs_ptrs = lhs_ptr + (last_m + offs_lhs_m[:, None]) * K + offs_k[None, :]
159
+
160
+ if TRANS_RHS:
161
+ rhs_ptrs = (
162
+ rhs_ptr
163
+ + g.to(tl.int64) * K * N
164
+ + offs_k[:, None]
165
+ + offs_rhs_n[None, :] * K
166
+ )
167
+ else:
168
+ rhs_ptrs = (
169
+ rhs_ptr
170
+ + g.to(tl.int64) * K * N
171
+ + offs_k[:, None] * N
172
+ + offs_rhs_n[None, :]
173
+ )
174
+
175
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
176
+
177
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
178
+ if K_DIVISIBLE_BY_BLOCK_SIZE_K:
179
+ lhs = tl.load(lhs_ptrs)
180
+ rhs = tl.load(rhs_ptrs)
181
+ else:
182
+ k_mask_limit = K - k * BLOCK_SIZE_K
183
+ lhs = tl.load(
184
+ lhs_ptrs, mask=offs_k[None, :] < k_mask_limit, other=0
185
+ )
186
+ rhs = tl.load(
187
+ rhs_ptrs, mask=offs_k[:, None] < k_mask_limit, other=0
188
+ )
189
+
190
+ acc = tl.dot(lhs, rhs, acc=acc)
191
+
192
+ lhs_ptrs += BLOCK_SIZE_K
193
+
194
+ if TRANS_RHS:
195
+ rhs_ptrs += BLOCK_SIZE_K
196
+ else:
197
+ rhs_ptrs += BLOCK_SIZE_K * N
198
+
199
+ # Add bias if enabled
200
+ if USE_BIAS:
201
+ offs_bias_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(
202
+ 0, BLOCK_SIZE_N
203
+ )
204
+ bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bias_n
205
+ bias = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0)
206
+ # Convert bias to float32 to match accumulator precision
207
+ bias = bias.to(tl.float32)
208
+ # Broadcast bias across M dimension and add in float32
209
+ acc += bias[None, :]
210
+
211
+ # Convert to output dtype after all computations
212
+ acc = acc.to(out_ptr.type.element_ty)
213
+
214
+ offs_out_m = tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
215
+ offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
216
+
217
+ out_ptrs = (
218
+ out_ptr + (last_m + offs_out_m[:, None]) * N + offs_out_n[None, :]
219
+ )
220
+
221
+ tl.store(
222
+ out_ptrs,
223
+ acc,
224
+ mask=(offs_out_m[:, None] < m) & (offs_out_n[None, :] < N),
225
+ )
226
+
227
+ # Go to the next tile by advancing number of programs.
228
+ tile += GRID_DIM
229
+ tl.device_assert(tile > 0, "tile <= 0 (at update)")
230
+
231
+ # Get ready to go to the next MM problem.
232
+
233
+ last_mm_tile += num_tiles
234
+ # last_mm_tile can be zero if group 0 is skipped
235
+ tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)")
236
+
237
+ last_m += m
238
+ # last_m can be zero if group 0 is skipped
239
+ tl.device_assert(last_m >= 0, "last_m < 0 (at update)")
240
+ tl.device_assert(last_m <= M, "last_m > M (at update)")
241
+
242
+
243
+ # Persistent TGMM kernel.
244
+ # ------------------------------------------------------------------------------
245
+
246
+
247
+ @triton.jit
248
+ def tgmm_persistent_kernel(
249
+ # Tensor pointers:
250
+ lhs_ptr,
251
+ rhs_ptr,
252
+ group_sizes_ptr,
253
+ out_ptr,
254
+ bias_grad_ptr,
255
+ # Tensor shapes:
256
+ M: int,
257
+ K: int,
258
+ N: int,
259
+ G: int,
260
+ # Meta-parameters:
261
+ TRANS_LHS: tl.constexpr,
262
+ BLOCK_SIZE_M: tl.constexpr,
263
+ BLOCK_SIZE_K: tl.constexpr,
264
+ BLOCK_SIZE_N: tl.constexpr,
265
+ GROUP_SIZE: tl.constexpr,
266
+ GRID_DIM: tl.constexpr,
267
+ COMPUTE_BIAS_GRAD: tl.constexpr,
268
+ ACCUMULATE: tl.constexpr,
269
+ ):
270
+ tl.assume(M > 0)
271
+ tl.assume(K > 0)
272
+ tl.assume(N > 0)
273
+ tl.assume(G > 0)
274
+
275
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
276
+ tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0")
277
+
278
+ num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
279
+ tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
280
+
281
+ num_tiles = num_k_tiles * num_n_tiles
282
+ tl.device_assert(num_tiles > 0, "num_tiles <= 0")
283
+
284
+ # Current tile. Each program computes multiple tiles of each group.
285
+ tile = tl.program_id(0)
286
+ tl.device_assert(tile >= 0, "tile < 0 (at initialization)")
287
+
288
+ # Tile limit of last MM problem (inclusive).
289
+ last_mm_tile = 0
290
+
291
+ # Last input column of lhs and input row of rhs. Each group reads some
292
+ # columns of lhs and some rows of rhs.
293
+ last_m = 0
294
+
295
+ # Loop through all (K, m, N) MM problems:
296
+ # (K, m) x (m, N) = (K, N)
297
+ # sum(m) = M
298
+ for g in range(G):
299
+ # Get m dimension of current MM problem.
300
+ m = tl.load(group_sizes_ptr + g)
301
+ # m can be zero if group is empty
302
+ tl.device_assert(m >= 0, "m < 0")
303
+
304
+ # Loop through tiles of current MM problem.
305
+ while tile >= last_mm_tile and tile < last_mm_tile + num_tiles:
306
+ # Figure out tile coordinates in current MM problem.
307
+ tile_in_mm = tile - last_mm_tile
308
+ tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
309
+
310
+ tile_k, tile_n = _remap_xcd_tile_grid(
311
+ tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
312
+ )
313
+
314
+ # Do regular MM:
315
+
316
+ tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0")
317
+ tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
318
+
319
+ offs_lhs_k = (
320
+ tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
321
+ ) % K
322
+ offs_rhs_n = (
323
+ tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
324
+ ) % N
325
+ offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
326
+
327
+ if TRANS_LHS:
328
+ lhs_ptrs = (
329
+ lhs_ptr + offs_lhs_k[:, None] + (last_m + offs_m[None, :]) * K
330
+ )
331
+ else:
332
+ lhs_ptrs = (
333
+ lhs_ptr + offs_lhs_k[:, None] * M + (last_m + offs_m[None, :])
334
+ )
335
+
336
+ rhs_ptrs = rhs_ptr + (last_m + offs_m[:, None]) * N + offs_rhs_n[None, :]
337
+
338
+ loop_m = tl.cdiv(m, BLOCK_SIZE_M)
339
+ m_divisible_by_block_m = m % BLOCK_SIZE_M == 0
340
+ if not m_divisible_by_block_m:
341
+ loop_m -= 1
342
+
343
+ acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32)
344
+
345
+ # Initialize bias accumulator
346
+ bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
347
+
348
+ for _ in range(0, loop_m):
349
+ lhs = tl.load(lhs_ptrs)
350
+ rhs = tl.load(rhs_ptrs)
351
+
352
+ acc = tl.dot(lhs, rhs, acc=acc)
353
+
354
+ # Accumulate for bias gradient: sum lhs across M dimension
355
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
356
+ bias_acc += tl.sum(
357
+ lhs, axis=1
358
+ ) # Sum across M dimension [K, M] -> [K]
359
+
360
+ if TRANS_LHS:
361
+ lhs_ptrs += BLOCK_SIZE_M * K
362
+ else:
363
+ lhs_ptrs += BLOCK_SIZE_M
364
+
365
+ rhs_ptrs += BLOCK_SIZE_M * N
366
+
367
+ if not m_divisible_by_block_m:
368
+ offs_lhs_k = (
369
+ tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
370
+ ) % K
371
+ offs_rhs_n = (
372
+ tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
373
+ ) % N
374
+ offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
375
+ lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0)
376
+ rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0)
377
+ acc = tl.dot(lhs, rhs, acc=acc)
378
+
379
+ # Accumulate last chunk for bias gradient
380
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
381
+ bias_acc += tl.sum(lhs, axis=1)
382
+
383
+ acc = acc.to(out_ptr.type.element_ty)
384
+
385
+ offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
386
+ offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
387
+
388
+ out_ptrs = (
389
+ out_ptr
390
+ + g.to(tl.int64) * K * N
391
+ + offs_out_k[:, None] * N
392
+ + offs_out_n[None, :]
393
+ )
394
+
395
+ mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N)
396
+ if ACCUMULATE:
397
+ # Load existing values and add to them (like beta=1 in BLAS)
398
+ old_vals = tl.load(out_ptrs, mask=mask, other=0.0)
399
+ tl.store(out_ptrs, acc + old_vals, mask=mask)
400
+ else:
401
+ # Overwrite output (like beta=0 in BLAS)
402
+ tl.store(out_ptrs, acc, mask=mask)
403
+
404
+ # Store bias gradient (only for first N tile, sum across all M)
405
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
406
+ # Keep as float32 for atomic_add (bf16 not supported for atomics)
407
+ bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k
408
+ # Use atomic add since multiple K-tiles may write to same expert's bias
409
+ tl.atomic_add(
410
+ bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed"
411
+ )
412
+
413
+ # Go to the next tile by advancing number of programs.
414
+ tile += GRID_DIM
415
+ tl.device_assert(tile > 0, "tile <= 0 (at update)")
416
+
417
+ # Get ready to go to the next MM problem.
418
+
419
+ last_mm_tile += num_tiles
420
+ # last_mm_tile can be zero if group 0 is skipped
421
+ tl.device_assert(last_mm_tile >= 0, "last_mm_tile < 0 (at update)")
422
+
423
+ last_m += m
424
+ # last_m can be zero if group 0 is skipped
425
+ tl.device_assert(last_m >= 0, "last_m < 0 (at update)")
426
+ tl.device_assert(last_m <= M, "last_m > M (at update)")
427
+
428
+
429
+ # Regular non-persistent TGMM kernel.
430
+ # ------------------------------------------------------------------------------
431
+
432
+
433
+ @triton.heuristics({"BLOCK_SIZE_G": lambda META: triton.next_power_of_2(META["G"])})
434
+ @triton.jit
435
+ def tgmm_non_persistent_kernel(
436
+ # Tensor pointers:
437
+ lhs_ptr,
438
+ rhs_ptr,
439
+ group_sizes_ptr,
440
+ out_ptr,
441
+ bias_grad_ptr,
442
+ # Tensor shapes:
443
+ M: int,
444
+ K: int,
445
+ N: int,
446
+ G: int,
447
+ # Meta-parameters:
448
+ TRANS_LHS: tl.constexpr,
449
+ BLOCK_SIZE_G: tl.constexpr,
450
+ BLOCK_SIZE_M: tl.constexpr,
451
+ BLOCK_SIZE_K: tl.constexpr,
452
+ BLOCK_SIZE_N: tl.constexpr,
453
+ GROUP_SIZE: tl.constexpr,
454
+ COMPUTE_BIAS_GRAD: tl.constexpr,
455
+ ACCUMULATE: tl.constexpr,
456
+ ):
457
+ tl.assume(M > 0)
458
+ tl.assume(K > 0)
459
+ tl.assume(N > 0)
460
+ tl.assume(G > 0)
461
+
462
+ # Get group ID from grid.
463
+ g = tl.program_id(0)
464
+ tl.device_assert(g >= 0, "g < 0")
465
+ tl.device_assert(g < G, "g >= G")
466
+
467
+ # Get m dimension of current MM group.
468
+ m = tl.load(group_sizes_ptr + g)
469
+ # m can be zero if group is empty.
470
+ tl.device_assert(m >= 0, "m < 0")
471
+
472
+ # Skip empty groups.
473
+ if m == 0:
474
+ return
475
+
476
+ # Compute sum(group_sizes) until current group g.
477
+ # It's the starting column of lhs and starting row of rhs.
478
+ offs_g = tl.arange(0, BLOCK_SIZE_G)
479
+ group_sizes = tl.load(group_sizes_ptr + offs_g, mask=offs_g < g, other=0)
480
+ start_m = tl.sum(group_sizes)
481
+
482
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
483
+ tl.device_assert(num_k_tiles > 0, "num_k_tiles <= 0")
484
+
485
+ num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
486
+ tl.device_assert(num_n_tiles > 0, "num_n_tiles <= 0")
487
+
488
+ # Get MM tile from grid.
489
+ tile_in_mm = tl.program_id(1)
490
+ tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0")
491
+
492
+ tile_k, tile_n = _remap_xcd_tile_grid(
493
+ tile_in_mm, num_k_tiles, num_n_tiles, GROUP_SIZE=GROUP_SIZE
494
+ )
495
+
496
+ tl.device_assert(tile_k * BLOCK_SIZE_K >= 0, "tile_k * BLOCK_SIZE_K < 0")
497
+ tl.device_assert(tile_n * BLOCK_SIZE_N >= 0, "tile_n * BLOCK_SIZE_N < 0")
498
+
499
+ offs_lhs_k = (tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K
500
+ offs_rhs_n = (tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
501
+ offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
502
+
503
+ if TRANS_LHS:
504
+ lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] + (start_m + offs_m[None, :]) * K
505
+ else:
506
+ lhs_ptrs = lhs_ptr + offs_lhs_k[:, None] * M + (start_m + offs_m[None, :])
507
+
508
+ rhs_ptrs = rhs_ptr + (start_m + offs_m[:, None]) * N + offs_rhs_n[None, :]
509
+
510
+ loop_m = tl.cdiv(m, BLOCK_SIZE_M)
511
+ m_divisible_by_block_m = m % BLOCK_SIZE_M == 0
512
+ if not m_divisible_by_block_m:
513
+ loop_m -= 1
514
+
515
+ acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32)
516
+ # Initialize bias accumulator
517
+ bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
518
+
519
+ for _ in range(0, loop_m):
520
+ lhs = tl.load(lhs_ptrs)
521
+ rhs = tl.load(rhs_ptrs)
522
+
523
+ acc = tl.dot(lhs, rhs, acc=acc)
524
+
525
+ # Accumulate for bias gradient: sum lhs across M dimension
526
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
527
+ bias_acc += tl.sum(lhs, axis=1) # [K, M] -> [K]
528
+
529
+ if TRANS_LHS:
530
+ lhs_ptrs += BLOCK_SIZE_M * K
531
+ else:
532
+ lhs_ptrs += BLOCK_SIZE_M
533
+
534
+ rhs_ptrs += BLOCK_SIZE_M * N
535
+
536
+ if not m_divisible_by_block_m:
537
+ offs_lhs_k = (
538
+ tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
539
+ ) % K
540
+ offs_rhs_n = (
541
+ tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
542
+ ) % N
543
+ offs_m = loop_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
544
+ lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0)
545
+ rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0)
546
+ acc = tl.dot(lhs, rhs, acc=acc)
547
+ # Accumulate last chunk for bias gradient
548
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
549
+ bias_acc += tl.sum(lhs, axis=1)
550
+
551
+ acc = acc.to(out_ptr.type.element_ty)
552
+
553
+ offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
554
+ offs_out_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
555
+
556
+ out_ptrs = (
557
+ out_ptr + g.to(tl.int64) * K * N + offs_out_k[:, None] * N + offs_out_n[None, :]
558
+ )
559
+
560
+ mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N)
561
+ if ACCUMULATE:
562
+ # Load existing values and add to them (like beta=1 in BLAS)
563
+ old_vals = tl.load(out_ptrs, mask=mask, other=0.0)
564
+ tl.store(out_ptrs, acc + old_vals, mask=mask)
565
+ else:
566
+ # Overwrite output (like beta=0 in BLAS)
567
+ tl.store(out_ptrs, acc, mask=mask)
568
+
569
+ # Store bias gradient (only for first N tile, sum across all M)
570
+ if COMPUTE_BIAS_GRAD and tile_n == 0:
571
+ # Keep as float32 for atomic_add (bf16/fp16 not supported for atomics)
572
+ bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k
573
+ # Use atomic add since multiple K-tiles may write to same expert's bias
574
+ tl.atomic_add(bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed")
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/adapter.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Adapt AITER's Triton grouped GEMM to MegaBlocks' ``gmm`` calling convention.
3
+
4
+ MegaBlocks (following tgale96/grouped_gemm) uses a single ``gmm`` entry point
5
+ with ``trans_a`` / ``trans_b`` flags:
6
+
7
+ * ``trans_a=False, trans_b=False``: a(M,K) @ b(G,K,N) -> c(M,N)
8
+ * ``trans_a=False, trans_b=True`` : a(M,K) @ b(G,N,K)^T -> c(M,N) (dgrad)
9
+ * ``trans_a=True`` : a(M,K)^T @ b(M,N) per group -> c(G,K,N) (wgrad)
10
+
11
+ AITER exposes these as two kernels: ``gmm`` ((M,K)@(G,K,N)->(M,N), transposition
12
+ of the 3D operand inferred from strides) and ``ptgmm`` ((K,M)@(M,N)->(G,K,N),
13
+ transposition of the 2D operand inferred from strides).
14
+ """
15
+
16
+ import torch
17
+
18
+ from .gmm import gmm as _aiter_gmm
19
+ from .gmm import ptgmm as _aiter_ptgmm
20
+
21
+
22
+ def gmm(a, b, c, batch_sizes, trans_a=False, trans_b=False):
23
+ # AITER requires group sizes to be int32 and to live on the compute device.
24
+ group_sizes = batch_sizes.to(device=a.device, dtype=torch.int32)
25
+
26
+ # AITER asserts exact strides: gmm wants lhs/rhs row-major (a transposed
27
+ # 3D operand must be exactly column-major), tgmm wants rhs row-major and
28
+ # lhs row/column-major. Make operands contiguous first so the transposed
29
+ # views have the precise strides the kernels expect. `.contiguous()` is a
30
+ # no-op when the tensor is already contiguous.
31
+ if trans_a:
32
+ # Weight gradient: a(M,K), b(M,N) -> c(G,K,N).
33
+ # Pass a transposed so AITER sees lhs(K,M) column-major (TRANS_LHS).
34
+ _aiter_ptgmm(
35
+ a.contiguous().transpose(0, 1),
36
+ b.contiguous(),
37
+ group_sizes,
38
+ preferred_element_type=c.dtype,
39
+ existing_out=c,
40
+ )
41
+ else:
42
+ # trans_b contracts b's last dim: pass a column-major (G,K,N) view.
43
+ rhs = b.contiguous()
44
+ if trans_b:
45
+ rhs = rhs.transpose(1, 2)
46
+ _aiter_gmm(
47
+ a.contiguous(),
48
+ rhs,
49
+ group_sizes,
50
+ preferred_element_type=c.dtype,
51
+ existing_out=c,
52
+ )
53
+ return c
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/configs.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+ # Tuned GMM configs vendored from ROCm/aiter (aiter/ops/triton/configs/).
3
+ # Inlined as a Python module so packaging always includes them.
4
+
5
+ CONFIGS = {'gfx1250': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx942': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 304, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}, 'gfx950': {'gmm': {'default': {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_K': 64, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'ptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'GRID_DIM': 256, 'num_warps': 8, 'num_stages': 1}}, 'nptgmm': {'default': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 256, 'BLOCK_SIZE_N': 256, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}, 'accumulate': {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 128, 'BLOCK_SIZE_N': 128, 'GROUP_SIZE': 1, 'num_warps': 8, 'num_stages': 1}}}}
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/gmm.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+ # Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+
5
+ # Imports.
6
+ # ------------------------------------------------------------------------------
7
+
8
+ # PyTorch
9
+ import torch
10
+ from torch import Tensor
11
+
12
+ # Triton
13
+ import triton
14
+
15
+ # AITER: GMM utility functions
16
+ from .utils.gmm_common import (
17
+ DTYPE,
18
+ is_power_of_2,
19
+ check_input_device_dtype,
20
+ check_bias_shape_stride,
21
+ get_gmm_shape,
22
+ get_gmm_output,
23
+ get_gmm_transposition,
24
+ get_tgmm_shape,
25
+ get_tgmm_output,
26
+ get_tgmm_bias_grad,
27
+ get_tgmm_transposition,
28
+ )
29
+
30
+ # AITER: GMM Triton kernels
31
+ from ._triton_kernels.gmm import (
32
+ gmm_kernel,
33
+ tgmm_persistent_kernel,
34
+ tgmm_non_persistent_kernel,
35
+ get_config,
36
+ )
37
+
38
+ # GMM PyTorch wrapper.
39
+ # ------------------------------------------------------------------------------
40
+
41
+
42
+ def _gmm_grid(
43
+ N: int,
44
+ block_size_m: int,
45
+ block_size_n: int,
46
+ group_sizes: Tensor,
47
+ grid_dim: int,
48
+ ) -> tuple[int]:
49
+ assert N > 0, f"N must be positive, it's {N}."
50
+ assert is_power_of_2(
51
+ block_size_m
52
+ ), f"M-dimension tile size must be a power of 2 (it's {block_size_m})."
53
+ assert is_power_of_2(
54
+ block_size_n
55
+ ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
56
+ assert torch.all(group_sizes >= 0).item(), "All group_sizes must be non-negative."
57
+ assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})."
58
+ num_m_tiles = (group_sizes + block_size_m - 1) // block_size_m
59
+ assert torch.all(num_m_tiles >= 0).item(), "All num_m_tiles must be non-negative."
60
+ num_n_tiles = triton.cdiv(N, block_size_n)
61
+ assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
62
+ num_tiles = torch.sum(num_m_tiles * num_n_tiles).item()
63
+ assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}."
64
+ num_programs = int(min(grid_dim, num_tiles))
65
+ assert num_programs > 0, f"num_programs must be positive, it's {num_programs}."
66
+ return (num_programs,)
67
+
68
+
69
+ def gmm(
70
+ lhs: Tensor,
71
+ rhs: Tensor,
72
+ group_sizes: Tensor,
73
+ preferred_element_type: torch.dtype = DTYPE,
74
+ existing_out: Tensor | None = None,
75
+ config: dict[str, int] | None = None,
76
+ bias: Tensor | None = None,
77
+ ) -> Tensor:
78
+ """
79
+ Perform Group Matrix Multiplication (GMM): out = lhs @ rhs + bias
80
+
81
+ lhs rows are divided into G groups. Each group of lhs rows is matrix multiplied with a plane of
82
+ rhs 3D tensor and then stored in a slice of out. In PyTorch parlance, it can be implemented as
83
+ follows for a given group g:
84
+ out[group_start:group_end, :] = lhs[group_start:group_end, :] @ rhs[g] + bias[g]
85
+
86
+ The size of each group, and their respective start and end positions are specified by
87
+ group_sizes tensor. For instance, suppose that group_sizes = [3, 2, 4, 1]. In this particular
88
+ case we have 4 groups. The 1st group starts at 0 and ends at 2, the second group starts at 3 and
89
+ ends at 4, the third group starts at 5 and ends at 8, and the fourth and final group consists of
90
+ just the 10th (last) row of lhs.
91
+
92
+ Parameters
93
+ ----------
94
+ lhs : torch.Tensor
95
+ Left-hand side 2D input tensor. Shape: (M, K).
96
+ lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
97
+ lhs must be on the same device of rhs and group_sizes.
98
+ rhs : torch.Tensor
99
+ Right-hand side 3D input tensor. Shape: (G, K, N).
100
+ rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
101
+ rhs must be on the same device of lhs and group_sizes.
102
+ group_sizes : torch.Tensor
103
+ 1D input tensor describing group sizes. Shape: (G,).
104
+ group_sizes data type must be torch.int32 and all its elements must be non-negative.
105
+ group_sizes must be on the same device of lhs and rhs.
106
+ preferred_element_type : torch.dtype, optional
107
+ Desired data type for output tensor. Default is torch.bfloat16.
108
+ Supported output types are torch.float16 and torch.bfloat16.
109
+ existing_out : torch.Tensor or None, optional
110
+ Preallocated output tensor. Default is None.
111
+ If provided, results are written into this tensor. Otherwise, a new output tensor is
112
+ allocated.
113
+ If provided then it must have shape (M, N), its data type must match preferred_element_type
114
+ and it must be on the same device of other input tensors.
115
+ config : dict[str, int] or None, optional
116
+ Optional dictionary with kernel metaparameters. If absent, config will be queried from
117
+ internal tuning database.
118
+ bias : torch.Tensor or None, optional
119
+ Optional bias tensor. Shape: (G, N).
120
+ If provided, bias data type must match lhs and rhs data type, and bias must be on the same
121
+ device as other input tensors. Each group g adds bias[g] to the output.
122
+
123
+ Returns
124
+ -------
125
+ torch.Tensor
126
+ The computed output 2D tensor. Shape: (M, N).
127
+ Output tensor data type is given by preferred_element_type.
128
+ If existing_out is provided then existing_out is also returned.
129
+
130
+ Implementation Notes
131
+ --------------------
132
+ - GMM is implemented with a persistent Triton kernel.
133
+ - lhs must be row-major (lhs.stride() == (K, 1)).
134
+ - rhs can be row-major (rhs.stride() == (K * N, N, 1)) or column-major (rhs.stride() ==
135
+ (K * N, 1, K)). If rhs is row-major then kernel parameter TRANS_RHS == False, this is useful
136
+ for implementing forward pass. If rhs is column-major then kernel parameter TRANS_RHS == True,
137
+ this is useful for computing the lhs derivative in the backward pass, while fusing the
138
+ transposition.
139
+ - out must be row-major (out.stride() == (N, 1)).
140
+ - bias must be row-major (bias.stride() == (N, 1)) if provided.
141
+ """
142
+ use_bias = bias is not None
143
+ check_input_device_dtype(lhs, rhs, group_sizes, bias)
144
+
145
+ M, K, N, G = get_gmm_shape(lhs, rhs, group_sizes)
146
+
147
+ if use_bias:
148
+ check_bias_shape_stride(bias, G, N)
149
+
150
+ out = get_gmm_output(
151
+ M,
152
+ N,
153
+ device=lhs.device,
154
+ preferred_element_type=preferred_element_type,
155
+ existing_out=existing_out,
156
+ )
157
+
158
+ trans_rhs, _ = get_gmm_transposition(lhs, rhs, out)
159
+
160
+ if config is None:
161
+ config = get_config("gmm", M, K, N, G)
162
+
163
+ assert all(
164
+ key in config
165
+ and isinstance(config[key], int)
166
+ and (
167
+ is_power_of_2(config[key])
168
+ if key.startswith("BLOCK_SIZE_")
169
+ else config[key] > 0
170
+ )
171
+ for key in {
172
+ "BLOCK_SIZE_M",
173
+ "BLOCK_SIZE_K",
174
+ "BLOCK_SIZE_N",
175
+ "GROUP_SIZE",
176
+ "GRID_DIM",
177
+ }
178
+ ), "Invalid GMM kernel config."
179
+
180
+ grid = _gmm_grid(
181
+ N,
182
+ config["BLOCK_SIZE_M"],
183
+ config["BLOCK_SIZE_N"],
184
+ group_sizes,
185
+ config["GRID_DIM"],
186
+ )
187
+
188
+ # fmt: off
189
+ gmm_kernel[grid](
190
+ # Tensor pointers:
191
+ lhs, rhs, group_sizes, out, bias,
192
+ # Tensor shapes:
193
+ M, K, N, G,
194
+ # Meta-parameters:
195
+ TRANS_RHS=trans_rhs,
196
+ USE_BIAS=use_bias,
197
+ **config,
198
+ )
199
+ # fmt: on
200
+
201
+ return out
202
+
203
+
204
+ # Persistent TGMM PyTorch wrapper.
205
+ # ------------------------------------------------------------------------------
206
+
207
+
208
+ def _ptgmm_grid(
209
+ K: int,
210
+ N: int,
211
+ G: int,
212
+ block_size_k: int,
213
+ block_size_n: int,
214
+ grid_dim: int,
215
+ ) -> tuple[int]:
216
+ assert K > 0, f"K must be positive, it's {K}."
217
+ assert N > 0, f"N must be positive, it's {N}."
218
+ assert G > 0, f"G must be positive, it's {G}."
219
+ assert is_power_of_2(
220
+ block_size_k
221
+ ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})."
222
+ assert is_power_of_2(
223
+ block_size_n
224
+ ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
225
+ assert grid_dim > 0, f"Grid dimension must be positive (it's {grid_dim})."
226
+ num_k_tiles = triton.cdiv(K, block_size_k)
227
+ assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}."
228
+ num_n_tiles = triton.cdiv(N, block_size_n)
229
+ assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
230
+ num_tiles = G * num_k_tiles * num_n_tiles
231
+ assert num_tiles > 0, f"num_tiles must be positive, it's {num_tiles}."
232
+ num_programs = min(grid_dim, num_tiles)
233
+ assert num_programs > 0, f"num_programs must be positive, it's {num_programs}."
234
+ return (num_programs,)
235
+
236
+
237
+ def ptgmm(
238
+ lhs: Tensor,
239
+ rhs: Tensor,
240
+ group_sizes: Tensor,
241
+ preferred_element_type: torch.dtype = DTYPE,
242
+ existing_out: Tensor | None = None,
243
+ config: dict[str, int] | None = None,
244
+ bias_grad: Tensor | None = None,
245
+ accumulate: bool = False,
246
+ ) -> Tensor:
247
+ """
248
+ Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs
249
+
250
+ lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with
251
+ the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch
252
+ parlance, it can be implemented as follows for a given group g:
253
+ out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :]
254
+
255
+ The 't' in the operator name derives from MaxText implementation
256
+ (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py),
257
+ which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor
258
+ shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N).
259
+
260
+ The 'p' in the operator name means that it is implemented with a persistent kernel. There is
261
+ also the non-persistent variation, which is implemented with a regular kernel. Please take a
262
+ look at nptgmm operator. Both ptgmm and nptgmm implement the same computation, choosing one or
263
+ the other is a matter of performance for the target workload.
264
+
265
+ Parameters
266
+ ----------
267
+ lhs : torch.Tensor
268
+ Left-hand side 2D input tensor. Shape: (K, M).
269
+ lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
270
+ lhs must be on the same device of rhs and group_sizes.
271
+ rhs : torch.Tensor
272
+ Right-hand side 2D input tensor. Shape: (M, N).
273
+ rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
274
+ rhs must be on the same device of lhs and group_sizes.
275
+ group_sizes : torch.Tensor
276
+ 1D input tensor describing group sizes. Shape: (G,).
277
+ group_sizes data type must be torch.int32 and all its elements must be non-negative.
278
+ group_sizes must be on the same device of lhs and rhs.
279
+ preferred_element_type : torch.dtype, optional
280
+ Desired data type for output tensor. Default is torch.bfloat16.
281
+ Supported output types are torch.float16 and torch.bfloat16.
282
+ existing_out : torch.Tensor or None, optional
283
+ Preallocated output tensor. Default is None.
284
+ If provided, results are written into this tensor. Otherwise, a new output tensor is
285
+ allocated.
286
+ If provided then it must have shape (G, K, N), its data type must match
287
+ preferred_element_type and it must be on the same device of other input tensors.
288
+ config : dict[str, int] or None, optional
289
+ Optional dictionary with kernel metaparameters. If absent, config will be queried from
290
+ internal tuning database.
291
+ bias_grad : torch.Tensor or None, optional
292
+ Optional bias gradient output tensor. Shape: (G, K).
293
+ If provided, the kernel will compute the bias gradient and write it to this tensor.
294
+ bias_grad must be torch.float32 (kernel uses atomic_add which requires float32),
295
+ accumulate : bool, optional
296
+ Whether to accumulate into existing output tensor values. Default is False.
297
+ If False, output will be overwritten with fresh computation.
298
+ If True, results will be added to existing output tensor values.
299
+
300
+ Returns
301
+ -------
302
+ torch.Tensor
303
+ The computed output 3D tensor. Shape: (G, K, N).
304
+ Output tensor data type is given by preferred_element_type.
305
+ If existing_out is provided then existing_out is also returned.
306
+
307
+ Implementation Notes
308
+ --------------------
309
+ - PTGMM is implemented with a persistent Triton kernel.
310
+ - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs
311
+ is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel
312
+ parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward
313
+ pass, while fusing the transposition.
314
+ - rhs must be row-major (rhs.stride() == (N, 1)).
315
+ - out must be row-major (out.stride() == (K * N, N, 1)).
316
+ """
317
+ check_input_device_dtype(lhs, rhs, group_sizes)
318
+
319
+ M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes)
320
+
321
+ out = get_tgmm_output(
322
+ K,
323
+ N,
324
+ G,
325
+ device=lhs.device,
326
+ preferred_element_type=preferred_element_type,
327
+ existing_out=existing_out,
328
+ )
329
+
330
+ trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out)
331
+
332
+ if config is None:
333
+ config = get_config("ptgmm", M, K, N, G, accumulate)
334
+
335
+ assert all(
336
+ key in config
337
+ and isinstance(config[key], int)
338
+ and (
339
+ is_power_of_2(config[key])
340
+ if key.startswith("BLOCK_SIZE_")
341
+ else config[key] > 0
342
+ )
343
+ for key in {
344
+ "BLOCK_SIZE_M",
345
+ "BLOCK_SIZE_K",
346
+ "BLOCK_SIZE_N",
347
+ "GROUP_SIZE",
348
+ "GRID_DIM",
349
+ }
350
+ ), "Invalid PTGMM kernel config."
351
+
352
+ # Bias gradient handling.
353
+ # -----------------------
354
+ # Get or validate bias gradient tensor.
355
+ compute_bias_grad = bias_grad is not None
356
+ bias_grad_ptr = get_tgmm_bias_grad(
357
+ K,
358
+ G,
359
+ device=lhs.device,
360
+ existing_bias_grad=bias_grad,
361
+ )
362
+
363
+ grid = _ptgmm_grid(
364
+ K,
365
+ N,
366
+ G,
367
+ config["BLOCK_SIZE_K"],
368
+ config["BLOCK_SIZE_N"],
369
+ config["GRID_DIM"],
370
+ )
371
+
372
+ # fmt: off
373
+ tgmm_persistent_kernel[grid](
374
+ # Tensor pointers:
375
+ lhs, rhs, group_sizes, out, bias_grad_ptr,
376
+ # Tensor shapes:
377
+ M, K, N, G,
378
+ # Meta-parameters:
379
+ TRANS_LHS=trans_lhs,
380
+ COMPUTE_BIAS_GRAD=compute_bias_grad,
381
+ ACCUMULATE=accumulate,
382
+ **config,
383
+ )
384
+ # fmt: on
385
+
386
+ return out
387
+
388
+
389
+ # Regular non-persistent TGMM PyTorch wrapper.
390
+ # ------------------------------------------------------------------------------
391
+
392
+
393
+ def _nptgmm_grid(
394
+ K: int,
395
+ N: int,
396
+ G: int,
397
+ block_size_k: int,
398
+ block_size_n: int,
399
+ ) -> tuple[int, int]:
400
+ assert K > 0, f"K must be positive, it's {K}."
401
+ assert N > 0, f"N must be positive, it's {N}."
402
+ assert G > 0, f"G must be positive, it's {G}."
403
+ assert is_power_of_2(
404
+ block_size_k
405
+ ), f"K-dimension tile size must be a power of 2 (it's {block_size_k})."
406
+ assert is_power_of_2(
407
+ block_size_n
408
+ ), f"N-dimension tile size must be a power of 2 (it's {block_size_n})."
409
+ num_k_tiles = triton.cdiv(K, block_size_k)
410
+ assert num_k_tiles > 0, f"num_k_tiles must be positive, it's {num_k_tiles}."
411
+ num_n_tiles = triton.cdiv(N, block_size_n)
412
+ assert num_n_tiles > 0, f"num_n_tiles must be positive, it's {num_n_tiles}."
413
+ num_tiles_per_mm = num_k_tiles * num_n_tiles
414
+ assert (
415
+ num_tiles_per_mm > 0
416
+ ), f"num_tiles_per_mm must be positive, it's {num_tiles_per_mm}."
417
+ return (G, num_tiles_per_mm)
418
+
419
+
420
+ def nptgmm(
421
+ lhs: Tensor,
422
+ rhs: Tensor,
423
+ group_sizes: Tensor,
424
+ preferred_element_type: torch.dtype = DTYPE,
425
+ existing_out: Tensor | None = None,
426
+ config: dict[str, int] | None = None,
427
+ bias_grad: Tensor | None = None,
428
+ accumulate: bool = False,
429
+ ) -> Tensor:
430
+ """
431
+ Perform a Group Matrix Multiplication (GMM) variant: out = lhs @ rhs
432
+
433
+ lhs columns and rhs rows are divided into G groups. Each group of lhs is matrix multiplied with
434
+ the respective group of rhs and then stored in a plane of the output 3D tensor. In PyTorch
435
+ parlance, it can be implemented as follows for a given group g:
436
+ out[g] = lhs[:, group_start:group_end] @ rhs[group_start:group_end, :]
437
+
438
+ The 't' in the operator name derives from MaxText implementation
439
+ (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py),
440
+ which served as the initial inspiration for this one. TGMM differs from GMM in terms of tensor
441
+ shapes. GMM does (M, K) @ (G, K, N) = (M, N) while TGMM does (K, M) @ (M, N) = (G, K, N).
442
+
443
+ The 'np' in the operator name means that it is implemented with a non-persistent, i.e. regular
444
+ kernel. There is also the persistent variation, which is implemented with a persistent kernel.
445
+ Please take a look at ptgmm operator. Both nptgmm and ptgmm implement the same computation,
446
+ choosing one or the other is a matter of performance for the target workload.
447
+
448
+ Parameters
449
+ ----------
450
+ lhs : torch.Tensor
451
+ Left-hand side 2D input tensor. Shape: (K, M).
452
+ lhs data type must be torch.float16 or torch.bfloat16, and must match rhs data type.
453
+ lhs must be on the same device of rhs and group_sizes.
454
+ rhs : torch.Tensor
455
+ Right-hand side 2D input tensor. Shape: (M, N).
456
+ rhs data type must be torch.float16 or torch.bfloat16, and must match lhs data type.
457
+ rhs must be on the same device of lhs and group_sizes.
458
+ group_sizes : torch.Tensor
459
+ 1D input tensor describing group sizes. Shape: (G,).
460
+ group_sizes data type must be torch.int32 and all its elements must be non-negative.
461
+ group_sizes must be on the same device of lhs and rhs.
462
+ preferred_element_type : torch.dtype, optional
463
+ Desired data type for output tensor. Default is torch.bfloat16.
464
+ Supported output types are torch.float16 and torch.bfloat16.
465
+ existing_out : torch.Tensor or None, optional
466
+ Preallocated output tensor. Default is None.
467
+ If provided, results are written into this tensor. Otherwise, a new output tensor is
468
+ allocated.
469
+ If provided then it must have shape (G, K, N), its data type must match
470
+ preferred_element_type and it must be on the same device of other input tensors.
471
+ config : dict[str, int] or None, optional
472
+ Optional dictionary with kernel metaparameters. If absent, config will be queried from
473
+ internal tuning database.
474
+ bias_grad : torch.Tensor or None, optional
475
+ Optional bias gradient output tensor. Shape: (G, K).
476
+ If provided, the kernel will compute the bias gradient and write it to this tensor.
477
+ bias_grad must be torch.float32 (kernel uses atomic_add which requires float32),
478
+ accumulate : bool, optional
479
+ Whether to accumulate into existing output tensor values. Default is False.
480
+ If False, output will be overwritten with fresh computation.
481
+ If True, results will be added to existing output tensor values.
482
+
483
+ Returns
484
+ -------
485
+ torch.Tensor
486
+ The computed output 3D tensor. Shape: (G, K, N).
487
+ Output tensor data type is given by preferred_element_type.
488
+ If existing_out is provided then existing_out is also returned.
489
+
490
+ Implementation Notes
491
+ --------------------
492
+ - NPTGMM is implemented with a non-persistent regular Triton kernel.
493
+ - lhs can be row-major (lhs.stride() == (M, 1)) or column-major (lhs.stride() == (1, K)). If lhs
494
+ is row-major then kernel parameter TRANS_LHS == False. If lhs is column-major then kernel
495
+ parameter TRANS_LHS == True, this is useful for computing the rhs derivative in the backward
496
+ pass, while fusing the transposition.
497
+ - rhs must be row-major (rhs.stride() == (N, 1)).
498
+ - out must be row-major (out.stride() == (K * N, N, 1)).
499
+ """
500
+ check_input_device_dtype(lhs, rhs, group_sizes)
501
+
502
+ M, K, N, G = get_tgmm_shape(lhs, rhs, group_sizes)
503
+
504
+ out = get_tgmm_output(
505
+ K,
506
+ N,
507
+ G,
508
+ device=lhs.device,
509
+ preferred_element_type=preferred_element_type,
510
+ existing_out=existing_out,
511
+ )
512
+
513
+ trans_lhs, _ = get_tgmm_transposition(lhs, rhs, out)
514
+
515
+ # Bias gradient handling.
516
+ # -----------------------
517
+ # Get or validate bias gradient tensor.
518
+ compute_bias_grad = bias_grad is not None
519
+ bias_grad_ptr = get_tgmm_bias_grad(
520
+ K,
521
+ G,
522
+ device=lhs.device,
523
+ existing_bias_grad=bias_grad,
524
+ )
525
+
526
+ if config is None:
527
+ config = get_config("nptgmm", M, K, N, G, accumulate)
528
+
529
+ assert all(
530
+ key in config
531
+ and isinstance(config[key], int)
532
+ and (
533
+ is_power_of_2(config[key])
534
+ if key.startswith("BLOCK_SIZE_")
535
+ else config[key] > 0
536
+ )
537
+ for key in {
538
+ "BLOCK_SIZE_M",
539
+ "BLOCK_SIZE_K",
540
+ "BLOCK_SIZE_N",
541
+ "GROUP_SIZE",
542
+ }
543
+ ), "Invalid NPTGMM kernel config."
544
+
545
+ grid = _nptgmm_grid(
546
+ K,
547
+ N,
548
+ G,
549
+ config["BLOCK_SIZE_K"],
550
+ config["BLOCK_SIZE_N"],
551
+ )
552
+
553
+ # fmt: off
554
+ tgmm_non_persistent_kernel[grid](
555
+ # Tensor pointers:
556
+ lhs, rhs, group_sizes, out, bias_grad_ptr,
557
+ # Tensor shapes:
558
+ M, K, N, G,
559
+ # Meta-parameters:
560
+ TRANS_LHS=trans_lhs,
561
+ COMPUTE_BIAS_GRAD=compute_bias_grad,
562
+ ACCUMULATE=accumulate,
563
+ **config,
564
+ )
565
+ # fmt: on
566
+
567
+ return out
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/__init__.py ADDED
File without changes
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/__init__.py ADDED
File without changes
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/arch_info.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+
3
+ # Detect the GPU arch lazily: querying the triton driver at import time fails
4
+ # in headless environments (e.g. the kernel-builder ABI check sandbox has no
5
+ # GPU), and the original JAX fallback pulled in an unrelated runtime dep. The
6
+ # arch is only actually needed when a GMM kernel is dispatched, so resolve and
7
+ # cache on first call.
8
+ _CACHED_ARCH = None
9
+
10
+
11
+ def get_arch():
12
+ global _CACHED_ARCH
13
+ if _CACHED_ARCH is not None:
14
+ return _CACHED_ARCH
15
+ try:
16
+ _CACHED_ARCH = triton.runtime.driver.active.get_current_target().arch
17
+ except RuntimeError:
18
+ try:
19
+ from jax._src.lib import gpu_triton as triton_kernel_call_lib
20
+ _CACHED_ARCH = triton_kernel_call_lib.get_arch_details("0").split(":")[0]
21
+ except ImportError as e:
22
+ raise RuntimeError(
23
+ "Cannot determine GPU arch: triton driver is inactive and "
24
+ "JAX is not available. A GPU is required for grouped GEMM."
25
+ ) from e
26
+ return _CACHED_ARCH
27
+
28
+
29
+ def is_gluon_avail():
30
+ return get_arch() in ("gfx950", "gfx1250")
31
+
32
+
33
+ def is_fp4_avail():
34
+ return get_arch() in ("gfx950", "gfx1250")
35
+
36
+
37
+ def is_fp8_avail():
38
+ return get_arch() in ("gfx942", "gfx950", "gfx1250", "gfx1200", "gfx1201")
39
+
40
+
41
+ def is_mx_scale_preshuffling_avail():
42
+ return get_arch() in ("gfx950", "gfx1250")
43
+
44
+
45
+ def is_tdm_avail():
46
+ return get_arch() in ("gfx1250",)
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/_triton/pid_preprocessing.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+
3
+ # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+
9
+ @triton.jit
10
+ def remap_xcd_chunked(
11
+ pid, GRID_MN, NUM_XCDS: tl.constexpr = 8, CHUNK_SIZE: tl.constexpr = 2
12
+ ):
13
+ # Compute current XCD and local PID
14
+ xcd = pid % NUM_XCDS
15
+ # distribute the modulo pids in round robin
16
+ if pid > (GRID_MN // (NUM_XCDS * CHUNK_SIZE)) * (NUM_XCDS * CHUNK_SIZE):
17
+ return pid
18
+ local_pid = pid // NUM_XCDS
19
+ # Calculate chunk index and position within chunk
20
+ chunk_idx = local_pid // CHUNK_SIZE
21
+ pos_in_chunk = local_pid % CHUNK_SIZE
22
+ # Calculate new PID
23
+ new_pid = chunk_idx * NUM_XCDS * CHUNK_SIZE + xcd * CHUNK_SIZE + pos_in_chunk
24
+ return new_pid
25
+
26
+
27
+ @triton.jit
28
+ def remap_xcd(pid, GRID_MN, NUM_XCDS: tl.constexpr = 8):
29
+ ## pid remapping on xcds
30
+ # Number of pids per XCD in the new arrangement
31
+ pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS
32
+ # When GRID_MN cannot divide NUM_XCDS, some xcds will have
33
+ # pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
34
+ # We calculate the number of xcds that have pids_per_xcd pids as
35
+ # tall_xcds
36
+ tall_xcds = GRID_MN % NUM_XCDS
37
+ tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds
38
+ # Compute current XCD and local pid within the XCD
39
+ xcd = pid % NUM_XCDS
40
+ local_pid = pid // NUM_XCDS
41
+ # Calculate new pid based on the new grouping
42
+ # Note that we need to consider the following two cases:
43
+ # 1. the current pid is on a tall xcd
44
+ # 2. the current pid is on a short xcd
45
+ if xcd < tall_xcds:
46
+ pid = xcd * pids_per_xcd + local_pid
47
+ else:
48
+ pid = (
49
+ tall_xcds * pids_per_xcd
50
+ + (xcd - tall_xcds) * (pids_per_xcd - 1)
51
+ + local_pid
52
+ )
53
+
54
+ return pid
55
+
56
+
57
+ @triton.jit
58
+ def pid_grid(pid: int, num_pid_m: int, num_pid_n: int, GROUP_SIZE_M: tl.constexpr = 1):
59
+ """
60
+ Maps 1D pid to 2D grid coords (pid_m, pid_n).
61
+
62
+ Args:
63
+ - pid: 1D pid
64
+ - num_pid_m: grid m size
65
+ - num_pid_n: grid n size
66
+ - GROUP_SIZE_M: tl.constexpr: default is 1
67
+ """
68
+ if GROUP_SIZE_M == 1:
69
+ pid_m = pid // num_pid_n
70
+ pid_n = pid % num_pid_n
71
+ else:
72
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
73
+ group_id = pid // num_pid_in_group
74
+ first_pid_m = group_id * GROUP_SIZE_M
75
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
76
+ tl.assume(group_size_m >= 0)
77
+ pid_m = first_pid_m + (pid % group_size_m)
78
+ pid_n = (pid % num_pid_in_group) // group_size_m
79
+
80
+ return pid_m, pid_n
81
+
82
+
83
+ @triton.jit
84
+ def pid_grid_3d(pid: int, num_pid_m: int, num_pid_n: int, num_pid_k):
85
+ """
86
+ Maps 1D pid to 3D grid coords (pid_m, pid_n, pid_k).
87
+ Args:
88
+ - pid: 1D pid
89
+ - num_pid_m: grid m size
90
+ - num_pid_n: grid n size
91
+ - num_pid_k: grid k size
92
+
93
+ Returns:
94
+ - pid_m, pid_n, pid_k: 3D grid coordinates
95
+ """
96
+ pid_m = pid % num_pid_m
97
+ pid_n = (pid // num_pid_m) % num_pid_n
98
+ pid_k = pid // (num_pid_m * num_pid_n) % num_pid_k
99
+
100
+ return pid_m, pid_n, pid_k
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/gmm_common.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+ # Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+ # Imports.
5
+ # ------------------------------------------------------------------------------
6
+
7
+ # PyTorch
8
+ import torch
9
+ from torch import Tensor
10
+
11
+ # AITER: logging
12
+ from .logger import AiterTritonLogger
13
+
14
+ _LOGGER: AiterTritonLogger = AiterTritonLogger()
15
+
16
+
17
+ # Supported data types.
18
+ # ------------------------------------------------------------------------------
19
+
20
+ # Supported data types, as strings.
21
+ SUPPORTED_DTYPES_STR: set[str] = {"fp16", "bf16"}
22
+
23
+
24
+ # Convert string data type to PyTorch data type.
25
+ def dtype_from_str(dtype_str: str) -> torch.dtype:
26
+ dtype_str = dtype_str.strip().lower()
27
+ dtype_str = dtype_str[1:] if dtype_str[0] in {"i", "o"} else dtype_str
28
+ assert (
29
+ dtype_str in SUPPORTED_DTYPES_STR
30
+ ), "String data type isn't in set of supported string data types."
31
+ return {"fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str]
32
+
33
+
34
+ # Supported data types, as PyTorch types.
35
+ SUPPORTED_DTYPES: set[torch.dtype] = {
36
+ dtype_from_str(dtype_str) for dtype_str in SUPPORTED_DTYPES_STR
37
+ }
38
+
39
+
40
+ # Convert PyTorch data type to string data type.
41
+ def str_from_dtype(dtype: torch.dtype) -> str:
42
+ assert (
43
+ dtype in SUPPORTED_DTYPES
44
+ ), "PyTorch data type isn't in set of supported PyTorch data types."
45
+ return {torch.float16: "fp16", torch.bfloat16: "bf16"}[dtype]
46
+
47
+
48
+ # Default data type, as string.
49
+ DTYPE_STR: str = "bf16"
50
+ assert (
51
+ DTYPE_STR in SUPPORTED_DTYPES_STR
52
+ ), "Default string data type isn't in set of supported string data types."
53
+
54
+
55
+ # Default data type, as PyTorch type.
56
+ DTYPE: torch.dtype = dtype_from_str(DTYPE_STR)
57
+
58
+
59
+ # Other defaults.
60
+ # ------------------------------------------------------------------------------
61
+
62
+ # Default device.
63
+ DEVICE: torch.device | str = "cuda"
64
+
65
+ # Default RNG seed for input generation.
66
+ RNG_SEED: int = 0
67
+
68
+ # Default number of group sizes.
69
+ NUM_GROUP_SIZES: int = 1
70
+
71
+ # Default transposition (NN).
72
+ TRANS_LHS: bool = False
73
+ TRANS_RHS: bool = False
74
+
75
+
76
+ # Parameter checking functions.
77
+ # ------------------------------------------------------------------------------
78
+
79
+
80
+ def is_power_of_2(x: int) -> bool:
81
+ return (x > 0) and (x & (x - 1) == 0)
82
+
83
+
84
+ def check_input_device_dtype(
85
+ lhs: Tensor, rhs: Tensor, group_sizes: Tensor, bias: Tensor | None = None
86
+ ) -> None:
87
+ assert (
88
+ lhs.device == rhs.device == group_sizes.device
89
+ ), f"All input tensors must be in the same device (lhs = {lhs.device}, rhs = {rhs.device}, group_sizes = {group_sizes.device})."
90
+ assert (
91
+ lhs.dtype == rhs.dtype
92
+ ), f"lhs and rhs types must match (lhs = {lhs.dtype}, rhs = {rhs.dtype})."
93
+ assert group_sizes.dtype == torch.int32, "group_sizes type must be int32."
94
+
95
+ if bias is not None:
96
+ assert (
97
+ bias.device == lhs.device
98
+ ), f"bias must be on the same device as lhs (bias = {bias.device}, lhs = {lhs.device})."
99
+ assert (
100
+ bias.dtype == lhs.dtype
101
+ ), f"bias dtype must match lhs dtype (bias = {bias.dtype}, lhs = {lhs.dtype})."
102
+
103
+
104
+ def check_bias_shape_stride(bias: Tensor, G: int, N: int) -> None:
105
+ assert bias.shape == (
106
+ G,
107
+ N,
108
+ ), f"bias must have shape (G, N) = ({G}, {N}), got {bias.shape}."
109
+ assert bias.stride() == (N, 1), "bias must be row-major (bias.stride() == (N, 1))."
110
+
111
+
112
+ # Generation of group sizes.
113
+ # ------------------------------------------------------------------------------
114
+
115
+
116
+ # Probabilities for generating random group sizes.
117
+ UNUSED_TOKENS_PROB: float = 0.0
118
+ UNUSED_EXPERTS_PROB: float = 0.1
119
+
120
+
121
+ def gen_uniform_group_sizes(
122
+ M: int,
123
+ G: int,
124
+ device: torch.device | str = DEVICE,
125
+ ) -> Tensor:
126
+ assert M >= 0, f"Number of tokens M must be non-negative (it's {M})."
127
+ assert G > 0, f"Number of experts G must be positive (it's {G})."
128
+
129
+ base = M // G
130
+ remainder = M % G
131
+ group_sizes = torch.full((G,), base, dtype=torch.int32, device=device)
132
+ if remainder > 0:
133
+ group_sizes[:remainder] += 1
134
+
135
+ assert (
136
+ len(group_sizes) == G
137
+ ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})."
138
+ assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative."
139
+ assert (
140
+ torch.sum(group_sizes).item() == M
141
+ ), f"Group sizes don't add up to total tokens {M}."
142
+ assert group_sizes.dtype == torch.int32, "Group sizes must be int32."
143
+
144
+ return group_sizes
145
+
146
+
147
+ def gen_group_sizes(
148
+ M: int,
149
+ G: int,
150
+ device: torch.device | str = DEVICE,
151
+ rng_seed: int | None = RNG_SEED,
152
+ unused_tokens_prob: float = UNUSED_TOKENS_PROB,
153
+ unused_experts_prob: float = UNUSED_EXPERTS_PROB,
154
+ ) -> Tensor:
155
+ assert M >= 0, f"Number of tokens M must be non-negative (it's {M})."
156
+ assert G > 0, f"Number of experts G must be positive (it's {G})."
157
+ assert (
158
+ 0 <= unused_tokens_prob <= 1
159
+ ), f"Probability of unused tokens must be in [0, 1] interval (it's {unused_tokens_prob})."
160
+ assert (
161
+ 0 <= unused_experts_prob <= 1
162
+ ), f"Probability of unused experts must be in [0, 1] interval (it's {unused_experts_prob})."
163
+
164
+ if rng_seed is not None:
165
+ torch.manual_seed(rng_seed)
166
+
167
+ if unused_tokens_prob > 0:
168
+ # Optionally drop tokens to simulate routing sparsity, some tokens may not be routed.
169
+ num_unused_tokens = M
170
+ while num_unused_tokens == M:
171
+ num_unused_tokens = int(
172
+ torch.binomial(
173
+ torch.tensor(float(M), device=device),
174
+ torch.tensor(unused_tokens_prob, device=device),
175
+ ).item()
176
+ )
177
+ else:
178
+ num_unused_tokens = 0
179
+ num_used_tokens = M - num_unused_tokens
180
+ assert (
181
+ num_unused_tokens >= 0
182
+ ), f"Number of unused tokens must be non-negative (it's {num_unused_tokens})."
183
+ assert (
184
+ num_used_tokens > 0
185
+ ), f"Number of used tokens must be positive (it's {num_used_tokens})."
186
+ assert (
187
+ num_used_tokens + num_unused_tokens == M
188
+ ), f"Unused + used tokens don't add up total tokens ({num_used_tokens} + {num_unused_tokens} != {M})."
189
+
190
+ if num_unused_tokens > 0:
191
+ _LOGGER.debug(
192
+ f"Group sizes generation: dropped {num_unused_tokens} token{'s' if num_unused_tokens > 1 else ''}.",
193
+ )
194
+
195
+ if unused_experts_prob > 0:
196
+ # Some experts may have zero tokens assigned to them.
197
+ num_used_experts = 0
198
+ while num_used_experts == 0:
199
+ used_experts = torch.nonzero(
200
+ torch.rand((G,), device=device) >= unused_experts_prob
201
+ ).squeeze()
202
+ num_used_experts = used_experts.numel()
203
+ else:
204
+ used_experts = torch.arange(0, G, device=device)
205
+ num_used_experts = G
206
+ num_unused_experts = G - num_used_experts
207
+ assert (
208
+ num_unused_experts >= 0
209
+ ), f"Number of unused experts must be non-negative (it's {num_unused_experts})."
210
+ assert (
211
+ num_used_experts >= 1
212
+ ), f"At least one expert must be used (it's {num_used_experts})."
213
+ assert (
214
+ num_unused_experts + num_used_experts == G
215
+ ), f"Unused + used experts don't add up total experts ({num_unused_experts} + {num_used_experts} != {G})."
216
+
217
+ if num_unused_experts > 0:
218
+ _LOGGER.debug(
219
+ f"Group sizes generation: dropped {num_unused_experts} expert{'s' if num_unused_experts > 1 else ''}.",
220
+ )
221
+
222
+ group_sizes = torch.bincount(
223
+ used_experts[
224
+ torch.randint(low=0, high=num_used_experts, size=(num_used_tokens,))
225
+ ],
226
+ minlength=G,
227
+ ).to(torch.int32)
228
+
229
+ assert (
230
+ len(group_sizes) == G
231
+ ), f"Group sizes don't have {G} elements (it's {len(group_sizes)})."
232
+ assert torch.all(group_sizes >= 0).item(), "All group sizes must be non-negative."
233
+ assert (
234
+ torch.sum(group_sizes).item() == num_used_tokens
235
+ ), f"Group sizes don't add up to used tokens {num_used_tokens}."
236
+ assert group_sizes.dtype == torch.int32, "Group sizes must be int32."
237
+
238
+ return group_sizes
239
+
240
+
241
+ def gen_multiple_group_sizes(
242
+ num_group_sizes: int,
243
+ M: int,
244
+ G: int,
245
+ device: torch.device | str = DEVICE,
246
+ rng_seed: int | None = RNG_SEED,
247
+ unused_tokens_prob: float = UNUSED_TOKENS_PROB,
248
+ unused_experts_prob: float = UNUSED_EXPERTS_PROB,
249
+ group_sizes_0: Tensor | None = None,
250
+ ) -> list[Tensor]:
251
+ assert (
252
+ num_group_sizes > 0
253
+ ), f"Number of group sizes to be generated must be positive, it's {num_group_sizes}."
254
+ multiple_group_sizes = [
255
+ gen_group_sizes(
256
+ M,
257
+ G,
258
+ device=device,
259
+ rng_seed=rng_seed if g == 0 else None,
260
+ unused_tokens_prob=unused_tokens_prob,
261
+ unused_experts_prob=unused_experts_prob,
262
+ )
263
+ for g in range(
264
+ num_group_sizes if group_sizes_0 is None else num_group_sizes - 1
265
+ )
266
+ ]
267
+ if group_sizes_0 is not None:
268
+ multiple_group_sizes.insert(0, group_sizes_0)
269
+ assert (
270
+ len(multiple_group_sizes) == num_group_sizes
271
+ ), f"Expecting {num_group_sizes} distinct group sizes (it's {len(multiple_group_sizes)})."
272
+ return multiple_group_sizes
273
+
274
+
275
+ # GMM helpers: tensor generation.
276
+ # ------------------------------------------------------------------------------
277
+
278
+
279
+ def gen_gmm_input(
280
+ M: int,
281
+ K: int,
282
+ N: int,
283
+ G: int,
284
+ device: torch.device | str = DEVICE,
285
+ preferred_element_type: torch.dtype = DTYPE,
286
+ trans_rhs: bool = TRANS_RHS,
287
+ rng_seed: int | None = RNG_SEED,
288
+ unif_group_sizes: bool = False,
289
+ ) -> tuple[Tensor, Tensor, Tensor]:
290
+ assert M > 0, f"Number of lhs rows M must be positive (M = {M})."
291
+ assert K > 0, f"Number of lhs columns / rhs rows K must be positive (K = {K})."
292
+ assert N > 0, f"Number of rhs columns N must be positive (N = {N})."
293
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
294
+
295
+ if rng_seed is not None:
296
+ torch.manual_seed(rng_seed)
297
+
298
+ lhs = torch.randn((M, K), dtype=torch.float32, device=device)
299
+ lhs = lhs.to(preferred_element_type)
300
+
301
+ if trans_rhs:
302
+ rhs = torch.randn((G, N, K), dtype=torch.float32, device=device).permute(
303
+ 0, 2, 1
304
+ )
305
+ else:
306
+ rhs = torch.randn((G, K, N), dtype=torch.float32, device=device)
307
+ rhs = rhs.to(preferred_element_type)
308
+
309
+ group_sizes = (
310
+ gen_uniform_group_sizes(M, G, device=device)
311
+ if unif_group_sizes
312
+ else gen_group_sizes(M, G, device=device, rng_seed=None)
313
+ )
314
+
315
+ return lhs, rhs, group_sizes
316
+
317
+
318
+ def gen_gmm_output(
319
+ M: int,
320
+ N: int,
321
+ device: torch.device | str = DEVICE,
322
+ preferred_element_type: torch.dtype = DTYPE,
323
+ ) -> Tensor:
324
+ assert M > 0, f"Number of out rows M must be positive (M = {M})."
325
+ assert N > 0, f"Number of out columns N must be positive (N = {N})."
326
+
327
+ out = torch.empty((M, N), dtype=preferred_element_type, device=device)
328
+
329
+ return out
330
+
331
+
332
+ def gen_gmm_tensors(
333
+ M: int,
334
+ K: int,
335
+ N: int,
336
+ G: int,
337
+ num_group_sizes: int,
338
+ device: torch.device | str = DEVICE,
339
+ input_type: torch.dtype = DTYPE,
340
+ output_type: torch.dtype = DTYPE,
341
+ trans_lhs: bool = False,
342
+ trans_rhs: bool = TRANS_RHS,
343
+ rng_seed: int | None = RNG_SEED,
344
+ unif_group_sizes: bool = False,
345
+ use_bias: bool = False,
346
+ ) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]:
347
+ lhs, rhs, group_sizes_0 = gen_gmm_input(
348
+ M,
349
+ K,
350
+ N,
351
+ G,
352
+ device=device,
353
+ preferred_element_type=input_type,
354
+ trans_rhs=trans_rhs,
355
+ rng_seed=rng_seed,
356
+ unif_group_sizes=unif_group_sizes,
357
+ )
358
+ multiple_group_sizes = gen_multiple_group_sizes(
359
+ num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0
360
+ )
361
+ out = gen_gmm_output(M, N, device=device, preferred_element_type=output_type)
362
+ bias = None
363
+ if use_bias:
364
+ torch.manual_seed(rng_seed + 1000) # Different seed for bias
365
+ bias = torch.randn(G, N, dtype=input_type, device=device)
366
+
367
+ return lhs, rhs, multiple_group_sizes, out, bias
368
+
369
+
370
+ # GMM helpers: get information from tensors.
371
+ # ------------------------------------------------------------------------------
372
+
373
+
374
+ def get_gmm_shape(
375
+ lhs: Tensor, rhs: Tensor, group_sizes: Tensor
376
+ ) -> tuple[int, int, int, int]:
377
+ assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
378
+ assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})."
379
+ assert (
380
+ group_sizes.dim() == 1
381
+ ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})."
382
+
383
+ M, lhs_k = lhs.shape
384
+ rhs_g, rhs_k, N = rhs.shape
385
+ group_sizes_g = group_sizes.shape[0]
386
+
387
+ assert (
388
+ lhs_k == rhs_k
389
+ ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})."
390
+ K = lhs_k
391
+ assert (
392
+ rhs_g == group_sizes_g
393
+ ), f"G dimension of rhs and group_sizes don't match (rhs = {rhs_g}, group_sizes = {group_sizes_g})."
394
+ G = rhs_g
395
+
396
+ assert M > 0, f"M must be positive, it's {M}."
397
+ assert K > 0, f"K must be positive, it's {K}."
398
+ assert N > 0, f"N must be positive, it's {N}"
399
+ assert G > 0, f"G must be positive, it's {G}"
400
+
401
+ return M, K, N, G
402
+
403
+
404
+ def get_gmm_output(
405
+ M: int,
406
+ N: int,
407
+ device: torch.device | str = DEVICE,
408
+ preferred_element_type: torch.dtype = DTYPE,
409
+ existing_out: Tensor | None = None,
410
+ ) -> Tensor:
411
+ assert M > 0, f"Number of out rows M must be positive (M = {M})."
412
+ assert N > 0, f"Number of out columns N must be positive (N = {N})."
413
+
414
+ if existing_out is not None:
415
+ assert (
416
+ existing_out.device == device
417
+ ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})."
418
+ assert (
419
+ existing_out.dtype == preferred_element_type
420
+ ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})."
421
+ assert existing_out.shape == (
422
+ M,
423
+ N,
424
+ ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(M, N)})."
425
+ return existing_out
426
+
427
+ return gen_gmm_output(
428
+ M,
429
+ N,
430
+ device=device,
431
+ preferred_element_type=preferred_element_type,
432
+ )
433
+
434
+
435
+ def get_gmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]:
436
+ assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
437
+ assert rhs.dim() == 3, f"rhs must have 3 dimensions (it's {rhs.dim()})."
438
+ assert out.dim() == 2, f"out must have 2 dimensions (it's {out.dim()})."
439
+
440
+ lhs_m, lhs_k = lhs.shape
441
+ G, rhs_k, rhs_n = rhs.shape
442
+ out_m, out_n = out.shape
443
+
444
+ assert (
445
+ lhs_m == out_m
446
+ ), f"M dimension of lhs and out don't match (lhs = {lhs_m}, rhs = {out_m})."
447
+ M = lhs_m
448
+ assert (
449
+ lhs_k == rhs_k
450
+ ), f"K dimension of lhs and rhs don't match (lhs = {lhs_k}, rhs = {rhs_k})."
451
+ K = lhs_k
452
+ assert (
453
+ rhs_n == out_n
454
+ ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})."
455
+ N = rhs_n
456
+
457
+ assert M > 0, f"M must be positive, it's {M}."
458
+ assert K > 0, f"K must be positive, it's {K}."
459
+ assert N > 0, f"N must be positive, it's {N}"
460
+ assert G > 0, f"G must be positive, it's {G}"
461
+
462
+ is_lhs_row_major = lhs.stride() == (K, 1)
463
+ assert is_lhs_row_major, "lhs must be row-major."
464
+ is_rhs_row_major = rhs.stride() == (K * N, N, 1)
465
+ is_rhs_col_major = rhs.stride() == (K * N, 1, K)
466
+ assert (
467
+ is_rhs_row_major != is_rhs_col_major
468
+ ), "rhs must be row-major or column-major."
469
+ is_out_row_major = out.stride() == (N, 1)
470
+ assert is_out_row_major, "out must be row-major."
471
+
472
+ # Get rhs leading dimension according to transposition configuration.
473
+ ld_rhs = N if is_rhs_row_major else K
474
+
475
+ return is_rhs_col_major, ld_rhs
476
+
477
+
478
+ # TGMM helpers: tensor generation.
479
+ # ------------------------------------------------------------------------------
480
+
481
+
482
+ def gen_tgmm_input(
483
+ M: int,
484
+ K: int,
485
+ N: int,
486
+ G: int,
487
+ device: torch.device | str = DEVICE,
488
+ preferred_element_type: torch.dtype = DTYPE,
489
+ trans_lhs: bool = TRANS_LHS,
490
+ rng_seed: int | None = RNG_SEED,
491
+ unif_group_sizes: bool = False,
492
+ ) -> tuple[Tensor, Tensor, Tensor]:
493
+ assert K > 0, f"Number of lhs rows K must be positive (M = {K})."
494
+ assert M > 0, f"Number of lhs columns / rhs rows M must be positive (K = {M})."
495
+ assert N > 0, f"Number of rhs columns N must be positive (N = {N})."
496
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
497
+
498
+ if rng_seed is not None:
499
+ torch.manual_seed(rng_seed)
500
+
501
+ if trans_lhs:
502
+ lhs = torch.randn((M, K), dtype=torch.float32, device=device).T
503
+ else:
504
+ lhs = torch.randn((K, M), dtype=torch.float32, device=device)
505
+ lhs = lhs.to(preferred_element_type)
506
+
507
+ rhs = torch.randn((M, N), dtype=torch.float32, device=device)
508
+ rhs = rhs.to(preferred_element_type)
509
+
510
+ group_sizes = (
511
+ gen_uniform_group_sizes(M, G, device=device)
512
+ if unif_group_sizes
513
+ else gen_group_sizes(M, G, device=device, rng_seed=None)
514
+ )
515
+
516
+ return lhs, rhs, group_sizes
517
+
518
+
519
+ def gen_tgmm_output(
520
+ K: int,
521
+ N: int,
522
+ G: int,
523
+ device: torch.device | str = DEVICE,
524
+ preferred_element_type: torch.dtype = DTYPE,
525
+ ) -> Tensor:
526
+ assert K > 0, f"Number of out rows K must be positive (K = {K})."
527
+ assert N > 0, f"Number of out columns N must be positive (N = {N})."
528
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
529
+
530
+ out = torch.empty((G, K, N), dtype=preferred_element_type, device=device)
531
+
532
+ return out
533
+
534
+
535
+ def gen_tgmm_bias_grad(
536
+ K: int,
537
+ G: int,
538
+ device: torch.device | str = DEVICE,
539
+ with_bias_grad: bool = False,
540
+ ) -> Tensor:
541
+ if with_bias_grad:
542
+ assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})."
543
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
544
+ return torch.empty((G, K), device=device, dtype=torch.float32)
545
+ else:
546
+ # Return dummy pointer when bias_grad is not needed.
547
+ # Must be float32 because atomic_add does not support bf16/fp16,
548
+ # and Triton validates the pointer dtype even in dead branches.
549
+ return torch.tensor([], device=device, dtype=torch.float32)
550
+
551
+
552
+ def gen_tgmm_tensors(
553
+ M: int,
554
+ K: int,
555
+ N: int,
556
+ G: int,
557
+ num_group_sizes: int,
558
+ device: torch.device | str = DEVICE,
559
+ input_type: torch.dtype = DTYPE,
560
+ output_type: torch.dtype = DTYPE,
561
+ trans_lhs: bool = TRANS_LHS,
562
+ trans_rhs: bool = False,
563
+ rng_seed: int | None = RNG_SEED,
564
+ unif_group_sizes: bool = False,
565
+ use_bias: bool = False,
566
+ ) -> tuple[Tensor, Tensor, list[Tensor], Tensor, Tensor | None]:
567
+ lhs, rhs, group_sizes_0 = gen_tgmm_input(
568
+ M,
569
+ K,
570
+ N,
571
+ G,
572
+ device=device,
573
+ preferred_element_type=input_type,
574
+ trans_lhs=trans_lhs,
575
+ rng_seed=rng_seed,
576
+ unif_group_sizes=unif_group_sizes,
577
+ )
578
+ multiple_group_sizes = gen_multiple_group_sizes(
579
+ num_group_sizes, M, G, device=device, rng_seed=None, group_sizes_0=group_sizes_0
580
+ )
581
+ out = gen_tgmm_output(K, N, G, device=device, preferred_element_type=output_type)
582
+ if use_bias:
583
+ bias_grad = gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=True)
584
+ else:
585
+ bias_grad = None
586
+ return lhs, rhs, multiple_group_sizes, out, bias_grad
587
+
588
+
589
+ # TGMM helpers: get information from tensors.
590
+ # ------------------------------------------------------------------------------
591
+
592
+
593
+ def get_tgmm_shape(
594
+ lhs: Tensor, rhs: Tensor, group_sizes: Tensor
595
+ ) -> tuple[int, int, int, int]:
596
+ assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
597
+ assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})."
598
+ assert (
599
+ group_sizes.dim() == 1
600
+ ), f"group_sizes must have 1 dimension (it's {group_sizes.dim()})."
601
+
602
+ K, lhs_m = lhs.shape
603
+ rhs_m, N = rhs.shape
604
+ G = group_sizes.shape[0]
605
+
606
+ assert (
607
+ lhs_m == rhs_m
608
+ ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})."
609
+ M = lhs_m
610
+
611
+ assert M > 0, f"M must be positive, it's {M}."
612
+ assert K > 0, f"K must be positive, it's {K}."
613
+ assert N > 0, f"N must be positive, it's {N}"
614
+ assert G > 0, f"G must be positive, it's {G}"
615
+
616
+ return M, K, N, G
617
+
618
+
619
+ def get_tgmm_output(
620
+ K: int,
621
+ N: int,
622
+ G: int,
623
+ device: torch.device | str = DEVICE,
624
+ preferred_element_type: torch.dtype = DTYPE,
625
+ existing_out: Tensor | None = None,
626
+ ) -> Tensor:
627
+ assert K > 0, f"Number of out rows K must be positive (K = {K})."
628
+ assert N > 0, f"Number of out columns N must be positive (N = {N})."
629
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
630
+
631
+ if existing_out is not None:
632
+ assert (
633
+ existing_out.device == device
634
+ ), f"Existing output device and provided device don't match (existing = {existing_out.device}, provided = {device})."
635
+ assert (
636
+ existing_out.dtype == preferred_element_type
637
+ ), f"Existing output type and preferred output type don't match (existing = {existing_out.dtype}, preferred = {preferred_element_type})."
638
+ assert existing_out.shape == (
639
+ G,
640
+ K,
641
+ N,
642
+ ), f"Existing output shape and GMM shape don't match (existing = {tuple(existing_out.shape)}, provided = {(G, K, N)})."
643
+ return existing_out
644
+
645
+ return gen_tgmm_output(
646
+ K,
647
+ N,
648
+ G,
649
+ device=device,
650
+ preferred_element_type=preferred_element_type,
651
+ )
652
+
653
+
654
+ def get_tgmm_bias_grad(
655
+ K: int,
656
+ G: int,
657
+ device: torch.device | str = DEVICE,
658
+ existing_bias_grad: Tensor | None = None,
659
+ ) -> Tensor:
660
+ """
661
+ Get or validate bias gradient tensor for TGMM.
662
+
663
+ If existing_bias_grad is provided, validates its shape, device, dtype, and stride,
664
+ and always zeros it before returning (since the kernel uses atomic_add).
665
+ If existing_bias_grad is None, returns a dummy tensor (for use when COMPUTE_BIAS_GRAD=False).
666
+ Parameters
667
+ ----------
668
+ K : int
669
+ Number of rows in the bias gradient tensor.
670
+ G : int
671
+ Number of groups.
672
+ device : torch.device or str
673
+ Device for the tensor.
674
+ existing_bias_grad : torch.Tensor or None
675
+ Existing bias gradient tensor to validate and use.
676
+ Returns
677
+ -------
678
+ torch.Tensor
679
+ Valid bias gradient tensor or dummy tensor.
680
+ """
681
+ assert K > 0, f"Number of bias_grad rows K must be positive (K = {K})."
682
+ assert G > 0, f"Number of groups G must be positive (G = {G})."
683
+
684
+ if existing_bias_grad is not None:
685
+ # Validate existing bias_grad tensor.
686
+ expected_shape = (G, K)
687
+ assert (
688
+ tuple(existing_bias_grad.shape) == expected_shape
689
+ ), f"bias_grad must have shape {expected_shape}, got {tuple(existing_bias_grad.shape)}."
690
+ assert (
691
+ existing_bias_grad.device == device
692
+ ), f"bias_grad must be on the same device (bias_grad = {existing_bias_grad.device}, device = {device})."
693
+ assert (
694
+ existing_bias_grad.dtype == torch.float32
695
+ ), f"bias_grad must be torch.float32 (kernel uses atomic_add which requires float32), got {existing_bias_grad.dtype}."
696
+ assert existing_bias_grad.stride() == (
697
+ K,
698
+ 1,
699
+ ), f"bias_grad must be row-major with stride (K, 1) = ({K}, 1), got {existing_bias_grad.stride()}."
700
+
701
+ # Always zero the tensor since bias_grad represents gradients for the current
702
+ # computation and should start fresh. The kernel uses atomic_add which adds to
703
+ # existing values, so we must zero before the kernel runs.
704
+ existing_bias_grad.zero_()
705
+
706
+ return existing_bias_grad
707
+
708
+ else:
709
+ return gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=False)
710
+
711
+
712
+ def get_tgmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]:
713
+ assert lhs.dim() == 2, f"lhs must have 2 dimensions (it's {lhs.dim()})."
714
+ assert rhs.dim() == 2, f"rhs must have 2 dimensions (it's {rhs.dim()})."
715
+ assert out.dim() == 3, f"out must have 3 dimensions (it's {out.dim()})."
716
+
717
+ lhs_k, lhs_m = lhs.shape
718
+ rhs_m, rhs_n = rhs.shape
719
+ G, out_k, out_n = out.shape
720
+
721
+ assert (
722
+ lhs_m == rhs_m
723
+ ), f"M dimension of lhs and rhs don't match (lhs = {lhs_m}, rhs = {rhs_m})."
724
+ M = lhs_m
725
+ assert (
726
+ lhs_k == out_k
727
+ ), f"K dimension of lhs and out don't match (lhs = {lhs_k}, rhs = {out_k})."
728
+ K = lhs_k
729
+ assert (
730
+ rhs_n == out_n
731
+ ), f"N dimension of rhs and out don't match (lhs = {rhs_n}, rhs = {out_n})."
732
+ N = rhs_n
733
+
734
+ assert M > 0, f"M must be positive, it's {M}."
735
+ assert K > 0, f"K must be positive, it's {K}."
736
+ assert N > 0, f"N must be positive, it's {N}"
737
+ assert G > 0, f"G must be positive, it's {G}"
738
+
739
+ is_lhs_row_major = lhs.stride() == (M, 1)
740
+ is_lhs_col_major = lhs.stride() == (1, K)
741
+ assert (
742
+ is_lhs_row_major != is_lhs_col_major
743
+ ), "lhs must be row-major or column-major."
744
+ is_rhs_row_major = rhs.stride() == (N, 1)
745
+ assert is_rhs_row_major, "rhs must be row-major."
746
+ is_out_row_major = out.stride() == (K * N, N, 1)
747
+ assert is_out_row_major, "out must be row-major."
748
+
749
+ # Get lhs leading dimension according to transposition configuration.
750
+ ld_lhs = M if is_lhs_row_major else K
751
+
752
+ return is_lhs_col_major, ld_lhs
build/torch211-cxx11-cu130-x86_64-linux/_grouped_gemm_triton/utils/logger.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+
5
+ # AITER Triton Logger which is singleton object around python logging.
6
+ # Note: Python logging is also a singleton object, but we want to read the
7
+ # env var AITER_LOG_LEVEL once at the beginning. Another alternative is to do
8
+ # this in __init__.py. In fact, that's how CK logger is setup. We can look at
9
+ # switching to that at some point
10
+ #
11
+ # AITER_LOG_LEVEL follows python logging levels
12
+ # DEBUG
13
+ # INFO
14
+ # WARNING
15
+ # ERROR
16
+ # CRITICAL
17
+ #
18
+ class AiterTritonLogger(object):
19
+ _instance = None
20
+
21
+ def __new__(cls):
22
+ if cls._instance is None:
23
+ cls._instance = super(AiterTritonLogger, cls).__new__(cls)
24
+ log_level_str = os.getenv("AITER_TRITON_LOG_LEVEL", "WARNING").upper()
25
+ numeric_level = getattr(logging, log_level_str, logging.WARNING)
26
+ cls._instance._logger = logging.getLogger("AITER_TRITON")
27
+ cls._instance._logger.setLevel(numeric_level)
28
+
29
+ return cls._instance
30
+
31
+ def get_logger(self):
32
+ return self._logger
33
+
34
+ def debug(self, msg):
35
+ self._logger.debug(msg)
36
+
37
+ def info(self, msg):
38
+ self._logger.info(msg)
39
+
40
+ def warning(self, msg):
41
+ self._logger.warning(msg)
42
+
43
+ def error(self, msg):
44
+ self._logger.error(msg)
45
+
46
+ def critical(self, msg):
47
+ self._logger.critical(msg)
build/torch211-cxx11-cu130-x86_64-linux/{_megablocks_cuda_ae601bb.abi3.so → _megablocks_cuda_f8f8b50.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8f05428251fcba79071d881be47c1d2778f2fb3a068d029c7f6c4f546efa5b64
3
- size 10113080
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ef673d78d220cea71eace3a5bdb4b952444ab7b95ed15774258ad108ad40d51
3
+ size 11769248
build/torch211-cxx11-cu130-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_ae601bb
3
- ops = torch.ops._megablocks_cuda_ae601bb
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_ae601bb::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_f8f8b50
3
+ ops = torch.ops._megablocks_cuda_f8f8b50
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_f8f8b50::{op_name}"
build/torch211-cxx11-cu130-x86_64-linux/grouped_gemm/backend.py CHANGED
@@ -2,16 +2,16 @@
2
  # extensions. Otherwise libc10.so cannot be found.
3
  import torch
4
 
5
- # # TODO(tgale): Wrap this in a try-block with better
6
- # # error message and instructions for building the
7
- # # c++ operations.
8
- # import grouped_gemm_backend as backend
9
 
10
- # We import the backend operations from the megablocks package as
11
- # grouped_gemm is vendored in megablocks in this repository.
12
- # from ... import _ops as backend
13
- # from megablocks._ops import ops as backend # type: ignore
14
- from .._ops import ops as backend # type: ignore
 
15
 
16
  def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
17
  assert not (trans_a and trans_b)
 
2
  # extensions. Otherwise libc10.so cannot be found.
3
  import torch
4
 
5
+ # On ROCm there is no CUTLASS grouped GEMM; dispatch to the vendored AITER
6
+ # Triton kernels instead. On CUDA we use the compiled CUTLASS `gmm` op.
7
+ _IS_ROCM = torch.version.hip is not None
 
8
 
9
+ if _IS_ROCM:
10
+ from .._grouped_gemm_triton import adapter as backend
11
+ else:
12
+ # We import the backend operations from the megablocks package as
13
+ # grouped_gemm is vendored in megablocks in this repository.
14
+ from .._ops import ops as backend # type: ignore
15
 
16
  def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
17
  assert not (trans_a and trans_b)