Kernels
danieldk HF Staff commited on
Commit
3b163b3
·
verified ·
1 Parent(s): 71f0dbc

Build uploaded using `kernels` (batch 2/10).

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/op.py +34 -0
  3. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_helpers.py +616 -0
  4. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_preprocessor.py +1958 -0
  5. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/cache_helpers.py +153 -0
  6. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/common.py +268 -0
  7. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/compiler.py +288 -0
  8. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/dsl.py +1686 -0
  9. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/env_manager.py +320 -0
  10. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/jit_executor.py +357 -0
  11. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/__init__.py +25 -0
  12. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/cuda.py +476 -0
  13. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/device_tensor.py +121 -0
  14. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/dlpack_types.py +76 -0
  15. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py +188 -0
  16. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py +201 -0
  17. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/typing.py +1962 -0
  18. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/__init__.py +19 -0
  19. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/logger.py +81 -0
  20. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/stacktrace.py +165 -0
  21. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/timer.py +56 -0
  22. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/__init__.py +59 -0
  23. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/__init__.py +319 -0
  24. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/__init__.py +101 -0
  25. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/elect.py +84 -0
  26. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/mbar.py +349 -0
  27. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py +681 -0
  28. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/smem.py +108 -0
  29. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/tmem.py +142 -0
  30. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/core.py +0 -0
  31. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/math.py +445 -0
  32. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py +26 -0
  33. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/common.py +189 -0
  34. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py +39 -0
  35. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py +471 -0
  36. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py +341 -0
  37. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py +249 -0
  38. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py +62 -0
  39. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py +663 -0
  40. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py +328 -0
  41. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py +1041 -0
  42. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py +25 -0
  43. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py +189 -0
  44. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py +83 -0
  45. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py +29 -0
  46. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py +109 -0
  47. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py +405 -0
  48. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/runtime.py +510 -0
  49. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/testing.py +610 -0
  50. build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/typing.py +207 -0
.gitattributes CHANGED
@@ -12,3 +12,4 @@ build/torch29-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lf
12
  build/torch29-cxx11-cu129-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
13
  build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
14
  build/torch210-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
 
 
12
  build/torch29-cxx11-cu129-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
13
  build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
14
  build/torch210-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
15
+ build/torch210-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/op.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ """
13
+ This module provides MLIR's OP helper functions
14
+ """
15
+
16
+
17
+ import inspect
18
+ from functools import wraps
19
+
20
+ from ..._mlir import ir
21
+
22
+
23
+ def dsl_user_op(opFunc):
24
+ @wraps(opFunc)
25
+ def wrapper(*args, **kwargs):
26
+ loc = kwargs.pop("loc", None)
27
+ if loc is None:
28
+ frame = inspect.currentframe().f_back
29
+ file_loc = ir.Location.file(frame.f_code.co_filename, frame.f_lineno, 0)
30
+ loc = ir.Location.name(frame.f_code.co_name, childLoc=file_loc)
31
+ res_or_list = opFunc(*args, **kwargs, loc=loc)
32
+ return res_or_list
33
+
34
+ return wrapper
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_helpers.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ """
13
+ This module provides helper functions that are generated by the preprocessor.
14
+ The preprocessor read through python's ast and changes the input code.
15
+ """
16
+
17
+ from typing import Callable, Iterator, Optional, overload
18
+ from typing_extensions import deprecated
19
+ import warnings
20
+ import inspect
21
+ from types import BuiltinFunctionType
22
+ from functools import lru_cache
23
+ from inspect import getmembers
24
+
25
+ from .utils.logger import log
26
+ from .common import *
27
+
28
+ from ._mlir_helpers.arith import ArithValue
29
+
30
+
31
+ class Executor:
32
+ """
33
+ The Executor class handles dynamic and compile-time (constexpr) execution
34
+ of "for" loops and "if-else-elif" statements.
35
+
36
+ Methods:
37
+ set_functions: Assigns the functions for checking loop bounds and
38
+ conditional evaluation.
39
+
40
+ for_execute: Generates MLIR for OP
41
+ while_execute: Generates MLIR while OP
42
+ if_execute: generate MLIR if OP
43
+ """
44
+
45
+ def __init__(self):
46
+ self._is_dynamic_expression = None
47
+ self._loop_execute_range_dynamic = None
48
+ self._if_dynamic = None
49
+ self._while_dynamic = None
50
+ self._compare_executor = None
51
+ self._any_executor = None
52
+ self._all_executor = None
53
+ self._builtin_redirector = None
54
+
55
+ def set_functions(
56
+ self,
57
+ *,
58
+ is_dynamic_expression: Callable,
59
+ loop_execute_range_dynamic: Callable,
60
+ if_dynamic: Callable,
61
+ while_dynamic: Callable,
62
+ compare_executor: Callable,
63
+ any_executor: Callable = None,
64
+ all_executor: Callable = None,
65
+ builtin_redirector: Callable = None,
66
+ ):
67
+ self._is_dynamic_expression = is_dynamic_expression
68
+ self._loop_execute_range_dynamic = loop_execute_range_dynamic
69
+ self._if_dynamic = if_dynamic
70
+ self._while_dynamic = while_dynamic
71
+ self._compare_executor = compare_executor
72
+ self._any_executor = any_executor
73
+ self._all_executor = all_executor
74
+ self._builtin_redirector = builtin_redirector
75
+
76
+ @staticmethod
77
+ def convert_to_list(x):
78
+ """This function is used to convert x to a list.
79
+ If x is None, return an empty list.
80
+ If x is not a list, return a list containing x.
81
+ Otherwise, return x itself.
82
+ """
83
+ if x is None:
84
+ return []
85
+ if not isinstance(x, list):
86
+ return [x]
87
+ return x
88
+
89
+ @staticmethod
90
+ def converge_ret_val(res):
91
+ """This function is used to converge res (the return value) of the function.
92
+ If res is None, return None.
93
+ If res is a list and has only one element, return the element.
94
+ Otherwise, return res itself.
95
+ """
96
+ if res is None:
97
+ return res
98
+ elif isinstance(res, list) and len(res) == 1:
99
+ return res[0]
100
+ return res
101
+
102
+ def for_execute(
103
+ self,
104
+ func,
105
+ start,
106
+ stop,
107
+ step,
108
+ write_args=[],
109
+ full_write_args_count=0,
110
+ write_args_names=[],
111
+ unroll=-1,
112
+ unroll_full=False,
113
+ prefetch_stages=None,
114
+ ):
115
+ assert (
116
+ self._loop_execute_range_dynamic
117
+ ), "Functions must be set before execution."
118
+ log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
119
+
120
+ return self._loop_execute_range_dynamic(
121
+ func,
122
+ start,
123
+ stop,
124
+ step,
125
+ write_args,
126
+ full_write_args_count,
127
+ write_args_names,
128
+ unroll,
129
+ unroll_full,
130
+ prefetch_stages,
131
+ )
132
+
133
+ def if_execute(
134
+ self,
135
+ pred,
136
+ then_block: Callable,
137
+ else_block: Optional[Callable] = None,
138
+ write_args=[],
139
+ full_write_args_count=0,
140
+ write_args_names=[],
141
+ ):
142
+ assert self._if_dynamic, "Functions must be set before execution."
143
+
144
+ # MLIR generation
145
+ return self._if_dynamic(
146
+ pred,
147
+ then_block,
148
+ else_block,
149
+ write_args,
150
+ full_write_args_count,
151
+ write_args_names,
152
+ )
153
+
154
+ def while_execute(
155
+ self,
156
+ pred,
157
+ while_before_block: Callable,
158
+ while_after_block: Callable,
159
+ write_args=[],
160
+ full_write_args_count=0,
161
+ write_args_names=[],
162
+ ):
163
+ assert self._while_dynamic, "Functions must be set before execution."
164
+
165
+ # MLIR generation
166
+ return self._while_dynamic(
167
+ while_before_block,
168
+ while_after_block,
169
+ write_args,
170
+ full_write_args_count,
171
+ write_args_names,
172
+ )
173
+
174
+
175
+ # =============================================================================
176
+ # Decorator
177
+ # =============================================================================
178
+
179
+ executor = Executor()
180
+
181
+
182
+ def loop_selector(
183
+ start,
184
+ stop,
185
+ step,
186
+ *,
187
+ write_args=[],
188
+ full_write_args_count=0,
189
+ write_args_names=[],
190
+ unroll=-1,
191
+ unroll_full=False,
192
+ prefetch_stages=None,
193
+ ):
194
+ log().debug(
195
+ "start [%s] stop [%s] step [%s] write_args [%s] full_write_args_count [%s] write_args_names [%s] unroll [%s] unroll_full [%s] prefetch_stages [%s]",
196
+ start,
197
+ stop,
198
+ step,
199
+ write_args,
200
+ full_write_args_count,
201
+ write_args_names,
202
+ unroll,
203
+ unroll_full,
204
+ prefetch_stages,
205
+ )
206
+ from .typing import Integer, Numeric
207
+
208
+ def _maybe_upcast(value):
209
+ if isinstance(value, Integer):
210
+ value = value.ir_value()
211
+
212
+ return value
213
+
214
+ start = _maybe_upcast(start)
215
+ stop = _maybe_upcast(stop)
216
+ step = _maybe_upcast(step)
217
+
218
+ def ir_loop(func):
219
+ return executor.for_execute(
220
+ func,
221
+ start,
222
+ stop,
223
+ step,
224
+ write_args,
225
+ full_write_args_count,
226
+ write_args_names,
227
+ unroll,
228
+ unroll_full,
229
+ prefetch_stages,
230
+ )
231
+
232
+ return ir_loop
233
+
234
+
235
+ def if_selector(pred, write_args=[]):
236
+ log().debug("pred [%s] write_args [%s]", pred, write_args)
237
+ # Handle Numeric types here?
238
+
239
+ from .typing import Numeric
240
+
241
+ if isinstance(pred, Numeric):
242
+ pred = pred.value
243
+
244
+ def ir_loop(func):
245
+ return func(pred, *write_args)
246
+
247
+ return ir_loop
248
+
249
+
250
+ def while_selector(pred, write_args=[]):
251
+ def ir_while_loop(func):
252
+ return func(pred, *write_args)
253
+
254
+ return ir_while_loop
255
+
256
+
257
+ def while_executor(
258
+ pred,
259
+ while_before_block: Callable,
260
+ while_after_block: Callable,
261
+ write_args=[],
262
+ full_write_args_count=0,
263
+ write_args_names=[],
264
+ ):
265
+ return executor.while_execute(
266
+ pred,
267
+ while_before_block,
268
+ while_after_block,
269
+ write_args,
270
+ full_write_args_count,
271
+ write_args_names,
272
+ )
273
+
274
+
275
+ def if_executor(
276
+ pred,
277
+ then_block: Callable,
278
+ else_block: Optional[Callable] = None,
279
+ write_args=[],
280
+ full_write_args_count=0,
281
+ write_args_names=[],
282
+ ):
283
+ return executor.if_execute(
284
+ pred,
285
+ then_block,
286
+ else_block,
287
+ write_args,
288
+ full_write_args_count,
289
+ write_args_names,
290
+ )
291
+
292
+
293
+ # =============================================================================
294
+ # Range
295
+ # =============================================================================
296
+
297
+
298
+ class range:
299
+ """
300
+ A range-like object for dynamic loop iteration in the DSL.
301
+
302
+ This class provides a range interface similar to Python's built-in range,
303
+ but is designed to be preprocessed into constructs for dynamic
304
+ loop execution.
305
+
306
+ The class supports both single-argument (stop) and three-argument
307
+ (start, stop, step) constructors with additional parameters for loop
308
+ optimization:
309
+
310
+ - unroll: Number of iterations to unroll (0 or 1 = no unrolling)
311
+ - unroll_full: Whether to fully unroll the loop
312
+ - prefetch_stages: Number of prefetch stages to generate
313
+ """
314
+
315
+ @overload
316
+ def __new__(cls, stop, unroll=0, unroll_full=False, prefetch_stages=None):
317
+ pass
318
+
319
+ @overload
320
+ def __new__(
321
+ cls, start, stop, step, unroll=0, unroll_full=False, prefetch_stages=None
322
+ ):
323
+ pass
324
+
325
+ def __new__(cls, *args, **kwargs):
326
+ raise DSLRuntimeError("dynamic range should be always preprocessed to IR")
327
+
328
+ def __iter__(self) -> Iterator[int]:
329
+ raise DSLRuntimeError("dynamic range should be always preprocessed to IR")
330
+
331
+
332
+ @deprecated(
333
+ "range_dynamic is deprecated and will be removed in the future, please remove it."
334
+ )
335
+ def range_dynamic(*args, **kwargs):
336
+ raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")
337
+
338
+
339
+ def range_constexpr(*args):
340
+ raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.")
341
+
342
+
343
+ # =============================================================================
344
+ # If expressions
345
+ # =============================================================================
346
+
347
+
348
+ def const_expr(expression):
349
+ """
350
+ This function is used to check if the expression is a python value.
351
+ If the expression is a python value, return the boolean value of the expression.
352
+ If the expression is a dynamic expression, raise an error.
353
+ """
354
+ from .typing import Numeric
355
+
356
+ failed = False
357
+
358
+ if isinstance(expression, Numeric):
359
+ if isinstance(expression.value, (int, float, bool)):
360
+ return expression.value
361
+ else:
362
+ failed = True
363
+ elif executor._is_dynamic_expression(expression):
364
+ failed = True
365
+
366
+ if failed:
367
+ raise DSLRuntimeError(
368
+ f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).",
369
+ context={
370
+ "If your expression depends on dynamic values": "Remove `const_expr()`",
371
+ },
372
+ )
373
+ return expression
374
+
375
+
376
+ @deprecated(
377
+ "dynamic_expr is deprecated and will be removed in the future, please remove it."
378
+ )
379
+ def dynamic_expr(expression):
380
+ return expression
381
+
382
+
383
+ # =============================================================================
384
+ # Assertion & casting
385
+ # =============================================================================
386
+
387
+
388
+ def assert_executor(test, msg=None):
389
+ from .typing import Numeric
390
+
391
+ fail = False
392
+ # Implicit convert dynamic expression to bool is not allowed
393
+ # So here explicitly do a None check
394
+ if test is not None and executor._is_dynamic_expression(test):
395
+ if isinstance(test, Numeric):
396
+ try:
397
+ test = test.to(bool)
398
+ except:
399
+ fail = True
400
+ else:
401
+ fail = True
402
+
403
+ if not fail:
404
+ assert test, msg
405
+ else:
406
+ raise DSLRuntimeError(
407
+ "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
408
+ suggestion="Please replace with runtime assert.",
409
+ )
410
+
411
+
412
+ def bool_cast(value):
413
+ if executor._is_dynamic_expression(value):
414
+ raise DSLRuntimeError(
415
+ "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
416
+ suggestion="Please explicitly convert to boolean with expressions like comparision.",
417
+ )
418
+ return bool(value)
419
+
420
+
421
+ def compare_executor(left, comparators, ops):
422
+ """
423
+ Executes comparison operations with a left operand and a list of comparators.
424
+
425
+ Args:
426
+ left: The leftmost value in the comparison chain
427
+ comparators: A list of values to compare against
428
+ ops: A list of comparison operators to apply
429
+
430
+ Returns:
431
+ The result of the comparison chain
432
+
433
+ Raises:
434
+ AssertionError: If the executor function is not set before execution
435
+ """
436
+ assert (
437
+ executor._compare_executor is not None
438
+ ), "Function must be set before execution."
439
+ return executor._compare_executor(left, comparators, ops)
440
+
441
+
442
+ def any_executor(iterable):
443
+ """Executes the 'any' operation on an iterable, handling both dynamic and static expressions.
444
+
445
+ :param iterable: An iterable to check if any elements evaluate to True
446
+ :type iterable: Iterable
447
+ :return: boolean of Python value or IR value
448
+ :rtype: bool or cutlass.Boolean
449
+
450
+ """
451
+ if executor._any_executor and executor._is_dynamic_expression(iterable):
452
+ return executor._any_executor(iterable)
453
+ else:
454
+ return any(iterable)
455
+
456
+
457
+ def all_executor(iterable):
458
+ """Executes the 'all' operation on an iterable, handling both dynamic and static expressions.
459
+
460
+ :param iterable: An iterable to check if all elements evaluate to True
461
+ :type iterable: Iterable
462
+ :return: boolean of Python value or IR value
463
+ :rtype: bool or cutlass.Boolean
464
+ """
465
+ if executor._all_executor and executor._is_dynamic_expression(iterable):
466
+ return executor._all_executor(iterable)
467
+ else:
468
+ return all(iterable)
469
+
470
+
471
+ # =============================================================================
472
+ # Control flow checks
473
+ # =============================================================================
474
+ class DSLOptimizationWarning(Warning):
475
+ """
476
+ This warning is used to warn the user about the optimization related issues in DSL.
477
+ """
478
+
479
+ def __init__(self, message):
480
+ self.message = message
481
+ super().__init__()
482
+
483
+ def __str__(self):
484
+ return self.message
485
+
486
+
487
+ def range_value_check(*args):
488
+ """
489
+ Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
490
+ """
491
+ try:
492
+ args = tuple(arg.__index__() for arg in args)
493
+
494
+ # Compute range size and warn if it's too large
495
+ start = 0
496
+ end = 0
497
+ step = 1
498
+ if len(args) == 1:
499
+ end = args[0]
500
+ elif len(args) == 2:
501
+ start = args[0]
502
+ end = args[1]
503
+ elif len(args) == 3:
504
+ start = args[0]
505
+ end = args[1]
506
+ step = args[2]
507
+
508
+ range_length = (abs(end - start) - 1) // abs(step) + 1
509
+ if range_length >= 64:
510
+ warnings.warn(
511
+ f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.",
512
+ category=DSLOptimizationWarning,
513
+ stacklevel=2,
514
+ )
515
+
516
+ return (start, end, step)
517
+ except:
518
+ raise DSLRuntimeError(
519
+ "`range_constexpr` requires constexpr (compile-time constant) for all arguments.",
520
+ suggestion="Use `range` instead of `range_constexpr`.",
521
+ )
522
+
523
+
524
+ def range_perf_warning(filename, lineno, *args):
525
+ has_dynamic_expr = False
526
+ for arg in args:
527
+ if executor._is_dynamic_expression(arg):
528
+ has_dynamic_expr = True
529
+ break
530
+ if not has_dynamic_expr:
531
+ warnings.warn_explicit(
532
+ (
533
+ "This loop is no longer unrolled and may cause performance regression. "
534
+ "Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants."
535
+ ),
536
+ category=DSLOptimizationWarning,
537
+ filename=filename,
538
+ lineno=lineno,
539
+ )
540
+
541
+
542
+ @lru_cache(maxsize=1)
543
+ def _get_self_module():
544
+ """
545
+ This function is used to get the owning module of this function.
546
+ """
547
+ return inspect.getmodule(_get_self_module)
548
+
549
+
550
+ def cf_symbol_check(symbol):
551
+ """
552
+ Check if the symbol is control flow symbol from current module.
553
+ """
554
+
555
+ failed = False
556
+ name = symbol.__name__
557
+ self_module = _get_self_module()
558
+ if inspect.ismodule(symbol):
559
+ name = "range"
560
+ if not self_module.__name__.startswith(symbol.__name__):
561
+ failed = True
562
+ else:
563
+ owning_module = inspect.getmodule(symbol)
564
+ if owning_module != self_module:
565
+ failed = True
566
+
567
+ if failed:
568
+ raise DSLRuntimeError(
569
+ f"Incorrect {symbol.__name__} is used.",
570
+ suggestion=f"Please avoid overriding `{symbol.__name__}` from DSL package.",
571
+ )
572
+
573
+
574
+ def redirect_builtin_function(fcn):
575
+ """
576
+ This function is used to redirect built-in function call
577
+ to the function defined in DSL package.
578
+ """
579
+ # Only redirect if it's a built-in
580
+ if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector:
581
+ return executor._builtin_redirector(fcn)
582
+ return fcn
583
+
584
+
585
+ def copy_members(dest, src):
586
+ """
587
+ Copies all non-callable, non-dunder members from src to dest if they exist in src.
588
+ Skips members that are callables or have names starting with double underscores.
589
+ """
590
+ if id(dest) == id(src):
591
+ return
592
+
593
+ members = getmembers(dest)
594
+ for name, value in members:
595
+ if (
596
+ name.startswith("__")
597
+ or isinstance(value, Callable)
598
+ or not hasattr(src, name)
599
+ ):
600
+ continue
601
+ setattr(dest, name, getattr(src, name))
602
+
603
+
604
+ def get_locals_or_none(locals, symbols):
605
+ """
606
+ Given a locals() dictionary and a list of symbol names, return a list of their values
607
+ in the same order as the symbols list. If a symbol is not present in locals, None is returned
608
+ for that symbol.
609
+ """
610
+ variables = []
611
+ for symbol in symbols:
612
+ if symbol in locals:
613
+ variables.append(locals[symbol])
614
+ else:
615
+ variables.append(None)
616
+ return variables
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_preprocessor.py ADDED
@@ -0,0 +1,1958 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ """
13
+ This module defines the `DSLPreprocessor` class, which acts as a Python preprocessor.
14
+ It uses Python's AST and rewrites specific Python statements such as `for` and `if-else`.
15
+
16
+ The preprocessor operates on the following constructs:
17
+ - `for` loops:
18
+ - Rewrites `for` loops with the `@loop_selector` decorator.
19
+ - Supports `range`, `range_dynamic` for loop iteration.
20
+ - `if-elif-else` statements:
21
+ - Rewrites conditional statements with the `@if_selector` decorator.
22
+ - Supports `dynamic_expr` and `const_expr` in the condition expressions.
23
+
24
+ Additionally, both `for` loops and `if-else` statements require `yield`
25
+ operation generation. The preprocessor handles this by:
26
+ - Using a `ScopeManager` to track symbols across different scopes during AST traversal.
27
+ - Identifying read-only, read-write, and active variables for DSL constructs.
28
+ - Generating `yield` operations for symbols that are classified as read-write or write.
29
+
30
+ It is designed to be generic and can handle `for` and `if` constructs from other dialects.
31
+ In such cases, the user's DSL should implement `@loop_selector` and `@if_selector`
32
+ to generate dialect-specific operations for `for` and `if` statements.
33
+ """
34
+
35
+ import ast
36
+ import importlib
37
+ import inspect
38
+ import textwrap
39
+ import warnings
40
+ from dataclasses import dataclass
41
+ from typing import List, Set, Dict, Any, Callable, Optional
42
+ from types import ModuleType
43
+ from collections import OrderedDict
44
+ from copy import deepcopy
45
+
46
+ from .common import *
47
+ from .utils.logger import log
48
+
49
+
50
+ class OrderedSet:
51
+ """
52
+ A deterministic set implementation for ordered operations.
53
+ """
54
+
55
+ def __init__(self, iterable=None):
56
+ self._dict = dict.fromkeys(iterable or [])
57
+
58
+ def add(self, item):
59
+ self._dict[item] = None
60
+
61
+ def __iter__(self):
62
+ return iter(self._dict)
63
+
64
+ def __and__(self, other):
65
+ return OrderedSet(key for key in self._dict if key in other)
66
+
67
+ def __or__(self, other):
68
+ new_dict = self._dict.copy()
69
+ new_dict.update(dict.fromkeys(other))
70
+ return OrderedSet(new_dict)
71
+
72
+ def __sub__(self, other):
73
+ return OrderedSet(key for key in self._dict if key not in other)
74
+
75
+ def intersections(self, others):
76
+ """Compute the intersection of this set with multiple other sets.
77
+
78
+ :param others: A list of sets to compute intersections with
79
+ :type others: List[Set[str]]
80
+ :return: A new ordered set containing elements that appear in this set
81
+ and at least one of the other sets
82
+ """
83
+ result = OrderedSet()
84
+ for key in self._dict:
85
+ for other in reversed(others):
86
+ if key in other:
87
+ result.add(key)
88
+ break
89
+ return result
90
+
91
+
92
+ @dataclass
93
+ class ImportInfo:
94
+ """
95
+ Information about an import expression.
96
+ """
97
+ module_path: str
98
+ attr_name: Optional[str]
99
+ alias_name: str
100
+
101
+
102
+ @dataclass
103
+ class ScopeManager:
104
+ """
105
+ Manages symbol scopes during AST traversal.
106
+ Manage nested scopes during transformations.
107
+ """
108
+
109
+ scopes: List[Set[str]]
110
+
111
+ @classmethod
112
+ def create(cls) -> "ScopeManager":
113
+ return cls([])
114
+
115
+ def add_to_scope(self, name: str) -> None:
116
+ if name == "_":
117
+ return
118
+ self.scopes[-1].add(name)
119
+
120
+ def get_active_symbols(self) -> List[Set[str]]:
121
+ return self.scopes.copy()
122
+
123
+ def __enter__(self) -> "ScopeManager":
124
+ self.scopes.append(set())
125
+ return self
126
+
127
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
128
+ self.scopes.pop()
129
+
130
+
131
+ class DSLPreprocessor(ast.NodeTransformer):
132
+ """
133
+ A preprocessor for transforming Python ASTs. It supports:
134
+
135
+ - Rewriting `for` loops with the `@loop_selector` decorator.
136
+ - Rewriting `if-elif-else` statements with the `@if_selector` decorator.
137
+ - Generating `yield` operations for read-write or write symbols.
138
+ """
139
+
140
+ DECORATOR_FOR_STATEMENT = "loop_selector"
141
+ DECORATOR_IF_STATEMENT = "if_selector"
142
+ DECORATOR_WHILE_STATEMENT = "while_selector"
143
+ IF_EXECUTOR = "if_executor"
144
+ WHILE_EXECUTOR = "while_executor"
145
+ ASSERT_EXECUTOR = "assert_executor"
146
+ BOOL_CAST = "bool_cast"
147
+ IMPLICIT_DOWNCAST_NUMERIC_TYPE = "implicitDowncastNumericType"
148
+ SUPPORTED_FOR_RANGE_STATEMENTS = {"range", "range_dynamic", "range_constexpr"}
149
+ COMPARE_EXECUTOR = "compare_executor"
150
+ ANY_EXECUTOR = "any_executor"
151
+ ALL_EXECUTOR = "all_executor"
152
+
153
+ def __init__(self, client_module_name):
154
+ super().__init__()
155
+ self.counter = 0 # Unique function names for multiple loops
156
+ self.scope_manager = ScopeManager.create()
157
+ self.processed_functions = set()
158
+ self.function_counter = 0
159
+ self.function_name = "<unknown function>"
160
+ self.class_name = None
161
+ self.file_name = "<unknown filename>"
162
+ self.function_depth = 0
163
+ self.local_closures = set()
164
+ self.function_globals = None
165
+ self.client_module_name = client_module_name
166
+ self.import_top_module = False
167
+
168
+ def _create_module_attribute(
169
+ self,
170
+ func_name,
171
+ *,
172
+ top_module_name="_dsl_",
173
+ submodule_name="ast_helpers",
174
+ lineno=None,
175
+ col_offset=None,
176
+ ):
177
+ # If we simply copy location from origin node, it contains a way to wide range, which cause location in traceback to be wrong.
178
+ def set_location(node, lineno, col_offset):
179
+ if lineno and col_offset:
180
+ node.lineno = lineno
181
+ node.end_lineno = lineno
182
+ node.col_offset = col_offset
183
+ node.end_col_offset = col_offset
184
+
185
+ base = ast.Name(id=top_module_name, ctx=ast.Load())
186
+ set_location(base, lineno, col_offset)
187
+ if submodule_name:
188
+ base = ast.Attribute(value=base, attr=submodule_name, ctx=ast.Load())
189
+ set_location(base, lineno, col_offset)
190
+ node = ast.Attribute(value=base, attr=func_name, ctx=ast.Load())
191
+ set_location(node, lineno, col_offset)
192
+ return node
193
+
194
+ def _get_module_imports(self, decorated_func):
195
+ """Extract imports from the module containing the decorated function"""
196
+ imports = []
197
+
198
+ # Get the module containing the decorated function
199
+ if module := inspect.getmodule(decorated_func):
200
+ try:
201
+ # Get the module source code
202
+ source = inspect.getsource(module)
203
+ module_ast = ast.parse(source)
204
+
205
+ # Extract imports from the full module
206
+ alias = lambda n: n.asname if n.asname else n.name
207
+ for node in ast.walk(module_ast):
208
+ if isinstance(node, ast.Import):
209
+ for name in node.names:
210
+ imports.append(
211
+ ImportInfo(
212
+ module_path=name.name,
213
+ attr_name=None,
214
+ alias_name=alias(name),
215
+ )
216
+ )
217
+ elif isinstance(node, ast.ImportFrom):
218
+ module_name = node.module
219
+ if node.level > 0:
220
+ # Handle relative imports
221
+ package_name = module.__package__.rsplit(
222
+ ".", node.level - 1
223
+ )[0]
224
+ module_name = f"{package_name}.{module_name}"
225
+ for name in node.names:
226
+ imports.append(
227
+ ImportInfo(
228
+ module_path=module_name,
229
+ attr_name=name.name,
230
+ alias_name=alias(name),
231
+ )
232
+ )
233
+ except (IOError, TypeError):
234
+ pass
235
+
236
+ return imports
237
+
238
+ def exec(self, function_name, original_function, code_object, exec_globals):
239
+ # Get imports from the original module
240
+ module_imports = self._get_module_imports(original_function)
241
+
242
+ # Import all required modules
243
+ for import_info in module_imports:
244
+ module_path, attr_name, alias_name = (
245
+ import_info.module_path,
246
+ import_info.attr_name,
247
+ import_info.alias_name,
248
+ )
249
+ try:
250
+ module = importlib.import_module(module_path)
251
+ if attr_name:
252
+ if attr_name == "*":
253
+ if hasattr(module, "__all__"):
254
+ attrs = module.__all__
255
+ else:
256
+ attrs = [
257
+ name for name in dir(module) if not name.startswith("_")
258
+ ]
259
+ else:
260
+ attrs = [attr_name]
261
+
262
+ for attr in attrs:
263
+ alias = attr if attr_name == "*" else alias_name
264
+ exec_globals[alias] = getattr(module, attr)
265
+ else:
266
+ exec_globals[alias_name] = module
267
+ except (ImportError, AttributeError) as e:
268
+ raise ImportError(f"Failed to import {module_path}: {str(e)}")
269
+
270
+ # Execute the transformed code
271
+ log().info(
272
+ "ASTPreprocessor Executing transformed code for function [%s]",
273
+ function_name,
274
+ )
275
+ exec(code_object, exec_globals)
276
+ return exec_globals.get(function_name)
277
+
278
+ @staticmethod
279
+ def print_ast(transformed_tree=None):
280
+ print("#", "-" * 40, "Transformed AST", "-" * 40)
281
+ unparsed_code = ast.unparse(transformed_tree)
282
+ print(unparsed_code)
283
+ print("#", "-" * 40, "End Transformed AST", "-" * 40)
284
+
285
+ def make_func_param_name(self, base_name, used_names):
286
+ """Generate a unique parameter name that doesn't collide with existing names."""
287
+ if base_name not in used_names:
288
+ return base_name
289
+
290
+ i = 0
291
+ while f"{base_name}_{i}" in used_names:
292
+ i += 1
293
+ return f"{base_name}_{i}"
294
+
295
+ def transform_function(self, func_name, function_pointer):
296
+ """
297
+ Transforms a function.
298
+ """
299
+ # Skip if the function has already been processed
300
+ if function_pointer in self.processed_functions:
301
+ log().info(
302
+ "ASTPreprocessor Skipping already processed function [%s]", func_name
303
+ )
304
+ return []
305
+
306
+ # Step 1. Parse the given function
307
+ file_name = inspect.getsourcefile(function_pointer)
308
+ lines, start_line = inspect.getsourcelines(function_pointer)
309
+ dedented_source = textwrap.dedent("".join(lines))
310
+ tree = ast.parse(dedented_source, filename=file_name)
311
+ # Bump the line numbers so they match the real source file
312
+ ast.increment_lineno(tree, start_line - 1)
313
+
314
+ # Step 1.2 Check the decorator
315
+ if not self.check_decorator(tree.body[0]):
316
+ log().info(
317
+ "[%s] - Skipping function due to missing decorator",
318
+ func_name,
319
+ )
320
+ return []
321
+
322
+ self.processed_functions.add(function_pointer)
323
+ log().info("ASTPreprocessor Transforming function [%s]", func_name)
324
+
325
+ # Step 2. Transform the function
326
+ transformed_tree = self.visit(tree)
327
+
328
+ # Step 3. Import cutlass and base_dsl
329
+ top_module_name = ".".join(self.client_module_name)
330
+ import_stmts = []
331
+ if self.import_top_module:
332
+ import_stmts.append(ast.Import(names=[ast.alias(name=top_module_name)]))
333
+ import_stmts.append(
334
+ ast.Import(
335
+ names=[ast.alias(name=f"{top_module_name}.base_dsl", asname="_dsl_")]
336
+ )
337
+ )
338
+ transformed_tree.body = import_stmts + transformed_tree.body
339
+
340
+ # Step 4. Import cutlass and base_dsl
341
+ ast.fix_missing_locations(transformed_tree)
342
+ combined_body = transformed_tree.body
343
+
344
+ # Step 5. Return the transformed tree
345
+ return combined_body
346
+
347
+ def check_early_exit(self, tree, kind):
348
+ """
349
+ Checks if a given region or scope in the provided Python code has early exits.
350
+ """
351
+
352
+ class EarlyExitChecker(ast.NodeVisitor):
353
+ def __init__(self, kind):
354
+ self.has_early_exit = False
355
+ self.early_exit_node = None
356
+ self.early_exit_type = None
357
+ self.kind = kind
358
+ self.loop_nest_level = 0
359
+
360
+ # Early exit is not allowed in any level of dynamic control flow
361
+ def visit_Return(self, node):
362
+ self.has_early_exit = True
363
+ self.early_exit_node = node
364
+ self.early_exit_type = "return"
365
+
366
+ def visit_Raise(self, node):
367
+ self.has_early_exit = True
368
+ self.early_exit_node = node
369
+ self.early_exit_type = "raise"
370
+
371
+ def visit_Break(self, node):
372
+ # For break/continue in inner loops, we don't consider it as early exit
373
+ if self.loop_nest_level == 0 and self.kind != "if":
374
+ self.has_early_exit = True
375
+ self.early_exit_node = node
376
+ self.early_exit_type = "break"
377
+
378
+ def visit_Continue(self, node):
379
+ if self.loop_nest_level == 0 and self.kind != "if":
380
+ self.has_early_exit = True
381
+ self.early_exit_node = node
382
+ self.early_exit_type = "continue"
383
+
384
+ def visit_For(self, node):
385
+ self.loop_nest_level += 1
386
+ self.generic_visit(node)
387
+ self.loop_nest_level -= 1
388
+
389
+ def visit_While(self, node):
390
+ self.loop_nest_level += 1
391
+ self.generic_visit(node)
392
+ self.loop_nest_level -= 1
393
+
394
+ checker = EarlyExitChecker(kind)
395
+ checker.generic_visit(tree)
396
+ if not checker.has_early_exit:
397
+ return
398
+ raise DSLAstPreprocessorError(
399
+ message=f"Early exit ({checker.early_exit_type}) is not allowed in `{self.function_name}`"
400
+ + (f" in `{self.class_name}`" if self.class_name else ""),
401
+ filename=self.file_name,
402
+ snippet=ast.unparse(tree),
403
+ suggestion=(
404
+ "If predicates are constant expression, write like "
405
+ "`if const_expr(...)` or `for ... in range_constexpr(...)`. "
406
+ "In that case, early exit will be executed by Python "
407
+ "interpreter, so it's supported."
408
+ ),
409
+ )
410
+
411
+ def is_node_constexpr(self, node) -> bool:
412
+ """
413
+ Determines if the node is a constexpr.
414
+ Supported nodes are if, while statements.
415
+ """
416
+ if isinstance(node, ast.If) or isinstance(node, ast.While):
417
+ if isinstance(node.test, ast.Call):
418
+ func = node.test.func
419
+
420
+ if isinstance(func, ast.Attribute) and func.attr == "const_expr":
421
+ return True
422
+
423
+ elif isinstance(func, ast.Name) and func.id == "const_expr":
424
+ return True
425
+ return False
426
+
427
+ def _get_range_kind(self, iter_node):
428
+ """
429
+ Return "range", "range_dynamic", "range_constexpr" or None for the iterable
430
+ """
431
+ if isinstance(iter_node, ast.Call):
432
+ func = iter_node.func
433
+ if (
434
+ isinstance(func, ast.Name)
435
+ and func.id in self.SUPPORTED_FOR_RANGE_STATEMENTS
436
+ ):
437
+ return func.id, True, len(iter_node.keywords) != 0
438
+ if (
439
+ isinstance(func, ast.Attribute)
440
+ and func.attr in self.SUPPORTED_FOR_RANGE_STATEMENTS
441
+ ):
442
+ return func.attr, False, len(iter_node.keywords) != 0
443
+ return None, None, None
444
+
445
+ def transform(self, original_function, exec_globals):
446
+ """
447
+ Transforms the provided function using the preprocessor.
448
+ """
449
+ self.file_name = inspect.getsourcefile(original_function)
450
+ self.function_globals = exec_globals
451
+ transformed_tree = self.transform_function(
452
+ original_function.__name__, original_function
453
+ )
454
+ self.function_globals = None
455
+ unified_tree = ast.Module(body=transformed_tree, type_ignores=[])
456
+ unified_tree = ast.fix_missing_locations(unified_tree)
457
+
458
+ return unified_tree
459
+
460
+ def analyze_region_variables(
461
+ self, node: Union[ast.For, ast.If], active_symbols: List[Set[str]]
462
+ ):
463
+ """
464
+ Analyze variables in different code regions to identify read-only, write-only,
465
+ and active variables for DSL constructs.
466
+ """
467
+
468
+ # we need orderedset to keep the insertion order the same. otherwise generated IR is different each time
469
+ write_args = OrderedSet()
470
+ invoked_args = OrderedSet()
471
+ local_closure = self.local_closures
472
+ file_name = self.file_name
473
+ region_node = node
474
+
475
+ class RegionAnalyzer(ast.NodeVisitor):
476
+ force_store = False
477
+
478
+ def visit_Name(self, node):
479
+ """
480
+ Mark every store as write.
481
+ """
482
+ if isinstance(node.ctx, ast.Store) or self.force_store:
483
+ write_args.add(node.id)
484
+
485
+ def visit_Subscript(self, node):
486
+ # When subscript occurs on the lhs of an assignment, the `Name` is still a load, but `Subscript` is marked as `Store`.
487
+ # We need to force the store for the `Name` to be marked as write.
488
+ if isinstance(node.ctx, ast.Store):
489
+ self.force_store = True
490
+ self.visit(node.value)
491
+ self.force_store = False
492
+ self.visit(node.slice)
493
+ else:
494
+ self.generic_visit(node)
495
+
496
+ def visit_Assign(self, node):
497
+ self.force_store = True
498
+ [self.visit(target) for target in node.targets]
499
+ self.force_store = False
500
+ self.visit(node.value)
501
+
502
+ def visit_AugAssign(self, node):
503
+ self.force_store = True
504
+ self.visit(node.target)
505
+ self.force_store = False
506
+ self.visit(node.value)
507
+
508
+ @staticmethod
509
+ def get_call_base(func_node):
510
+ if isinstance(func_node, ast.Attribute):
511
+ # If the .value is another Attribute, keep digging
512
+ if isinstance(func_node.value, ast.Attribute):
513
+ return RegionAnalyzer.get_call_base(func_node.value)
514
+ # If the .value is a Name, that's our base
515
+ elif isinstance(func_node.value, ast.Name):
516
+ return func_node.value.id
517
+ else:
518
+ # Could be something else (lambda, call, etc.)
519
+ return None
520
+ elif isinstance(func_node, ast.Name):
521
+ return None
522
+ return None
523
+
524
+ @staticmethod
525
+ def get_function_name(func_node: ast.Call):
526
+ if isinstance(func_node.func, ast.Name):
527
+ function_name = func_node.func.id
528
+ # Check if it's a method or attribute call
529
+ elif isinstance(func_node.func, ast.Attribute):
530
+ function_name = func_node.func.attr
531
+ else:
532
+ function_name = None
533
+ return function_name
534
+
535
+ def visit_Call(self, node):
536
+ base_name = RegionAnalyzer.get_call_base(node.func)
537
+
538
+ if isinstance(node.func, ast.Name):
539
+ func_name = node.func.id
540
+ if func_name in local_closure:
541
+ raise DSLAstPreprocessorError(
542
+ f"Function `{func_name}` is a closure and is not supported in for/if statements",
543
+ filename=file_name,
544
+ snippet=ast.unparse(region_node),
545
+ )
546
+
547
+ # Classes are mutable by default. Mark them as write. If they are
548
+ # dataclass(frozen=True), treat them as read in runtime.
549
+ if base_name is not None and base_name not in ("self"):
550
+ invoked_args.add(base_name)
551
+
552
+ self.generic_visit(node)
553
+
554
+ analyzer = RegionAnalyzer()
555
+ analyzer.visit(ast.Module(body=node))
556
+
557
+ # If arg is both write and invoke, remove from invoked_args
558
+ invoked_args = invoked_args - write_args
559
+
560
+ write_args = list(write_args.intersections(active_symbols))
561
+ invoked_args = list(invoked_args.intersections(active_symbols))
562
+
563
+ return write_args + invoked_args, len(write_args)
564
+
565
+ def extract_range_args(self, iter_node):
566
+ args = iter_node.args
567
+ if len(args) == 1:
568
+ return (
569
+ self.visit(ast.Constant(value=0)),
570
+ self.visit(args[0]),
571
+ self.visit(ast.Constant(value=1)),
572
+ False,
573
+ )
574
+ elif len(args) == 2:
575
+ return (
576
+ self.visit(args[0]),
577
+ self.visit(args[1]),
578
+ self.visit(ast.Constant(value=1)),
579
+ False,
580
+ )
581
+ elif len(args) == 3:
582
+ return self.visit(args[0]), self.visit(args[1]), self.visit(args[2]), True
583
+ else:
584
+ raise DSLAstPreprocessorError(
585
+ "Unsupported number of arguments in range", filename=self.file_name
586
+ )
587
+
588
+ def extract_unroll_args(self, iter_node):
589
+ keywords = {kw.arg: kw.value for kw in iter_node.keywords}
590
+ return (
591
+ keywords.get("unroll", ast.Constant(value=-1)),
592
+ keywords.get("unroll_full", ast.Constant(value=False)),
593
+ )
594
+
595
+ def issue_deprecation_warning(self, *, message, category, filename, lineno):
596
+ warnings.simplefilter("always", category) # turn off filter
597
+ warnings.warn_explicit(
598
+ message, category=category, filename=filename, lineno=lineno
599
+ )
600
+ warnings.simplefilter("default", category) # reset filter
601
+
602
+ def extract_prefetch_stages_args(self, iter_node):
603
+ keywords = {kw.arg: kw.value for kw in iter_node.keywords}
604
+ if "pipelining" in keywords:
605
+ self.issue_deprecation_warning(
606
+ message="pipelining is deprecated, use prefetch_stages instead",
607
+ category=DeprecationWarning,
608
+ filename=self.file_name,
609
+ lineno=iter_node.lineno,
610
+ )
611
+ return keywords.get("pipelining", ast.Constant(value=None))
612
+ return keywords.get("prefetch_stages", ast.Constant(value=None))
613
+
614
+ def create_loop_function(
615
+ self,
616
+ func_name,
617
+ node,
618
+ start,
619
+ stop,
620
+ step,
621
+ unroll,
622
+ unroll_full,
623
+ prefetch_stages,
624
+ write_args,
625
+ full_write_args_count,
626
+ ):
627
+ """
628
+ Creates a loop body function with the `loop_selector` decorator.
629
+ """
630
+
631
+ func_args = [ast.arg(arg=node.target.id, annotation=None)]
632
+ func_args += [ast.arg(arg=var, annotation=None) for var in write_args]
633
+
634
+ # Create the loop body
635
+ transformed_body = []
636
+ for stmt in node.body:
637
+ transformed_stmt = self.visit(stmt) # Recursively visit inner statements
638
+ if isinstance(transformed_stmt, list):
639
+ transformed_body.extend(transformed_stmt)
640
+ else:
641
+ transformed_body.append(transformed_stmt)
642
+
643
+ # Handle the return for a single iterated argument correctly
644
+ if len(write_args) == 0:
645
+ transformed_body.append(ast.Return())
646
+ else:
647
+ transformed_body.append(
648
+ ast.Return(
649
+ value=ast.List(
650
+ elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args],
651
+ ctx=ast.Load(),
652
+ )
653
+ )
654
+ )
655
+
656
+ # Define the decorator with parameters
657
+ decorator = ast.copy_location(
658
+ ast.Call(
659
+ func=self._create_module_attribute(
660
+ self.DECORATOR_FOR_STATEMENT,
661
+ lineno=node.lineno,
662
+ col_offset=node.col_offset,
663
+ ),
664
+ args=[start, stop, step],
665
+ keywords=[
666
+ ast.keyword(arg="unroll", value=unroll),
667
+ ast.keyword(arg="unroll_full", value=unroll_full),
668
+ ast.keyword(arg="prefetch_stages", value=prefetch_stages),
669
+ ast.keyword(
670
+ arg="write_args",
671
+ value=self.generate_get_locals_or_none_call(write_args),
672
+ ),
673
+ ast.keyword(
674
+ arg="full_write_args_count",
675
+ value=ast.Constant(value=full_write_args_count),
676
+ ),
677
+ ast.keyword(
678
+ arg="write_args_names",
679
+ value=ast.List(
680
+ elts=[ast.Constant(value=arg) for arg in write_args],
681
+ ctx=ast.Load(),
682
+ ),
683
+ ),
684
+ ],
685
+ ),
686
+ node,
687
+ )
688
+
689
+ return ast.copy_location(
690
+ ast.FunctionDef(
691
+ name=func_name,
692
+ args=ast.arguments(
693
+ posonlyargs=[],
694
+ args=func_args,
695
+ kwonlyargs=[],
696
+ kw_defaults=[],
697
+ defaults=[],
698
+ ),
699
+ body=transformed_body,
700
+ decorator_list=[decorator],
701
+ ),
702
+ node,
703
+ )
704
+
705
+ def visit_BoolOp(self, node):
706
+ # Visit child nodes first
707
+ self.generic_visit(node)
708
+
709
+ # It is necessary to expand short circuit evaluation explicit here
710
+ # Although we do not support inline if-else for IR generation, this is actually evaluated in Python
711
+ # So it's fine here
712
+ # Transform "and" to "and_"
713
+ if isinstance(node.op, ast.And):
714
+ # Create an if-else statement in AST form
715
+ # if type(lhs) == bool and lhs == False:
716
+ # return lhs
717
+ # else
718
+ # return and_(lhs, rhs)
719
+ short_circuit_value = ast.Constant(value=False)
720
+ helper_func = self._create_module_attribute(
721
+ "and_",
722
+ top_module_name="cutlass",
723
+ submodule_name=None,
724
+ lineno=node.lineno,
725
+ col_offset=node.col_offset,
726
+ )
727
+ self.import_top_module = True
728
+ # Transform "or" to "or_"
729
+ elif isinstance(node.op, ast.Or):
730
+ # Create an if-else statement in AST form
731
+ # if type(lhs) == bool and lhs == True:
732
+ # return lhs
733
+ # else
734
+ # return or_(lhs, rhs)
735
+ short_circuit_value = ast.Constant(value=True)
736
+ helper_func = self._create_module_attribute(
737
+ "or_",
738
+ top_module_name="cutlass",
739
+ submodule_name=None,
740
+ lineno=node.lineno,
741
+ col_offset=node.col_offset,
742
+ )
743
+ self.import_top_module = True
744
+ else:
745
+ # BoolOp should be either And or Or
746
+ raise DSLAstPreprocessorError(
747
+ f"Unsupported boolean operation: {node.op}",
748
+ filename=self.file_name,
749
+ snippet=ast.unparse(node),
750
+ )
751
+
752
+ def short_circuit_eval(value, short_circuit_value):
753
+ return ast.BoolOp(
754
+ op=ast.And(),
755
+ values=[
756
+ ast.Compare(
757
+ left=ast.Call(
758
+ func=ast.Name(id="type", ctx=ast.Load()),
759
+ args=[value],
760
+ keywords=[],
761
+ ),
762
+ ops=[ast.Eq()],
763
+ comparators=[ast.Name(id="bool", ctx=ast.Load())],
764
+ ),
765
+ ast.Compare(
766
+ left=value,
767
+ ops=[ast.Eq()],
768
+ comparators=[short_circuit_value],
769
+ ),
770
+ ],
771
+ )
772
+
773
+ lhs = node.values[0]
774
+
775
+ for i in range(1, len(node.values)):
776
+ test = short_circuit_eval(lhs, short_circuit_value)
777
+ lhs = ast.IfExp(
778
+ test=test,
779
+ body=lhs,
780
+ orelse=ast.Call(
781
+ func=helper_func,
782
+ args=[lhs, node.values[i]],
783
+ keywords=[],
784
+ ),
785
+ )
786
+
787
+ return ast.copy_location(lhs, node)
788
+
789
+ def visit_UnaryOp(self, node):
790
+ # Visit child nodes first
791
+ self.generic_visit(node)
792
+
793
+ # Transform "not" to "~" as we overload __invert__
794
+ if isinstance(node.op, ast.Not):
795
+ func_name = self._create_module_attribute(
796
+ "not_",
797
+ top_module_name="cutlass",
798
+ submodule_name=None,
799
+ lineno=node.lineno,
800
+ col_offset=node.col_offset,
801
+ )
802
+ self.import_top_module = True
803
+ return ast.copy_location(
804
+ ast.Call(func=func_name, args=[node.operand], keywords=[]), node
805
+ )
806
+
807
+ return node
808
+
809
+ def _insert_range_value_check(self, node):
810
+ """
811
+ Insert a check for range arguments
812
+ """
813
+ range_inputs = node.iter.args
814
+ check_call = ast.copy_location(
815
+ ast.Call(
816
+ func=self._create_module_attribute(
817
+ "range_value_check", lineno=node.lineno, col_offset=node.col_offset
818
+ ),
819
+ args=range_inputs,
820
+ keywords=[],
821
+ ),
822
+ node.iter,
823
+ )
824
+ node.iter = ast.copy_location(
825
+ ast.Call(
826
+ func=ast.Name(id="range", ctx=ast.Load()),
827
+ args=[ast.Starred(value=check_call, ctx=ast.Load())],
828
+ keywords=[],
829
+ ),
830
+ node.iter,
831
+ )
832
+
833
+ def _insert_cf_symbol_check(self, func):
834
+ """
835
+ Insert a check for range symbol
836
+ """
837
+ check_call = ast.copy_location(
838
+ ast.Call(
839
+ func=self._create_module_attribute(
840
+ "cf_symbol_check", lineno=func.lineno, col_offset=func.col_offset
841
+ ),
842
+ args=[deepcopy(func)],
843
+ keywords=[],
844
+ ),
845
+ func,
846
+ )
847
+ return ast.Expr(check_call)
848
+
849
+ def visit_For(self, node):
850
+ # For static for loop (for with range_constexpr or not range based for), preprocessor keeps the loop.
851
+ range_kind, is_builtin_range, has_keyword = self._get_range_kind(node.iter)
852
+ if range_kind == "range_constexpr" or range_kind == None:
853
+ self.generic_visit(node)
854
+ if range_kind == "range_constexpr":
855
+ check_call = self._insert_cf_symbol_check(node.iter.func)
856
+ # Rewrite range_constexpr to range
857
+ node.iter.func = ast.Name(id="range", ctx=ast.Load())
858
+ self._insert_range_value_check(node)
859
+ return [check_call, node]
860
+ return node
861
+
862
+ active_symbols = self.scope_manager.get_active_symbols()
863
+
864
+ with self.scope_manager:
865
+ if isinstance(node.target, ast.Name):
866
+ self.scope_manager.add_to_scope(node.target.id)
867
+
868
+ if range_kind == "range_dynamic":
869
+ # Generate a warning
870
+ self.issue_deprecation_warning(
871
+ message="range_dynamic is deprecated and will be removed in the future, please remove it.",
872
+ category=DeprecationWarning,
873
+ filename=self.file_name,
874
+ lineno=node.iter.lineno,
875
+ )
876
+
877
+ warning_call = None
878
+ if range_kind == "range" and is_builtin_range and not has_keyword:
879
+ # Warn about possible performance regression due to behavior change
880
+ warning_call = ast.Expr(
881
+ ast.Call(
882
+ func=self._create_module_attribute(
883
+ "range_perf_warning",
884
+ lineno=node.lineno,
885
+ col_offset=node.col_offset,
886
+ ),
887
+ args=[
888
+ ast.Constant(value=self.file_name),
889
+ ast.Constant(value=node.iter.lineno),
890
+ ]
891
+ + node.iter.args,
892
+ keywords=[],
893
+ )
894
+ )
895
+ ast.copy_location(warning_call, node.iter)
896
+
897
+ is_prefixed_range = range_kind == "range" and not is_builtin_range
898
+ check_call = None
899
+ if range_kind == "range_dynamic" or is_prefixed_range:
900
+ # Insert a check for range symbol
901
+ if not is_prefixed_range:
902
+ check_call = self._insert_cf_symbol_check(node.iter.func)
903
+ else:
904
+ # Get toplevel module
905
+ check_call = self._insert_cf_symbol_check(node.iter.func.value)
906
+
907
+ new_for_node = self.transform_for_loop(node, active_symbols)
908
+ if check_call is not None:
909
+ new_for_node = [check_call] + new_for_node
910
+
911
+ return new_for_node if warning_call is None else [warning_call] + new_for_node
912
+
913
+ @staticmethod
914
+ def _hoist_expr_to_assignments(expr, name):
915
+ return ast.copy_location(
916
+ ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=expr), expr
917
+ )
918
+
919
+ def _build_select_and_assign(self, *, name, test, body, orelse, location):
920
+ node = ast.copy_location(
921
+ ast.Assign(
922
+ targets=[ast.Name(id=name, ctx=ast.Store())],
923
+ value=ast.IfExp(
924
+ test=test,
925
+ body=body,
926
+ orelse=orelse,
927
+ ),
928
+ ),
929
+ location,
930
+ )
931
+ self.generic_visit(node)
932
+ return node
933
+
934
+ def _handle_negative_step(self, node, start_expr, stop_expr, step_expr):
935
+ # hoist start, stop, step to assignments
936
+ start_ori_name = f"start_ori_{self.counter}"
937
+ start = self._hoist_expr_to_assignments(start_expr, start_ori_name)
938
+ stop_ori_name = f"stop_ori_{self.counter}"
939
+ stop = self._hoist_expr_to_assignments(stop_expr, stop_ori_name)
940
+ step_ori_name = f"step_ori_{self.counter}"
941
+ step = self._hoist_expr_to_assignments(step_expr, step_ori_name)
942
+
943
+ extra_exprs = [start, stop, step]
944
+
945
+ # Handle possible negative step, generates the following code in Python:
946
+ # isNegative = step < 0
947
+ isNegative_name = f"isNegative_{self.counter}"
948
+ isNegative = ast.copy_location(
949
+ ast.Assign(
950
+ targets=[ast.Name(id=isNegative_name, ctx=ast.Store())],
951
+ value=ast.Compare(
952
+ left=ast.Name(id=step_ori_name, ctx=ast.Load()),
953
+ ops=[ast.Lt()],
954
+ comparators=[ast.Constant(value=0)],
955
+ ),
956
+ ),
957
+ step,
958
+ )
959
+
960
+ # start = stop if isNegative else start
961
+ start_name = f"start_{self.counter}"
962
+ start = self._build_select_and_assign(
963
+ name=start_name,
964
+ test=ast.Name(id=isNegative_name, ctx=ast.Load()),
965
+ body=ast.Name(id=stop_ori_name, ctx=ast.Load()),
966
+ orelse=ast.Name(id=start_ori_name, ctx=ast.Load()),
967
+ location=start,
968
+ )
969
+
970
+ # stop = start if isNegative else stop
971
+ stop_name = f"stop_{self.counter}"
972
+ stop = self._build_select_and_assign(
973
+ name=stop_name,
974
+ test=ast.Name(id=isNegative_name, ctx=ast.Load()),
975
+ body=ast.Name(id=start_ori_name, ctx=ast.Load()),
976
+ orelse=ast.Name(id=stop_ori_name, ctx=ast.Load()),
977
+ location=stop,
978
+ )
979
+
980
+ # step = -step if isNegative else step
981
+ step_name = f"step_{self.counter}"
982
+ step = self._build_select_and_assign(
983
+ name=step_name,
984
+ test=ast.Name(id=isNegative_name, ctx=ast.Load()),
985
+ body=ast.UnaryOp(
986
+ op=ast.USub(), operand=ast.Name(id=step_ori_name, ctx=ast.Load())
987
+ ),
988
+ orelse=ast.Name(id=step_ori_name, ctx=ast.Load()),
989
+ location=step,
990
+ )
991
+
992
+ # offset = start + stop if isNegative else 0
993
+ offset_name = f"offset_{self.counter}"
994
+ offset = self._build_select_and_assign(
995
+ name=offset_name,
996
+ test=ast.Name(id=isNegative_name, ctx=ast.Load()),
997
+ body=ast.BinOp(
998
+ op=ast.Add(),
999
+ left=ast.Name(id=start_name, ctx=ast.Load()),
1000
+ right=ast.Name(id=stop_name, ctx=ast.Load()),
1001
+ ),
1002
+ orelse=ast.Constant(value=0),
1003
+ location=node,
1004
+ )
1005
+
1006
+ extra_exprs.append(isNegative)
1007
+ extra_exprs.append(start)
1008
+ extra_exprs.append(stop)
1009
+ extra_exprs.append(step)
1010
+ extra_exprs.append(offset)
1011
+
1012
+ # Add this to begining of loop body
1013
+ # for i in range(start, stop, step):
1014
+ # i = offset - i if isNegative else i
1015
+ assert isinstance(node.target, ast.Name)
1016
+
1017
+ target_name = node.target.id
1018
+ target = self._build_select_and_assign(
1019
+ name=target_name,
1020
+ test=ast.Name(id=isNegative_name, ctx=ast.Load()),
1021
+ body=ast.BinOp(
1022
+ op=ast.Sub(),
1023
+ left=ast.Name(id=offset_name, ctx=ast.Load()),
1024
+ right=ast.Name(id=target_name, ctx=ast.Load()),
1025
+ ),
1026
+ orelse=ast.Name(id=target_name, ctx=ast.Load()),
1027
+ location=node.target,
1028
+ )
1029
+
1030
+ node.body.insert(0, target)
1031
+
1032
+ return (
1033
+ ast.Name(id=start_name, ctx=ast.Load()),
1034
+ ast.Name(id=stop_name, ctx=ast.Load()),
1035
+ ast.Name(id=step_name, ctx=ast.Load()),
1036
+ extra_exprs,
1037
+ )
1038
+
1039
+ def transform_for_loop(self, node, active_symbols):
1040
+ # Check for early exit and raise exception
1041
+ self.check_early_exit(node, "for")
1042
+ if node.orelse:
1043
+ raise DSLAstPreprocessorError(
1044
+ "dynamic for loop with else is not supported",
1045
+ filename=self.file_name,
1046
+ snippet=ast.unparse(node),
1047
+ )
1048
+
1049
+ # Get loop target variable name
1050
+ target_var_name = None
1051
+ target_var_is_active_before_loop = False
1052
+ if isinstance(node.target, ast.Name):
1053
+ target_var_name = node.target.id
1054
+ for active_symbol in active_symbols:
1055
+ if target_var_name in active_symbol:
1056
+ target_var_is_active_before_loop = True
1057
+ active_symbols.remove(active_symbol)
1058
+ break
1059
+
1060
+ # Add necessary exprs to handle this
1061
+ if target_var_is_active_before_loop:
1062
+ # Initialize an extra loop carried variable
1063
+ loop_carried_var_name = f"loop_carried_var_{self.counter}"
1064
+ pre_loop_expr = ast.copy_location(
1065
+ ast.Assign(
1066
+ targets=[ast.Name(id=loop_carried_var_name, ctx=ast.Store())],
1067
+ value=ast.Name(id=target_var_name, ctx=ast.Load()),
1068
+ ),
1069
+ node,
1070
+ )
1071
+ # append an extra assignment to the loop carried variable
1072
+ node.body.append(
1073
+ ast.copy_location(
1074
+ ast.Assign(
1075
+ targets=[ast.Name(id=loop_carried_var_name, ctx=ast.Store())],
1076
+ value=ast.Name(id=target_var_name, ctx=ast.Load()),
1077
+ ),
1078
+ node,
1079
+ )
1080
+ )
1081
+ active_symbols.append({loop_carried_var_name})
1082
+
1083
+ start_expr, stop_expr, step_expr, has_step = self.extract_range_args(node.iter)
1084
+ unroll, unroll_full = self.extract_unroll_args(node.iter)
1085
+ prefetch_stages = self.extract_prefetch_stages_args(node.iter)
1086
+ write_args, full_write_args_count = self.analyze_region_variables(
1087
+ node, active_symbols
1088
+ )
1089
+
1090
+ if has_step and self.client_module_name[0] == "cutlass":
1091
+ start, stop, step, exprs = self._handle_negative_step(
1092
+ node, start_expr, stop_expr, step_expr
1093
+ )
1094
+ else:
1095
+ start, stop, step, exprs = start_expr, stop_expr, step_expr, []
1096
+
1097
+ if target_var_is_active_before_loop:
1098
+ exprs.append(pre_loop_expr)
1099
+
1100
+ func_name = f"loop_body_{self.counter}"
1101
+ self.counter += 1
1102
+
1103
+ func_def = self.create_loop_function(
1104
+ func_name,
1105
+ node,
1106
+ start,
1107
+ stop,
1108
+ step,
1109
+ unroll,
1110
+ unroll_full,
1111
+ prefetch_stages,
1112
+ write_args,
1113
+ full_write_args_count,
1114
+ )
1115
+
1116
+ assign = self.create_cf_call(func_name, write_args, node)
1117
+
1118
+ # This should work fine as it modifies the AST structure
1119
+ exprs = exprs + [func_def] + assign
1120
+
1121
+ if target_var_is_active_before_loop:
1122
+ # Create a new assignment to the target variable
1123
+ exprs.append(
1124
+ ast.copy_location(
1125
+ ast.Assign(
1126
+ targets=[ast.Name(id=target_var_name, ctx=ast.Store())],
1127
+ value=ast.Name(id=loop_carried_var_name, ctx=ast.Load()),
1128
+ ),
1129
+ node,
1130
+ )
1131
+ )
1132
+
1133
+ return exprs
1134
+
1135
+ def visit_Assert(self, node):
1136
+ test = self.visit(node.test)
1137
+
1138
+ args = [ast.keyword(arg="test", value=test)]
1139
+ if node.msg:
1140
+ msg = self.visit(node.msg)
1141
+ args.append(ast.keyword(arg="msg", value=msg))
1142
+
1143
+ # Rewrite to assert_executor(test, msg)
1144
+ new_node = ast.Expr(
1145
+ ast.Call(
1146
+ func=self._create_module_attribute(
1147
+ self.ASSERT_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset
1148
+ ),
1149
+ args=[],
1150
+ keywords=args,
1151
+ )
1152
+ )
1153
+
1154
+ # Propagate line number from original node to new node
1155
+ ast.copy_location(new_node, node)
1156
+ return new_node
1157
+
1158
+ def visit_Call(self, node):
1159
+ func = node.func
1160
+ # Visit args and kwargs
1161
+ node.args = [self.visit(arg) for arg in node.args]
1162
+ node.keywords = [self.visit(kwarg) for kwarg in node.keywords]
1163
+
1164
+ # Rewrite call to some built-in functions
1165
+ if isinstance(func, ast.Name):
1166
+ # Check if the function is 'bool'
1167
+ if func.id == "bool":
1168
+ return ast.copy_location(
1169
+ ast.Call(
1170
+ func=self._create_module_attribute(
1171
+ self.BOOL_CAST,
1172
+ lineno=node.lineno,
1173
+ col_offset=node.col_offset,
1174
+ ),
1175
+ args=[node.args[0]],
1176
+ keywords=[],
1177
+ ),
1178
+ node,
1179
+ )
1180
+ elif func.id in ["any", "all"]:
1181
+ helper_func = (
1182
+ self.ANY_EXECUTOR if func.id == "any" else self.ALL_EXECUTOR
1183
+ )
1184
+ return ast.copy_location(
1185
+ ast.Call(
1186
+ func=self._create_module_attribute(
1187
+ helper_func, lineno=node.lineno, col_offset=node.col_offset
1188
+ ),
1189
+ args=[node.args[0]],
1190
+ keywords=[],
1191
+ ),
1192
+ node,
1193
+ )
1194
+ elif func.id in ["min", "max"]:
1195
+ return ast.copy_location(
1196
+ ast.Call(
1197
+ func=self._create_module_attribute(
1198
+ func.id,
1199
+ top_module_name="cutlass",
1200
+ submodule_name=None,
1201
+ lineno=node.lineno,
1202
+ col_offset=node.col_offset,
1203
+ ),
1204
+ args=[node.args[0], node.args[1]],
1205
+ keywords=[],
1206
+ ),
1207
+ node,
1208
+ )
1209
+ elif isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
1210
+ def create_downcast_call(arg):
1211
+ return ast.copy_location(
1212
+ ast.Call(
1213
+ func=self._create_module_attribute(
1214
+ self.IMPLICIT_DOWNCAST_NUMERIC_TYPE,
1215
+ submodule_name="typing",
1216
+ lineno=node.lineno,
1217
+ col_offset=node.col_offset,
1218
+ ),
1219
+ args=[arg],
1220
+ keywords=[],
1221
+ ),
1222
+ arg,
1223
+ )
1224
+ module = self.function_globals.get(func.value.id)
1225
+ if isinstance(module, ModuleType) and module.__package__.endswith(
1226
+ "._mlir.dialects"
1227
+ ):
1228
+ # Check if argument is Numeric, if so, call ir_value()
1229
+ args = []
1230
+ for arg in node.args:
1231
+ args.append(create_downcast_call(arg))
1232
+ kwargs = []
1233
+ for kwarg in node.keywords:
1234
+ kwargs.append(
1235
+ ast.copy_location(
1236
+ ast.keyword(
1237
+ arg=kwarg.arg,
1238
+ value=create_downcast_call(kwarg.value),
1239
+ ),
1240
+ kwarg,
1241
+ )
1242
+ )
1243
+ return ast.copy_location(
1244
+ ast.Call(func=func, args=args, keywords=kwargs), node
1245
+ )
1246
+ else:
1247
+ node.func = self.visit(node.func)
1248
+
1249
+ return node
1250
+
1251
+ def visit_ClassDef(self, node):
1252
+ self.class_name = node.name
1253
+ self.generic_visit(node)
1254
+ self.class_name = None
1255
+ return node
1256
+
1257
+ def _visit_target(self, target):
1258
+ if isinstance(target, ast.Name):
1259
+ self.scope_manager.add_to_scope(target.id)
1260
+ elif isinstance(target, ast.Tuple):
1261
+ for t in target.elts:
1262
+ if isinstance(t, ast.Name):
1263
+ self.scope_manager.add_to_scope(t.id)
1264
+
1265
+ def visit_Assign(self, node):
1266
+ for target in node.targets:
1267
+ self._visit_target(target)
1268
+ self.generic_visit(node)
1269
+ return node
1270
+
1271
+ def visit_AugAssign(self, node):
1272
+ self._visit_target(node.target)
1273
+ self.generic_visit(node)
1274
+ return node
1275
+
1276
+ def visit_Name(self, node):
1277
+ isLoad = isinstance(node.ctx, ast.Load)
1278
+ if node.id in ["max", "min", "any", "all"] and isLoad:
1279
+ return ast.copy_location(
1280
+ ast.Call(
1281
+ func=self._create_module_attribute(
1282
+ "redirect_builtin_function",
1283
+ lineno=node.lineno,
1284
+ col_offset=node.col_offset,
1285
+ ),
1286
+ args=[node],
1287
+ keywords=[],
1288
+ ),
1289
+ node,
1290
+ )
1291
+ elif node.id == "_" and isLoad:
1292
+ raise DSLAstPreprocessorError("Read '_' is not allowed")
1293
+ else:
1294
+ self.generic_visit(node)
1295
+ return node
1296
+
1297
+ def check_decorator(self, node: ast.AST) -> bool:
1298
+ """
1299
+ Check if the function has the correct decorator for preprocessing.
1300
+ """
1301
+ if not isinstance(node, ast.FunctionDef):
1302
+ return False
1303
+ decorator_list = node.decorator_list
1304
+ if len(decorator_list) == 0:
1305
+ return False
1306
+
1307
+ for d in decorator_list:
1308
+ if isinstance(d, ast.Call):
1309
+ if isinstance(d.func, ast.Attribute):
1310
+ if d.func.attr in ["jit", "kernel"]:
1311
+ if d.keywords == []:
1312
+ return True
1313
+ for keyword in d.keywords:
1314
+ if keyword.arg == "preprocess":
1315
+ try:
1316
+ if isinstance(keyword.value, ast.Constant):
1317
+ return keyword.value.value
1318
+ else:
1319
+ return ast.literal_eval(keyword.value)
1320
+ except:
1321
+ pass
1322
+
1323
+ elif isinstance(d, ast.Attribute):
1324
+ if d.attr in ["jit", "kernel"]:
1325
+ return True
1326
+
1327
+ return False
1328
+
1329
+ def remove_dsl_decorator(self, decorator_list):
1330
+ """
1331
+ Remove .jit and .kernel decorators
1332
+ The decorator can be in two forms:
1333
+ - @jit(...)
1334
+ - @jit
1335
+ """
1336
+ new_decorator_list = []
1337
+ decorator_names = ["jit", "kernel"]
1338
+ for d in decorator_list:
1339
+ is_jit_or_kernel = False
1340
+ if isinstance(d, ast.Call):
1341
+ if isinstance(d.func, ast.Attribute):
1342
+ if d.func.attr in decorator_names:
1343
+ is_jit_or_kernel = True
1344
+ elif isinstance(d, ast.Attribute):
1345
+ if d.attr in decorator_names:
1346
+ is_jit_or_kernel = True
1347
+
1348
+ if not is_jit_or_kernel:
1349
+ new_decorator_list.append(d)
1350
+ return new_decorator_list
1351
+
1352
+ def visit_FunctionDef(self, node):
1353
+ with self.scope_manager:
1354
+ self.function_counter += 1
1355
+ self.function_name = node.name
1356
+ if self.function_depth > 0:
1357
+ self.local_closures.add(node.name)
1358
+
1359
+ self.function_depth += 1
1360
+
1361
+ # Add function name and arguments
1362
+ self.scope_manager.add_to_scope(node.name)
1363
+ for arg in node.args.args:
1364
+ self.scope_manager.add_to_scope(arg.arg)
1365
+
1366
+ self.generic_visit(node)
1367
+
1368
+ self.function_depth -= 1
1369
+
1370
+ # Remove .jit and .kernel decorators
1371
+ node.decorator_list = self.remove_dsl_decorator(node.decorator_list)
1372
+ return node
1373
+
1374
+ def visit_With(self, node):
1375
+ with self.scope_manager:
1376
+ for item in node.items:
1377
+ if isinstance(item.optional_vars, ast.Name):
1378
+ self.scope_manager.add_to_scope(item.optional_vars.id)
1379
+ self.generic_visit(node)
1380
+
1381
+ return node
1382
+
1383
+ def visit_While(self, node):
1384
+ # Constexpr doesn't get preprocessed
1385
+ if self.is_node_constexpr(node):
1386
+ self.generic_visit(node)
1387
+ check = self._insert_cf_symbol_check(node.test.func)
1388
+ return [check, node]
1389
+
1390
+ active_symbols = self.scope_manager.get_active_symbols()
1391
+
1392
+ with self.scope_manager:
1393
+ # Check for early exit and raise exception
1394
+ self.check_early_exit(node, "while")
1395
+
1396
+ write_args, full_write_args_count = self.analyze_region_variables(
1397
+ node, active_symbols
1398
+ )
1399
+ func_name = f"while_region_{self.counter}"
1400
+ self.counter += 1
1401
+
1402
+ func_def = self.create_while_function(
1403
+ func_name, node, write_args, full_write_args_count
1404
+ )
1405
+ assign = self.create_cf_call(func_name, write_args, node)
1406
+
1407
+ return [func_def] + assign
1408
+
1409
+ def visit_Try(self, node):
1410
+ with self.scope_manager:
1411
+ self.generic_visit(node)
1412
+ return node
1413
+
1414
+ def visit_ExceptHandler(self, node):
1415
+ with self.scope_manager:
1416
+ if node.name: # Exception variable
1417
+ self.scope_manager.add_to_scope(node.name)
1418
+ self.generic_visit(node)
1419
+ return node
1420
+
1421
+ def create_cf_call(self, func_name, yield_args, node):
1422
+ """Creates the assignment statement for the if function call"""
1423
+ if not yield_args:
1424
+ return [
1425
+ ast.copy_location(
1426
+ ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load())), node
1427
+ )
1428
+ ]
1429
+ has_self = False
1430
+ for i, arg in enumerate(yield_args):
1431
+ if arg == "self":
1432
+ has_self = True
1433
+ yield_args[i] = "yield_self"
1434
+ break
1435
+ if len(yield_args) == 1:
1436
+ assign = ast.Assign(
1437
+ targets=[ast.Name(id=yield_args[0], ctx=ast.Store())],
1438
+ value=ast.Name(id=func_name, ctx=ast.Load()),
1439
+ )
1440
+ else:
1441
+ assign = ast.Assign(
1442
+ targets=[
1443
+ ast.Tuple(
1444
+ elts=[ast.Name(id=var, ctx=ast.Store()) for var in yield_args],
1445
+ ctx=ast.Store(),
1446
+ )
1447
+ ],
1448
+ value=ast.Name(id=func_name, ctx=ast.Load()),
1449
+ )
1450
+
1451
+ if has_self:
1452
+ fix_self = ast.Expr(
1453
+ value=ast.Call(
1454
+ func=self._create_module_attribute(
1455
+ "copy_members", lineno=node.lineno, col_offset=node.col_offset
1456
+ ),
1457
+ args=[
1458
+ ast.Name(id="self", ctx=ast.Load()),
1459
+ ast.Name(id="yield_self", ctx=ast.Load()),
1460
+ ],
1461
+ keywords=[],
1462
+ )
1463
+ )
1464
+ return [ast.copy_location(assign, node), ast.copy_location(fix_self, node)]
1465
+ else:
1466
+ return [ast.copy_location(assign, node)]
1467
+
1468
+ def visit_IfExp(self, node):
1469
+ """
1470
+ Visits an inline if-else expression (ternary operator).
1471
+ This is the Python equivalent of `x if condition else y`.
1472
+ """
1473
+ self.generic_visit(node)
1474
+ # Emit
1475
+ # node if type(pred) == bool else select_(pred, body, orelse)
1476
+ # so if pred is a python bool, use python to short-circuit and avoid emit arith.select
1477
+ self.import_top_module = True
1478
+ return ast.copy_location(
1479
+ ast.IfExp(
1480
+ test=ast.Compare(
1481
+ left=ast.Call(
1482
+ func=ast.Name(id="type", ctx=ast.Load()),
1483
+ args=[node.test],
1484
+ keywords=[],
1485
+ ),
1486
+ ops=[ast.Eq()],
1487
+ comparators=[ast.Name(id="bool", ctx=ast.Load())],
1488
+ ),
1489
+ body=node, # Original ternary expression
1490
+ orelse=ast.Call(
1491
+ func=self._create_module_attribute(
1492
+ "select_", top_module_name="cutlass", submodule_name=None
1493
+ ),
1494
+ args=[
1495
+ node.test,
1496
+ node.body,
1497
+ node.orelse,
1498
+ ],
1499
+ keywords=[],
1500
+ ),
1501
+ ),
1502
+ node,
1503
+ )
1504
+
1505
+ cmpops = {
1506
+ "Eq": "==",
1507
+ "NotEq": "!=",
1508
+ "Lt": "<",
1509
+ "LtE": "<=",
1510
+ "Gt": ">",
1511
+ "GtE": ">=",
1512
+ "Is": "is",
1513
+ "IsNot": "is not",
1514
+ "In": "in",
1515
+ "NotIn": "not in",
1516
+ }
1517
+ def compare_ops_to_str(self, node):
1518
+ names = [
1519
+ ast.Constant(value=self.cmpops[op.__class__.__name__]) for op in node.ops
1520
+ ]
1521
+ return ast.List(elts=names, ctx=ast.Load())
1522
+
1523
+ def visit_Compare(self, node):
1524
+ self.generic_visit(node)
1525
+
1526
+ comparator_strs = self.compare_ops_to_str(node)
1527
+
1528
+ keywords = [
1529
+ ast.keyword(arg="left", value=node.left),
1530
+ ast.keyword(
1531
+ arg="comparators", value=ast.List(elts=node.comparators, ctx=ast.Load())
1532
+ ),
1533
+ ast.keyword(arg="ops", value=comparator_strs),
1534
+ ]
1535
+
1536
+ call = ast.copy_location(
1537
+ ast.Call(
1538
+ func=self._create_module_attribute(self.COMPARE_EXECUTOR),
1539
+ args=[],
1540
+ keywords=keywords,
1541
+ ),
1542
+ node,
1543
+ )
1544
+
1545
+ return call
1546
+
1547
+ def visit_If(self, node):
1548
+ # const_expr doesn't get preprocessed
1549
+ if self.is_node_constexpr(node):
1550
+ self.generic_visit(node)
1551
+ check = self._insert_cf_symbol_check(node.test.func)
1552
+ return [check, node]
1553
+
1554
+ active_symbols = self.scope_manager.get_active_symbols()
1555
+ with self.scope_manager:
1556
+ # Check for early exit and raise exception
1557
+ self.check_early_exit(node, "if")
1558
+
1559
+ yield_args, full_write_args_count = self.analyze_region_variables(
1560
+ node, active_symbols
1561
+ )
1562
+ func_name = f"if_region_{self.counter}"
1563
+ self.counter += 1
1564
+
1565
+ func_def = self.create_if_function(
1566
+ func_name, node, yield_args, full_write_args_count
1567
+ )
1568
+ assign = self.create_cf_call(func_name, yield_args, node)
1569
+
1570
+ return [func_def] + assign
1571
+
1572
+ def generate_get_locals_or_none_call(self, write_args):
1573
+ return ast.Call(
1574
+ func=self._create_module_attribute("get_locals_or_none"),
1575
+ args=[
1576
+ ast.Call(
1577
+ func=ast.Name(id="locals", ctx=ast.Load()), args=[], keywords=[]
1578
+ ),
1579
+ ast.List(
1580
+ elts=[ast.Constant(value=arg) for arg in write_args],
1581
+ ctx=ast.Load(),
1582
+ ),
1583
+ ],
1584
+ keywords=[],
1585
+ )
1586
+
1587
+ def create_if_function(self, func_name, node, write_args, full_write_args_count):
1588
+ test_expr = self.visit(node.test)
1589
+ pred_name = self.make_func_param_name("pred", write_args)
1590
+ func_args = [ast.arg(arg=pred_name, annotation=None)]
1591
+ func_args += [ast.arg(arg=var, annotation=None) for var in write_args]
1592
+ func_args_then_else = [ast.arg(arg=var, annotation=None) for var in write_args]
1593
+
1594
+ then_body = []
1595
+ for stmt in node.body:
1596
+ transformed_stmt = self.visit(stmt) # Recursively visit inner statements
1597
+ if isinstance(transformed_stmt, list):
1598
+ then_body.extend(transformed_stmt)
1599
+ else:
1600
+ then_body.append(transformed_stmt)
1601
+
1602
+ # Create common return list for all blocks
1603
+ return_list = ast.List(
1604
+ elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args],
1605
+ ctx=ast.Load(),
1606
+ )
1607
+
1608
+ # Create common function arguments
1609
+ func_decorator_arguments = ast.arguments(
1610
+ posonlyargs=[], args=func_args, kwonlyargs=[], kw_defaults=[], defaults=[]
1611
+ )
1612
+ func_then_else_arguments = ast.arguments(
1613
+ posonlyargs=[],
1614
+ args=func_args_then_else,
1615
+ kwonlyargs=[],
1616
+ kw_defaults=[],
1617
+ defaults=[],
1618
+ )
1619
+
1620
+ then_block_name = f"then_block_{self.counter}"
1621
+ else_block_name = f"else_block_{self.counter}"
1622
+ elif_region_name = f"elif_region_{self.counter}"
1623
+ self.counter += 1
1624
+
1625
+ # Create then block
1626
+ then_block = ast.copy_location(
1627
+ ast.FunctionDef(
1628
+ name=then_block_name,
1629
+ args=func_then_else_arguments,
1630
+ body=then_body + [ast.Return(value=return_list)],
1631
+ decorator_list=[],
1632
+ ),
1633
+ node,
1634
+ )
1635
+
1636
+ # Decorator keywords
1637
+ decorator_keywords = [
1638
+ ast.keyword(
1639
+ arg="pred", value=test_expr
1640
+ ), # ast.Name(id="pred", ctx=ast.Load())
1641
+ ast.keyword(
1642
+ arg="write_args",
1643
+ value=self.generate_get_locals_or_none_call(write_args),
1644
+ ),
1645
+ ]
1646
+
1647
+ # Create decorator
1648
+ decorator = ast.copy_location(
1649
+ ast.Call(
1650
+ func=self._create_module_attribute(
1651
+ self.DECORATOR_IF_STATEMENT,
1652
+ lineno=node.lineno,
1653
+ col_offset=node.col_offset,
1654
+ ),
1655
+ args=[],
1656
+ keywords=decorator_keywords,
1657
+ ),
1658
+ node,
1659
+ )
1660
+
1661
+ # Executor keywords
1662
+ execute_keywords = [
1663
+ ast.keyword(arg="pred", value=ast.Name(id=pred_name, ctx=ast.Load())),
1664
+ ast.keyword(
1665
+ arg="write_args",
1666
+ value=ast.List(
1667
+ elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args],
1668
+ ctx=ast.Load(),
1669
+ ),
1670
+ ),
1671
+ ast.keyword(
1672
+ arg="full_write_args_count",
1673
+ value=ast.Constant(value=full_write_args_count),
1674
+ ),
1675
+ ast.keyword(
1676
+ arg="write_args_names",
1677
+ value=ast.List(
1678
+ elts=[ast.Constant(value=arg) for arg in write_args],
1679
+ ctx=ast.Load(),
1680
+ ),
1681
+ ),
1682
+ ast.keyword(
1683
+ arg="then_block", value=ast.Name(id=then_block_name, ctx=ast.Load())
1684
+ ),
1685
+ ]
1686
+
1687
+ # Handle different cases
1688
+ if not write_args and node.orelse == []:
1689
+ # No write_args case - only then_block needed
1690
+ execute_call = ast.copy_location(
1691
+ ast.Call(
1692
+ func=self._create_module_attribute(
1693
+ self.IF_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset
1694
+ ),
1695
+ args=[],
1696
+ keywords=execute_keywords,
1697
+ ),
1698
+ node,
1699
+ )
1700
+ func_body = [then_block, ast.Return(value=execute_call)]
1701
+ else:
1702
+ # Create else block based on node.orelse
1703
+ if node.orelse:
1704
+ if len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If):
1705
+ # Handle elif case
1706
+ elif_node = node.orelse[0]
1707
+ nested_if_name = elif_region_name
1708
+ # Recursion for nested elif
1709
+ nested_if = self.create_if_function(
1710
+ nested_if_name, elif_node, write_args, full_write_args_count
1711
+ )
1712
+ else_block = ast.FunctionDef(
1713
+ name=else_block_name,
1714
+ args=func_then_else_arguments,
1715
+ body=[
1716
+ nested_if,
1717
+ ast.Return(
1718
+ value=ast.Name(id=nested_if_name, ctx=ast.Load())
1719
+ ),
1720
+ ],
1721
+ decorator_list=[],
1722
+ )
1723
+ else:
1724
+
1725
+ else_body = []
1726
+ for stmt in node.orelse:
1727
+ transformed_stmt = self.visit(
1728
+ stmt
1729
+ ) # Recursively visit inner statements
1730
+ if isinstance(transformed_stmt, list):
1731
+ else_body.extend(transformed_stmt)
1732
+ else:
1733
+ else_body.append(transformed_stmt)
1734
+
1735
+ # Regular else block
1736
+ else_block = ast.FunctionDef(
1737
+ name=else_block_name,
1738
+ args=func_then_else_arguments,
1739
+ body=else_body + [ast.Return(value=return_list)],
1740
+ decorator_list=[],
1741
+ )
1742
+ else:
1743
+ # Default else block
1744
+ else_block = ast.FunctionDef(
1745
+ name=else_block_name,
1746
+ args=func_then_else_arguments,
1747
+ body=[ast.Return(value=return_list)],
1748
+ decorator_list=[],
1749
+ )
1750
+
1751
+ # Add else_block to execute keywords
1752
+ execute_keywords.append(
1753
+ ast.keyword(
1754
+ arg="else_block", value=ast.Name(id=else_block_name, ctx=ast.Load())
1755
+ )
1756
+ )
1757
+
1758
+ execute_call = ast.copy_location(
1759
+ ast.Call(
1760
+ func=self._create_module_attribute(
1761
+ self.IF_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset
1762
+ ),
1763
+ args=[],
1764
+ keywords=execute_keywords,
1765
+ ),
1766
+ node,
1767
+ )
1768
+ func_body = [
1769
+ then_block,
1770
+ ast.copy_location(else_block, node),
1771
+ ast.Return(value=execute_call),
1772
+ ]
1773
+
1774
+ return ast.copy_location(
1775
+ ast.FunctionDef(
1776
+ name=func_name,
1777
+ args=func_decorator_arguments,
1778
+ body=func_body,
1779
+ decorator_list=[decorator],
1780
+ ),
1781
+ node,
1782
+ )
1783
+
1784
+ def create_while_function(self, func_name, node, write_args, full_write_args_count):
1785
+ """Create a while function that looks like:
1786
+
1787
+ @while_selector(pred, write_args=[])
1788
+ def while_region(pred, write_args):
1789
+ def while_before_block(*write_args):
1790
+ # Note that during eval of pred can possibly alter yield_args
1791
+ return *pred, write_args
1792
+ def while_after_block(*write_args):
1793
+ ...loop_body_transformed...
1794
+ return write_args
1795
+ return self.while_executor(pred, write_args,
1796
+ while_before_block, while_after_block, constexpr)
1797
+ write_args = while_region(pred, write_args)
1798
+
1799
+ Which will later be executed as psuedo-code:
1800
+
1801
+ # Dynamic mode:
1802
+ scf.WhileOp(types(write_args), write_args)
1803
+ with InsertionPoint(before_block):
1804
+ cond, write_args = while_before_block(*write_args)
1805
+ scf.ConditionOp(cond, write_args)
1806
+ with InsertionPoint(after_block):
1807
+ write_args = while_after_block(write_args)
1808
+ scf.YieldOp(write_args)
1809
+ return while_op.results_
1810
+
1811
+ # Const mode:
1812
+ cond, write_args = while_before_block(write_args)
1813
+ while pred:
1814
+ write_args = body_block(write_args)
1815
+ cond, write_args = while_before_block(write_args)
1816
+ return write_args
1817
+ """
1818
+ test_expr = self.visit(node.test)
1819
+ pred_name = self.make_func_param_name("pred", write_args)
1820
+
1821
+ # Section: decorator construction
1822
+ decorator_keywords = [
1823
+ ast.keyword(arg="pred", value=test_expr),
1824
+ ast.keyword(
1825
+ arg="write_args",
1826
+ value=self.generate_get_locals_or_none_call(write_args),
1827
+ ),
1828
+ ]
1829
+ decorator = ast.copy_location(
1830
+ ast.Call(
1831
+ func=self._create_module_attribute(
1832
+ self.DECORATOR_WHILE_STATEMENT,
1833
+ lineno=node.lineno,
1834
+ col_offset=node.col_offset,
1835
+ ),
1836
+ args=[],
1837
+ keywords=decorator_keywords,
1838
+ ),
1839
+ node,
1840
+ )
1841
+
1842
+ # Section: Shared initialization for before and after blocks
1843
+ while_before_block_name = f"while_before_block_{self.counter}"
1844
+ while_after_block_name = f"while_after_block_{self.counter}"
1845
+ self.counter += 1
1846
+ block_args_args = [ast.arg(arg=var, annotation=None) for var in write_args]
1847
+ block_args = ast.arguments(
1848
+ posonlyargs=[],
1849
+ args=block_args_args,
1850
+ kwonlyargs=[],
1851
+ kw_defaults=[],
1852
+ defaults=[],
1853
+ )
1854
+
1855
+ yield_args_ast_name_list = ast.List(
1856
+ elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args],
1857
+ ctx=ast.Load(),
1858
+ )
1859
+
1860
+ # Section: while_before_block FunctionDef, which contains condition
1861
+ while_before_return_list = ast.List(
1862
+ elts=[test_expr, yield_args_ast_name_list],
1863
+ ctx=ast.Load(),
1864
+ )
1865
+ while_before_stmts = [ast.Return(value=while_before_return_list)]
1866
+ while_before_block = ast.copy_location(
1867
+ ast.FunctionDef(
1868
+ name=while_before_block_name,
1869
+ args=block_args,
1870
+ body=while_before_stmts,
1871
+ decorator_list=[],
1872
+ ),
1873
+ test_expr,
1874
+ )
1875
+
1876
+ # Section: while_after_block FunctionDef, which contains loop body
1877
+ while_after_stmts = []
1878
+ for stmt in node.body:
1879
+ transformed_stmt = self.visit(stmt) # Recursively visit inner statements
1880
+ if isinstance(transformed_stmt, list):
1881
+ while_after_stmts.extend(transformed_stmt)
1882
+ else:
1883
+ while_after_stmts.append(transformed_stmt)
1884
+ while_after_stmts.append(ast.Return(value=yield_args_ast_name_list))
1885
+
1886
+ while_after_block = ast.copy_location(
1887
+ ast.FunctionDef(
1888
+ name=while_after_block_name,
1889
+ args=block_args,
1890
+ body=while_after_stmts,
1891
+ decorator_list=[],
1892
+ ),
1893
+ node,
1894
+ )
1895
+
1896
+ # Section: Execute via executor
1897
+ execute_keywords = [
1898
+ ast.keyword(arg="pred", value=ast.Name(id=pred_name, ctx=ast.Load())),
1899
+ ast.keyword(
1900
+ arg="write_args",
1901
+ value=ast.List(
1902
+ elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args],
1903
+ ctx=ast.Load(),
1904
+ ),
1905
+ ),
1906
+ ast.keyword(
1907
+ arg="full_write_args_count",
1908
+ value=ast.Constant(value=full_write_args_count),
1909
+ ),
1910
+ ast.keyword(
1911
+ arg="while_before_block",
1912
+ value=ast.Name(id=while_before_block_name, ctx=ast.Load()),
1913
+ ),
1914
+ ast.keyword(
1915
+ arg="while_after_block",
1916
+ value=ast.Name(id=while_after_block_name, ctx=ast.Load()),
1917
+ ),
1918
+ ast.keyword(
1919
+ arg="write_args_names",
1920
+ value=ast.List(
1921
+ elts=[ast.Constant(value=arg) for arg in write_args],
1922
+ ctx=ast.Load(),
1923
+ ),
1924
+ ),
1925
+ ]
1926
+
1927
+ execute_call = ast.Call(
1928
+ func=self._create_module_attribute(
1929
+ self.WHILE_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset
1930
+ ),
1931
+ args=[],
1932
+ keywords=execute_keywords,
1933
+ )
1934
+
1935
+ # Putting everything together, FunctionDef for while_region
1936
+ func_args_args = [ast.arg(arg=pred_name, annotation=None)]
1937
+ func_args_args += [ast.arg(arg=var, annotation=None) for var in write_args]
1938
+ func_args = ast.arguments(
1939
+ posonlyargs=[],
1940
+ args=func_args_args,
1941
+ kwonlyargs=[],
1942
+ kw_defaults=[],
1943
+ defaults=[],
1944
+ )
1945
+
1946
+ return ast.copy_location(
1947
+ ast.FunctionDef(
1948
+ name=func_name,
1949
+ args=func_args,
1950
+ body=[
1951
+ while_before_block,
1952
+ while_after_block,
1953
+ ast.Return(value=execute_call),
1954
+ ],
1955
+ decorator_list=[decorator],
1956
+ ),
1957
+ node,
1958
+ )
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/cache_helpers.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ """
13
+ This module provides jit cache load/dump helper functions
14
+ """
15
+
16
+ import os
17
+ import uuid
18
+ import random
19
+ import tempfile
20
+ import pwd
21
+ import time
22
+ from pathlib import Path
23
+ import hashlib
24
+
25
+ from .utils.logger import log
26
+ from .jit_executor import JitExecutor
27
+
28
+ from .._mlir import ir
29
+
30
+ # =============================================================================
31
+ # Jit Cache Helper functions
32
+ # =============================================================================
33
+
34
+
35
+ def get_current_user():
36
+ # Try to get the user from the environment variable first
37
+ user = os.getenv("USER") or os.getenv("USERNAME")
38
+ if not user:
39
+ # Fallback for Unix-like systems
40
+ user = pwd.getpwuid(os.getuid()).pw_name
41
+ return user
42
+
43
+
44
+ try:
45
+ default_generated_ir_path = f"/tmp/{get_current_user()}/cutlass_python_cache/"
46
+ except Exception as e:
47
+ # If all else fails, provide a default fallback path
48
+ default_generated_ir_path = "/tmp/cutlass_python_cache/"
49
+ print(f"Could not determine user, using default path. Error: {e}")
50
+
51
+
52
+ def load_ir(file, asBytecode=False):
53
+ """Load generated IR from a file."""
54
+ assert "mlir" in file
55
+ func_name = file.split(".mlir")[0].split("dsl_")[-1]
56
+ with ir.Context() as ctx:
57
+ with open(file, "rb" if asBytecode else "r") as f:
58
+ module = ir.Module.parse(f.read())
59
+
60
+ return func_name, module
61
+
62
+
63
+ def make_unique_filename(fpath: Path, new_ext: str = None) -> Path:
64
+ """Generate a unique filename with an optional new extension."""
65
+ random_part = random.randint(0, 999999)
66
+ timestamp = time.time()
67
+ hash_input = f"{fpath}_{timestamp}_{random_part}".encode()
68
+ hash_code = hashlib.md5(hash_input).hexdigest()[:16] # Shorter hash for readability
69
+ stem_with_hash = f"{fpath.stem}_{hash_code}"
70
+ return fpath.with_name(stem_with_hash).with_suffix(new_ext or fpath.suffix)
71
+
72
+
73
+ def save_ir(
74
+ dsl_name: str,
75
+ module: object,
76
+ fname: str,
77
+ isTemp: bool = False,
78
+ asBytecode: bool = False,
79
+ ) -> str:
80
+ """Save generated IR to a file."""
81
+ initial_name = f"{dsl_name.lower()}_{fname}.mlir"
82
+ save_path = Path(tempfile.gettempdir() if isTemp else os.getcwd())
83
+ save_fname = save_path / initial_name
84
+ # Random ID to avoid any collisions
85
+ rnd_id = str(uuid.uuid4())
86
+ pid = os.getpid()
87
+ # use temp dir to be robust against program interruptions
88
+ temp_dir = os.path.join(save_path, f"tmp.pid_{pid}_{rnd_id}")
89
+ # If the process exits abnormally, may leave a temporary folder. Needs to be removed manually.
90
+ os.makedirs(temp_dir, exist_ok=False)
91
+ temp_fname = os.path.join(temp_dir, initial_name)
92
+
93
+ if asBytecode:
94
+ with open(temp_fname, "wb") as f:
95
+ module.operation.write_bytecode(f)
96
+ else:
97
+ with open(temp_fname, "w") as f:
98
+ print(module, file=f)
99
+ # os.replace is guaranteed to be atomic on POSIX systems if it succeeds
100
+ # so filepath cannot see a partial write
101
+ os.replace(temp_fname, save_fname)
102
+ os.removedirs(temp_dir)
103
+ log().debug("Generated IR saved into %s", save_fname)
104
+ return save_fname
105
+
106
+
107
+ def check_func_name(jit_cache, func_name):
108
+ if not func_name in jit_cache:
109
+ jit_cache[func_name] = JitExecutor(None, None, None, None, None, None)
110
+ return jit_cache
111
+
112
+
113
+ def load_cache_from_path(dsl_name, cache_limit, path=default_generated_ir_path):
114
+ """Load cache from a directory path."""
115
+ if not os.path.exists(path):
116
+ return dict()
117
+ files = os.listdir(path)
118
+ jit_cache = dict()
119
+ try:
120
+ for idx, file in enumerate(files):
121
+ if idx >= int(cache_limit):
122
+ break
123
+ # identify dsl prefix
124
+ if not file.startswith(f"{dsl_name.lower()}"):
125
+ continue
126
+ if ".mlir" in file:
127
+ func_name, ir_module = load_ir(
128
+ os.path.join(path, file), asBytecode=True
129
+ )
130
+ jit_cache = check_func_name(jit_cache, func_name)
131
+ jit_cache[func_name].ir_module = ir_module
132
+ except Exception as e:
133
+ print(f"{dsl_name} failed with loading generated IR cache.", e)
134
+ jit_cache = dict()
135
+ return jit_cache
136
+
137
+
138
+ def dump_cache_to_path(
139
+ dsl_name, jit_cache, cache_limit, path=default_generated_ir_path
140
+ ):
141
+ log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache))
142
+ os.makedirs(path, exist_ok=True)
143
+ original_path = os.getcwd()
144
+ try:
145
+ os.chdir(path)
146
+ for idx, [key, value] in enumerate(jit_cache.items()):
147
+ if idx >= int(cache_limit):
148
+ break
149
+ save_ir(dsl_name, value.ir_module, key, asBytecode=True)
150
+ except Exception as e:
151
+ print(f"{dsl_name} failed with caching generated IR", e)
152
+ finally:
153
+ os.chdir(original_path)
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/common.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ import os
13
+ from typing import Any, Dict, Iterable, Optional, Union
14
+
15
+ """
16
+ This module provides a Exception classes DSL class for any Dialect.
17
+ """
18
+
19
+
20
+ # Add color codes at the top of the file after imports
21
+ class Colors:
22
+ """ANSI color codes for error messages"""
23
+
24
+ RED = "\033[91m"
25
+ YELLOW = "\033[93m"
26
+ BLUE = "\033[94m"
27
+ GREEN = "\033[92m"
28
+ BOLD = "\033[1m"
29
+ RESET = "\033[0m"
30
+
31
+
32
+ # =============================================================================
33
+ # DSL Exceptions
34
+ # =============================================================================
35
+
36
+
37
+ class DSLBaseError(Exception):
38
+ """
39
+ Base exception for DSL-related errors.
40
+ Provides optional contextual metadata to aid in debugging.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ message: str,
46
+ line: Optional[int] = None,
47
+ snippet: Optional[str] = None,
48
+ filename: Optional[str] = None,
49
+ error_code: Optional[Union[str, int]] = None,
50
+ context: Optional[Union[Dict[str, Any], str]] = None,
51
+ suggestion: Optional[str] = None,
52
+ cause: Optional[BaseException] = None,
53
+ ) -> None:
54
+ self.message = message
55
+ self.line = line
56
+ self.filename = filename
57
+ self.snippet = snippet
58
+ self.error_code = error_code
59
+ self.context = context
60
+ self.suggestion = suggestion
61
+ self.cause = cause
62
+
63
+ super().__init__(self._format_message())
64
+
65
+ def _format_message(self):
66
+ """
67
+ Formats the complete error message with available metadata.
68
+ Override this in subclasses if you want to change formatting logic.
69
+ """
70
+ parts = [f"{self.__class__.__name__}: {self.message}"]
71
+
72
+ if self.error_code is not None:
73
+ parts.append(f"{Colors.BOLD}Error Code:{Colors.RESET} {self.error_code}\n")
74
+
75
+ if self.line is not None:
76
+ parts.append(f" Line: {self.line}")
77
+
78
+ if self.filename is not None:
79
+ parts.append(f" File: {self.filename}")
80
+
81
+ if self.snippet:
82
+ # Optionally truncate long snippets for readability
83
+ parts.append(f" Snippet: \n {self.snippet}")
84
+
85
+ if self.cause:
86
+ parts.append(f" Caused exception: {self.cause}")
87
+
88
+ if self.context:
89
+ if isinstance(self.context, dict):
90
+ parts.append(f"{Colors.BLUE}🔍 Additional Context:{Colors.RESET}\n")
91
+ for key, value in self.context.items():
92
+ parts.append(f" {key}: {value}")
93
+ else:
94
+ parts.append(
95
+ f"{Colors.BLUE}🔍 Additional Context:{Colors.RESET} {self.context}"
96
+ )
97
+
98
+ if self.suggestion:
99
+ parts.append(f"{Colors.GREEN}💡 Suggestions:{Colors.RESET}")
100
+ if isinstance(self.suggestion, (list, tuple)):
101
+ for suggestion in self.suggestion:
102
+ parts.append(f" {Colors.GREEN}{suggestion}{Colors.RESET}")
103
+ else:
104
+ parts.append(f" {self.suggestion}")
105
+
106
+ return "\n".join(parts)
107
+
108
+
109
+ class DSLRuntimeError(DSLBaseError):
110
+ """
111
+ Raised when an error occurs during JIT-time code generation in the DSL.
112
+ """
113
+
114
+ # Inherits all logic from DSLBaseError; override methods if you need
115
+ # specialized behavior or formatting for runtime errors.
116
+ pass
117
+
118
+
119
+ def _get_friendly_cuda_error_message(error_code, error_name):
120
+ # Avoid circular dependency
121
+ from .runtime.cuda import get_device_info
122
+
123
+ """Get a user-friendly error message for common CUDA errors."""
124
+ # Strip the byte string markers if present
125
+ if isinstance(error_name, bytes):
126
+ error_name = error_name.decode("utf-8")
127
+ elif (
128
+ isinstance(error_name, str)
129
+ and error_name.startswith("b'")
130
+ and error_name.endswith("'")
131
+ ):
132
+ error_name = error_name[2:-1]
133
+
134
+ # Add target architecture info
135
+ target_arch = os.getenv("CUTE_DSL_ARCH", "unknown")
136
+
137
+ error_messages = {
138
+ "CUDA_ERROR_INVALID_SOURCE": (
139
+ f"{Colors.RED}❌ Failed to load CUDA kernel - likely architecture mismatch.{Colors.RESET}\n\n"
140
+ ),
141
+ "CUDA_ERROR_NO_BINARY_FOR_GPU": (
142
+ f"{Colors.RED}❌ CUDA kernel not compatible with your GPU.{Colors.RESET}\n\n"
143
+ ),
144
+ "CUDA_ERROR_OUT_OF_MEMORY": (
145
+ f"{Colors.RED}💾 CUDA out of memory error.{Colors.RESET}\n\n"
146
+ ),
147
+ "CUDA_ERROR_INVALID_DEVICE": (
148
+ f"{Colors.RED}❌ Invalid CUDA device.{Colors.RESET}\n\n"
149
+ ),
150
+ "CUDA_ERROR_NOT_INITIALIZED": (
151
+ f"{Colors.RED}❌ CUDA context not initialized.{Colors.RESET}\n\n"
152
+ ),
153
+ "CUDA_ERROR_INVALID_VALUE": (
154
+ f"{Colors.RED}⚠️ Invalid parameter passed to CUDA operation.{Colors.RESET}\n\n"
155
+ f"{Colors.YELLOW}This is likely a bug - please report it with:{Colors.RESET}"
156
+ ),
157
+ }
158
+
159
+ error_suggestions = {
160
+ "CUDA_ERROR_INVALID_SOURCE": (
161
+ f"1. Ensure env CUTE_DSL_ARCH matches your GPU architecture",
162
+ f"2. Clear the compilation cache and regenerate the kernel",
163
+ f"3. Check CUDA toolkit installation",
164
+ ),
165
+ "CUDA_ERROR_NO_BINARY_FOR_GPU": (
166
+ f"Set env CUTE_DSL_ARCH to match your GPU architecture",
167
+ ),
168
+ "CUDA_ERROR_OUT_OF_MEMORY": (
169
+ f"1. Reduce batch size",
170
+ f"2. Reduce model size",
171
+ f"3. Free unused GPU memory",
172
+ ),
173
+ "CUDA_ERROR_INVALID_DEVICE": (
174
+ f"1. Check if CUDA device is properly initialized",
175
+ f"2. Verify GPU is detected: nvidia-smi",
176
+ f"3. Check CUDA_VISIBLE_DEVICES environment variable",
177
+ ),
178
+ "CUDA_ERROR_NOT_INITIALIZED": (
179
+ f"1. Check CUDA driver installation",
180
+ f"2. call `cuda.cuInit(0)` before any other CUDA operation",
181
+ f"3. Run nvidia-smi to confirm GPU status",
182
+ ),
183
+ "CUDA_ERROR_INVALID_VALUE": (
184
+ f"1. Your GPU model",
185
+ f"2. SM ARCH setting",
186
+ f"3. Steps to reproduce",
187
+ ),
188
+ }
189
+
190
+ message = error_messages.get(
191
+ error_name, f"{Colors.RED}Unknown CUDA error{Colors.RESET}"
192
+ )
193
+
194
+ # Add debug information
195
+ debug_info = f"\n- {Colors.BOLD}Error name: {error_name}\n"
196
+ debug_info += f"- CUDA_TOOLKIT_PATH: {os.getenv('CUDA_TOOLKIT_PATH', 'not set')}\n"
197
+ debug_info += (
198
+ f"- Target SM ARCH: {os.getenv('CUTE_DSL_ARCH', 'not set')}{Colors.RESET}\n"
199
+ )
200
+
201
+ try:
202
+ # Get GPU information using CUDA Python API
203
+ debug_info += f"\n{Colors.BLUE}📊 GPU Information:{Colors.RESET}\n"
204
+ gpu_info = get_device_info()
205
+ debug_info += gpu_info.pretty_str()
206
+
207
+ if target_arch and gpu_info.compatible_archs:
208
+ debug_info += f"\n{Colors.BOLD}Compatibility Check:{Colors.RESET}\n"
209
+
210
+ if target_arch not in gpu_info.compatible_archs:
211
+ debug_info += (
212
+ f"{Colors.RED}❌ Error: Target SM ARCH {target_arch} is not compatible\n"
213
+ f"💡 Please use one of SM ARCHs: "
214
+ f"{Colors.GREEN}{', '.join(gpu_info.compatible_archs or [])}{Colors.RESET}\n"
215
+ )
216
+ elif target_arch != gpu_info.sm_arch:
217
+ debug_info += (
218
+ f"{Colors.YELLOW}⚠️ Warning: Using compatible but non-optimal architecture\n"
219
+ f"• Current: {target_arch}\n"
220
+ f"• Recommended: {Colors.GREEN}{gpu_info.sm_arch}{Colors.RESET} (native)\n"
221
+ )
222
+ else:
223
+ debug_info += f"{Colors.GREEN}✓ Using optimal architecture: {gpu_info.sm_arch}{Colors.RESET}\n"
224
+
225
+ except Exception as e:
226
+ debug_info += (
227
+ f"\n{Colors.YELLOW}ℹ️ Could not retrieve GPU info: {str(e)}{Colors.RESET}"
228
+ )
229
+
230
+ return message, debug_info, error_suggestions.get(error_name, "")
231
+
232
+
233
+ class DSLCudaRuntimeError(DSLBaseError):
234
+ """
235
+ Raised when an error occurs during CUDA runtime code generation in the DSL.
236
+ """
237
+
238
+ # Inherits all logic from DSLRuntimeError; override methods if you need
239
+ # specialized behavior or formatting for runtime errors.
240
+ def __init__(self, error_code, error_name) -> None:
241
+ self._error_code = error_code
242
+ self._error_name = error_name
243
+ message, debug_info, suggestion = _get_friendly_cuda_error_message(
244
+ error_code, error_name
245
+ )
246
+
247
+ super().__init__(
248
+ message, error_code=error_code, context=debug_info, suggestion=suggestion
249
+ )
250
+
251
+
252
+ class DSLAstPreprocessorError(DSLBaseError):
253
+ """
254
+ Raised when an error occurs during AST preprocessing or visiting in the DSL.
255
+ """
256
+
257
+ # Same approach: You could override _format_message if you want
258
+ # to emphasize AST node details or anything specific to preprocessing.
259
+ pass
260
+
261
+
262
+ class DSLNotImplemented(DSLBaseError):
263
+ """
264
+ Raised when a feature of the DSL is not implemented yet.
265
+ """
266
+
267
+ # Useful for stubs in your DSL that you plan to implement in the future.
268
+ pass
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/compiler.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ """
13
+ This module provides a class that compiles generated IR using MLIR's PassManager
14
+ and executes it using MLIR's ExecutionEngine.
15
+
16
+ """
17
+
18
+ from typing import Sequence, Optional, Tuple
19
+ import os
20
+ import sys
21
+ import inspect
22
+ import argparse
23
+ from .common import DSLRuntimeError
24
+ from .utils.logger import log
25
+
26
+ _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
27
+ sys.path.append(_SCRIPT_PATH)
28
+
29
+ from .._mlir import ir
30
+
31
+
32
+ # =============================================================================
33
+ # Compiler Class
34
+ # =============================================================================
35
+
36
+
37
+ class CompilationError(RuntimeError):
38
+ """Custom error class for compilation failures"""
39
+
40
+ # Add ANSI color codes
41
+ RED = "\033[91m"
42
+ YELLOW = "\033[93m"
43
+ BLUE = "\033[94m"
44
+ GREEN = "\033[92m"
45
+ BOLD = "\033[1m"
46
+ RESET = "\033[0m"
47
+
48
+ def __init__(
49
+ self,
50
+ message: str,
51
+ nvvm_error: Optional[str] = None,
52
+ ir_context: Optional[str] = None,
53
+ cuda_toolkit: Optional[str] = None,
54
+ arch: Optional[str] = None,
55
+ ):
56
+ self.nvvm_error = nvvm_error
57
+ self.ir_context = ir_context
58
+ self.cuda_toolkit = cuda_toolkit
59
+ self.arch = arch
60
+ # Call parent with formatted error to avoid showing class name
61
+ super().__init__("") # Empty string to avoid class name
62
+ # Store formatted error for str() representation
63
+ self._formatted_error = self._format_error()
64
+
65
+ def __str__(self) -> str:
66
+ """Override string representation to avoid showing class name"""
67
+ return self._formatted_error
68
+
69
+ def __repr__(self) -> str:
70
+ """Override repr representation to avoid showing class name"""
71
+ return self._formatted_error
72
+
73
+ def _format_error(self) -> str:
74
+ if not self.nvvm_error:
75
+ return str(self.args[0])
76
+
77
+ return f"""NVVM Compilation Error:
78
+ ----------------------
79
+
80
+ {self.BLUE}⚙️ Current Settings:{self.RESET}
81
+ {self.BOLD}- CUDA Toolkit Path: {self.cuda_toolkit or "Not Set"}
82
+ - Target Architecture: {self.arch}{self.RESET}
83
+
84
+ IR Context (truncated):
85
+ {self.ir_context}
86
+
87
+ {self.YELLOW}💡 Possible Solutions:{self.RESET}
88
+ {self.GREEN}1. Check if CUDA_TOOLKIT_PATH is set correctly
89
+ 2. Verify target architecture ({self.arch}) is supported by your CUDA toolkit
90
+ 3. Make sure CUDA toolkit version matches the target architecture requirements{self.RESET}"""
91
+
92
+
93
+ class Compiler:
94
+ """Compiler class for compiling and building MLIR modules."""
95
+
96
+ def __init__(self, passmanager, execution_engine):
97
+ self.passmanager = passmanager
98
+ self.execution_engine = execution_engine
99
+
100
+ def __call__(self, module):
101
+ """Convenience application method."""
102
+ self.compile(module)
103
+
104
+ def _process_error(self, error_msg: str) -> Tuple[Optional[str], Optional[str]]:
105
+ """Process error message to extract NVVM error and IR context"""
106
+ nvvm_error = None
107
+ ir_msg = ""
108
+
109
+ if "NVVM_ERROR" in error_msg:
110
+ # Extract the specific NVVM error
111
+ nvvm_error = (
112
+ error_msg.split("libNVVM extra log:")[1].strip()
113
+ if "libNVVM extra log:" in error_msg
114
+ else error_msg
115
+ )
116
+
117
+ # Extract IR context
118
+ if "see current operation:" in error_msg:
119
+ # Get the IR section
120
+ ir_section = error_msg.split("see current operation:")[1].strip()
121
+ # Remove duplicate IR section
122
+ ir_section = ir_section.split("error: unknown: Failed translating")[
123
+ 0
124
+ ].strip()
125
+
126
+ # Get first few lines and last few lines of the IR
127
+ ir_lines = ir_section.split("\n")
128
+ if len(ir_lines) > 10:
129
+ ir_msg = "\n".join(ir_lines[:5] + [" ..."] + ir_lines[-5:])
130
+ else:
131
+ ir_msg = ir_section
132
+
133
+ return nvvm_error, ir_msg
134
+
135
+ def compile(
136
+ self,
137
+ module,
138
+ pipeline: str,
139
+ cuda_toolkit: str = "",
140
+ arch: str = "",
141
+ enable_verifier=False,
142
+ ):
143
+ """Compiles the module by invoking the pipeline."""
144
+ try:
145
+ pm = self.passmanager.PassManager.parse(pipeline)
146
+ pm.enable_verifier(enable_verifier)
147
+ pm.run(module.operation)
148
+ except Exception as e:
149
+ error_msg = str(e)
150
+ nvvm_error, ir_msg = self._process_error(error_msg)
151
+
152
+ if nvvm_error:
153
+ raise CompilationError(
154
+ error_msg,
155
+ nvvm_error=nvvm_error,
156
+ ir_context=ir_msg,
157
+ cuda_toolkit=cuda_toolkit,
158
+ arch=arch,
159
+ ) from e
160
+ raise e
161
+
162
+ def jit(self, module, opt_level: int = 2, shared_libs: Sequence[str] = ()):
163
+ """Wraps the module in a JIT execution engine."""
164
+ return self.execution_engine.ExecutionEngine(
165
+ module, opt_level=opt_level, shared_libs=shared_libs
166
+ )
167
+
168
+ def compile_and_jit(
169
+ self,
170
+ module,
171
+ pipeline: str,
172
+ shared_libs: Sequence[str] = (),
173
+ opt_level: int = 2,
174
+ cuda_toolkit: str = "",
175
+ arch: str = "",
176
+ ):
177
+ """Compiles and jits the module."""
178
+ self.compile(
179
+ module,
180
+ pipeline,
181
+ cuda_toolkit,
182
+ arch,
183
+ )
184
+ return self.jit(module, opt_level, shared_libs)
185
+
186
+
187
+ class CompileOptions:
188
+ def __init__(self, options: str = ""):
189
+ """
190
+ This class encapsulates all compilation options relevant to function compilation.
191
+ It provides a convenient way to manage and pass compilation options,
192
+ particularly for controlling compilation settings.
193
+ By centralizing these options, it ensures consistent and flexible configuration of
194
+ compilation parameters such as optimization level, debugging control, etc.
195
+
196
+ :param options: The options for the function. Will be parsed by argparse.
197
+ :type options: str
198
+ """
199
+ if not isinstance(options, str):
200
+ raise DSLRuntimeError(
201
+ f"Invalid compilation `options`: {options}, it should be a string"
202
+ )
203
+ self._parser = argparse.ArgumentParser()
204
+ self._parser.add_argument("--opt-level", nargs="?", type=int, default=3)
205
+ self._parser.add_argument(
206
+ "--enable-device-assertions", action="store_true", default=False
207
+ )
208
+ self._parser.add_argument("--link-libraries", type=str, default="")
209
+
210
+ try:
211
+ self._options = self._parser.parse_args(options.split())
212
+ except SystemExit as e:
213
+ # catch argparse error and raise as DSLRuntimeError
214
+ raise DSLRuntimeError(
215
+ f"Invalid compile options: '{options}'. Please check the option values and format."
216
+ )
217
+ log().info("`cute.compile` CompileOptions: options=" + options)
218
+
219
+ def to_str(self):
220
+ """
221
+ Generate a string representation of all compilation options
222
+ which will be used in pipeline options.
223
+ """
224
+ option_strings = []
225
+ for key, value in vars(self._options).items():
226
+ hyphen_key = key.replace("_", "-")
227
+ if isinstance(value, bool):
228
+ formatted_value = "true" if value else "false"
229
+ else:
230
+ formatted_value = str(value)
231
+ option_strings.append(f"{hyphen_key}={formatted_value}")
232
+
233
+ return " ".join(option_strings)
234
+
235
+
236
+ def compile(func, *args, **kwargs):
237
+ """
238
+ This function is used to compile a `cute.jit` decorated function.
239
+ It will process the compile options and input parameters, do explicit compilation and return the jit executor.
240
+
241
+ :param func: The function to compile. It can be a regular function, a method or a class instance.
242
+ :param args: The arguments to pass to the function.
243
+ :param kwargs: The keyword arguments to pass to the function. It can contain `options` like
244
+ `opt_level` to control the compilation flags.
245
+
246
+ :return: The jit executor.
247
+
248
+ :raises: DSLRuntimeError if the function is not decorated with `cute.jit` or is not callable.
249
+ """
250
+ if func is None:
251
+ raise DSLRuntimeError("Function is not set or invalid.")
252
+
253
+ if not callable(func):
254
+ raise DSLRuntimeError("Object is not callable.")
255
+
256
+ kwargs["compile_only"] = True
257
+ kwargs["no_cache"] = True
258
+
259
+ if inspect.isfunction(func):
260
+ # regular function
261
+ pass
262
+ elif inspect.ismethod(func):
263
+ # if it's a method, add the instance to the first argument
264
+ args = [func.__self__] + list(args)
265
+ func = func.__func__
266
+ elif inspect.isclass(type(func)) and hasattr(func, "__call__"):
267
+ # If it's a class instance, get the class's __call__ method
268
+ args = [func] + list(args)
269
+ # Get the actual function from the class definition
270
+ func = func.__call__.__func__
271
+ else:
272
+ raise DSLRuntimeError(
273
+ "Invalid function type, only function, method and module are supported, but got",
274
+ func,
275
+ )
276
+
277
+ # If it's a wrapped function created by jit decorator, get the original function
278
+ if hasattr(func, "__wrapped__"):
279
+ func = func.__wrapped__
280
+
281
+ if not hasattr(func, "_dsl_object"):
282
+ raise DSLRuntimeError("Function is not decorated with jit decorator.")
283
+
284
+ # process compile options, extract the options and remove them from the kwargs
285
+ options = kwargs.pop("options", "")
286
+ func._dsl_object.compile_options = CompileOptions(options)
287
+ fcn_ptr = func._dsl_object._preprocess_and_execute(func)
288
+ return func._dsl_object._func(fcn_ptr, *args, **kwargs)
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/dsl.py ADDED
@@ -0,0 +1,1686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ """
13
+ This module provides a main DSL class for any Dialect.
14
+ The DSL should be inherited as a new class, and its initialization requires dialects.
15
+ It handles most of the mechanics for the DSL in an agnostic way,
16
+ for example, it can handle various dialect-specific tasks.
17
+ """
18
+
19
+
20
+ # Standard library imports
21
+ from dataclasses import dataclass, field
22
+ import atexit
23
+ import os
24
+ import io
25
+ import sys
26
+ import errno
27
+ import ctypes
28
+ import re
29
+ import inspect
30
+ import argparse
31
+ import hashlib
32
+ from functools import lru_cache, wraps
33
+ from collections import namedtuple
34
+ from abc import ABC, abstractmethod
35
+ from typing import Any, Union, Tuple, get_origin, get_args, List
36
+ from types import FunctionType, SimpleNamespace
37
+ import warnings
38
+
39
+ from . import typing as t
40
+ from .env_manager import EnvironmentVarManager
41
+ from .compiler import CompileOptions
42
+ from .ast_helpers import DSLOptimizationWarning
43
+
44
+ # =============================================================================
45
+ # CUDA Python
46
+ # =============================================================================
47
+
48
+ from ..base_dsl._mlir_helpers.arith import const
49
+
50
+ # =============================================================================
51
+ # Local module imports
52
+ # =============================================================================
53
+
54
+ from .cache_helpers import *
55
+ from .jit_executor import JitExecutor
56
+ from .utils.timer import timer
57
+ from .utils.logger import setup_log, log
58
+ from .utils.stacktrace import filter_exception, walk_to_top_module, filter_stackframe
59
+ from .runtime.jit_arg_adapters import is_argument_constexpr, JitArgAdapterRegistry
60
+
61
+ from .ast_preprocessor import DSLPreprocessor
62
+ from .common import *
63
+ from .typing import (
64
+ get_c_pointers,
65
+ get_mlir_types,
66
+ )
67
+
68
+ # =============================================================================
69
+ # MLIR modules
70
+ # =============================================================================
71
+
72
+ from .._mlir import ir
73
+ from .._mlir import runtime as rt
74
+ from .._mlir.extras import types as T
75
+ from .._mlir.dialects import arith, math, func
76
+
77
+ # =============================================================================
78
+ # Global Variables
79
+ # =============================================================================
80
+
81
+ MLIR_DYNAMIC = -9223372036854775808
82
+
83
+ # =============================================================================
84
+ # Codegen Utils
85
+ # =============================================================================
86
+
87
+
88
+ def _numpy_type_to_mlir_type(dtype):
89
+ if dtype == np.float64:
90
+ return T.f64()
91
+ if dtype == np.float16:
92
+ return T.f16()
93
+ if dtype == np.float32:
94
+ return T.f32()
95
+ if dtype == np.int64:
96
+ return T.i64()
97
+ if dtype == np.int32:
98
+ return T.i32()
99
+ if dtype == np.int16:
100
+ return T.i16()
101
+ if dtype == np.int8:
102
+ return T.i8()
103
+ if dtype == np.uint64:
104
+ return T.ui64()
105
+ if dtype == np.uint32:
106
+ return T.ui32()
107
+ if dtype == np.uint16:
108
+ return T.ui16()
109
+ if dtype == np.uint8:
110
+ return T.ui8()
111
+ if dtype == np.bool_:
112
+ return T.bool()
113
+ if dtype == f8E5M2:
114
+ return T.f8E5M2()
115
+ if dtype == f8E4M3FN:
116
+ return T.f8E4M3FN()
117
+ if dtype == f8E8M0FNU:
118
+ return T.f8E8M0FNU()
119
+ if dtype == f6E3M2FN:
120
+ return T.f6E3M2FN()
121
+ if dtype == f6E2M3FN:
122
+ return T.f6E2M3FN()
123
+ if dtype == f4E2M1FN:
124
+ return T.f4E2M1FN()
125
+ assert False, f"Unknown type {type}"
126
+
127
+
128
+ def _mlir_type_to_numpy_type(type):
129
+ if type == T.f64():
130
+ return np.float64
131
+ if type == T.f16():
132
+ return np.float16
133
+ if type == T.f32():
134
+ return np.float32
135
+ if type == T.i64():
136
+ return np.int64
137
+ if type == T.i32():
138
+ return np.int32
139
+ if type == T.i16():
140
+ return np.int16
141
+ if type == T.i8():
142
+ return np.int8
143
+ if type == T.ui64():
144
+ return np.uint64
145
+ if type == T.ui32():
146
+ return np.uint32
147
+ if type == T.ui16():
148
+ return np.uint16
149
+ if type == T.ui8():
150
+ return np.uint8
151
+ if type == T.bool():
152
+ return np.bool_
153
+ assert False, f"Unknown type {type}"
154
+
155
+
156
+ # =============================================================================
157
+ # Main DSL Class
158
+ # =============================================================================
159
+
160
+
161
+ def is_dynamic_expression(value):
162
+ """
163
+ Given the `value`, check if itself is an IR value or recursively go through it to check if it contains IR value
164
+ """
165
+ if isinstance(value, (tuple, list)):
166
+ for x in value:
167
+ if is_dynamic_expression(x):
168
+ return True
169
+ elif isinstance(value, (ir.Value, ir.BlockArgumentList)) or hasattr(
170
+ value, "__extract_mlir_values__"
171
+ ):
172
+ return True
173
+ return False
174
+
175
+
176
+ def extract_mlir_values(obj):
177
+ """
178
+ Given the `obj`, recursively go through it to extract all contained IR values as list of MLIR values
179
+ """
180
+ res = []
181
+ if hasattr(obj, "__extract_mlir_values__"):
182
+ res = obj.__extract_mlir_values__()
183
+ elif isinstance(obj, (tuple, list)):
184
+ res = sum((extract_mlir_values(x) for x in obj), [])
185
+ elif isinstance(obj, SimpleNamespace):
186
+ res = []
187
+ for k, v in obj.__dict__.items():
188
+ res.extend(extract_mlir_values(v))
189
+ # Can't call is_dynamic_expression as _is_dynamic_expression depends on extract_mlir_values
190
+ elif isinstance(obj, set):
191
+ raise DSLRuntimeError(
192
+ "Sets are not supported in extract_mlir_values to ensure order preservation",
193
+ context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.",
194
+ suggestion="Consider using a list or tuple instead",
195
+ )
196
+ elif isinstance(obj, ir.Value):
197
+ res = [obj]
198
+ elif isinstance(obj, ir.BlockArgumentList):
199
+ res = list(obj) # type: ignore
200
+
201
+ return res
202
+
203
+
204
+ def new_from_mlir_values(obj, values):
205
+ """
206
+ Create a new python object by populating containing MLIR values with list of new values
207
+ """
208
+ if hasattr(obj, "__new_from_mlir_values__"):
209
+ return obj.__new_from_mlir_values__(values)
210
+ elif isinstance(obj, (tuple, list)):
211
+ res = []
212
+ for x in obj:
213
+ n_items = len(get_mlir_types(x))
214
+ res.append(new_from_mlir_values(x, values[:n_items]))
215
+ values = values[n_items:]
216
+ obj_ty = type(obj)
217
+ return obj_ty(res)
218
+ elif isinstance(obj, SimpleNamespace):
219
+ res = SimpleNamespace()
220
+ for k, v in obj.__dict__.items():
221
+ n_items = len(get_mlir_types(v))
222
+ res.__dict__[k] = new_from_mlir_values(v, values[:n_items])
223
+ values = values[n_items:]
224
+ return res
225
+ elif isinstance(obj, set):
226
+ raise DSLRuntimeError(
227
+ "Sets are not supported in new_from_mlir_values to ensure order preservation",
228
+ context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.",
229
+ suggestion="Consider using a list or tuple instead",
230
+ )
231
+ elif is_dynamic_expression(obj):
232
+
233
+ if len(values) == 0:
234
+ return obj
235
+
236
+ assert len(values) == 1
237
+ return values[0]
238
+ else:
239
+ assert len(values) == 0, f"{obj} expects 0 values, but got {values}"
240
+ return obj
241
+
242
+
243
+ class DSLCallable:
244
+ """
245
+ Wrapper class for a callable object used within the DSL.
246
+
247
+ DSLCallable is designed to wrap a function and provide additional
248
+ introspection utilities such as retrieving the argument specification
249
+ and signature. It ensures that the wrapped function can only be called
250
+ once, after which the reference to the function is cleared to prevent
251
+ further invocations. This is useful in scenarios where a function should
252
+ only be executed a single time within the DSL's execution model.
253
+
254
+ Attributes:
255
+ func (callable): The function to be wrapped and managed.
256
+
257
+ Methods:
258
+ __call__(*args, **kwargs): Calls the wrapped function and clears it.
259
+ """
260
+
261
+ def __init__(self, func):
262
+ self.func = func
263
+
264
+ def __call__(self, *args, **kwargs):
265
+ ret = self.__func__(*args, **kwargs)
266
+ self.func = None
267
+ return ret
268
+
269
+ @property
270
+ def __func__(self):
271
+ assert self.func is not None, "DSLCallable is already called"
272
+ return self.func
273
+
274
+ @property
275
+ def __signature__(self):
276
+ return inspect.signature(self.__func__)
277
+
278
+ @property
279
+ def __name__(self):
280
+ return self.__func__.__name__
281
+
282
+
283
+ class BaseDSL:
284
+ gpu_module = None
285
+
286
+ def __init__(
287
+ self,
288
+ *,
289
+ name: str,
290
+ dsl_package_name: List[str],
291
+ compiler_provider: Any,
292
+ pass_sm_arch_name: str,
293
+ device_compilation_only=False,
294
+ preprocess=False,
295
+ ):
296
+ """
297
+ Constructor for initializing the class with required providers and environment settings.
298
+
299
+ Parameters:
300
+ - name (str): Name of DSL, used for environment variables and logging.
301
+ - package_name (str): Name of the package, used for the preprocessor.
302
+ - compiler_provider (MLIR dialect): Provider for compiler.
303
+ - pass_sm_arch_name (str): The keyword name of the SM.
304
+ - device_compilation_only (bool) : Only device code, and call it via cuda driver
305
+ - preprocess (bool): Enable AST transformation.
306
+
307
+ This constructs a DSL instance and sets up environment management,
308
+ warning configurations, and logging functionalities. It reads
309
+ environment variables using `EnvironmentVarManager` and configures
310
+ a logger with settings from the environment. If environment warnings
311
+ are detected, they are escalated to errors to ensure strict handling.
312
+ """
313
+ # Enforcing initialization of instance variables
314
+ if not all([name, compiler_provider, pass_sm_arch_name]):
315
+ raise DSLRuntimeError(
316
+ "All required parameters must be provided and non-empty"
317
+ )
318
+
319
+ self.name = name
320
+ self.compiler_provider = compiler_provider
321
+ self.pass_sm_arch_name = pass_sm_arch_name
322
+ self.frame = None
323
+ self.no_cache = False
324
+ self.device_compilation_only = device_compilation_only
325
+ self.num_kernels = 0
326
+ # Read environment variables
327
+ self.envar = EnvironmentVarManager(self.name)
328
+ self.enable_preprocessor = preprocess
329
+ # This cache uses hash of original ir and env as key, allows dump/load to/from file. Enabled by default
330
+ self.jit_cache = (
331
+ dict()
332
+ if self.envar.disable_file_caching
333
+ else load_cache_from_path(self.name, self.envar.file_caching_capacity)
334
+ )
335
+ self.host_jit_decorator_name = f"@{BaseDSL.jit.__name__}"
336
+ self.device_jit_decorator_name = f"@{BaseDSL.kernel.__name__}"
337
+
338
+ # set warning
339
+ if not self.envar.enable_optimization_warnings:
340
+ # By default, optimization warnings are disabled
341
+ warnings.filterwarnings("ignore", category=DSLOptimizationWarning)
342
+ if self.envar.warnings_as_errors:
343
+ warnings.filterwarnings("error")
344
+ if self.envar.warnings_ignore:
345
+ warnings.filterwarnings("ignore")
346
+
347
+ # Initialize logger
348
+ if self.envar.log_to_console == False and self.envar.jitTimeProfiling:
349
+ self.envar.log_to_console = True
350
+ self.envar.log_level = 20 # info level
351
+ setup_log(
352
+ self.name,
353
+ self.envar.log_to_console,
354
+ self.envar.log_to_file,
355
+ f"{self.name}.log",
356
+ self.envar.log_level,
357
+ )
358
+
359
+ # kernel symbols are temporary symbol string variables, their values are valid until the compilation is done.
360
+ self.kernel_symbols = []
361
+ # used to generate unique name for gpu.launch
362
+ self.launch_inner_count = 0
363
+ # initialize default compile options
364
+ self.compile_options = CompileOptions()
365
+
366
+ if preprocess:
367
+ self.preprocessor = DSLPreprocessor(dsl_package_name)
368
+ log().info(f"Initializing {name} DSL")
369
+ log().debug(f"Logger initialized for {self.name}")
370
+
371
+ # Hook excepthook
372
+ if self.envar.filterStacktrace:
373
+ origin_excepthook = sys.excepthook
374
+ module_dir = walk_to_top_module(os.path.dirname(os.path.abspath(__file__)))
375
+
376
+ def excepthook(excep_type, value, traceback):
377
+ filter_exception(value, module_dir)
378
+ if hasattr(value, "__traceback__"):
379
+ origin_excepthook(excep_type, value, value.__traceback__)
380
+ else:
381
+ origin_excepthook(
382
+ excep_type, value, filter_stackframe(traceback, module_dir)
383
+ )
384
+
385
+ sys.excepthook = excepthook
386
+
387
+ # Restore original excepthook
388
+ def restore_excepthook(hook):
389
+ sys.excepthook = hook
390
+
391
+ atexit.register(restore_excepthook, origin_excepthook)
392
+
393
+ def dump_cache(self):
394
+ if not self.envar.disable_file_caching:
395
+ dump_cache_to_path(
396
+ self.name, self.jit_cache, self.envar.file_caching_capacity
397
+ )
398
+
399
+ @lru_cache(maxsize=1)
400
+ def print_warning_once(self, message):
401
+ log().warning(f"Warning: {message}")
402
+ warnings.warn(message, UserWarning)
403
+
404
+ def print_warning(self, message):
405
+ log().warning(f"Warning: {message}")
406
+ warnings.warn(message, UserWarning)
407
+
408
+ @classmethod
409
+ @lru_cache(maxsize=1)
410
+ def _get_dsl(cls):
411
+ # Instantiate the DSL Class once
412
+ main_dsl = cls()
413
+ if not main_dsl.no_cache:
414
+ # register atexit callback
415
+ atexit.register(main_dsl.dump_cache)
416
+ return main_dsl
417
+
418
+ @staticmethod
419
+ def _can_preprocess(**dkwargs):
420
+ """
421
+ Check if AST transformation is enabled or not for `jit` and `kernel` decorators.
422
+ """
423
+ return dkwargs.pop("preprocess", True)
424
+
425
+ @staticmethod
426
+ def _get_original_function(fcn_ptr, name):
427
+ """
428
+ Get the original function from the decorated function
429
+ """
430
+ while fcn_ptr.__name__ != name:
431
+ # If the function is wrapped with functools, get from __wrapped__
432
+ if hasattr(fcn_ptr, "__wrapped__"):
433
+ fcn_ptr = fcn_ptr.__wrapped__
434
+ # If the function is wrapped manually, it's the first in clousure
435
+ elif callable(fcn_ptr.__closure__[0].cell_contents):
436
+ fcn_ptr = fcn_ptr.__closure__[0].cell_contents
437
+ else:
438
+ raise DSLRuntimeError(
439
+ f"Cannot find the original function {name} in the closure chain"
440
+ )
441
+ return fcn_ptr
442
+
443
+ @staticmethod
444
+ def _preprocess_and_execute(func):
445
+ """
446
+ Run ast transformation and return the materialized function pointer
447
+ """
448
+ if hasattr(func, "_transformed_ast"):
449
+ # If the function ptr is already materialized, use the existing one
450
+ func._dsl_object.frame = func._decorator_frame
451
+ if func._transformed_ast is None:
452
+ func._transformed_ast = func._dsl_object.run_preprocessor(func)
453
+ if func._transformed_ast is None:
454
+ del func._transformed_ast
455
+ func._dsl_object.frame = None
456
+ return func
457
+
458
+ fcn_ptr = func._dsl_object.get_function_ptr(func)
459
+ # If the function is decorated, de-decorate it
460
+ fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__)
461
+ func._dsl_object.frame = None
462
+ return DSLCallable(fcn_ptr)
463
+ return func
464
+
465
+ def jit_runner(self, executor, frame, *dargs, **dkwargs):
466
+ """
467
+ Decorator to mark a function for JIT compilation.
468
+ """
469
+ log().info("jit_runner")
470
+
471
+ def jit_runner_decorator(func):
472
+ func._dsl_object = self
473
+ # Run preprocessor that alters AST
474
+ if self.enable_preprocessor and BaseDSL._can_preprocess(**dkwargs):
475
+ # For an annotated function, add some DSL attributes
476
+ # When materializing the AST, we need decorator's frame
477
+ func._decorator_frame = frame
478
+ # No transformed ast at this point
479
+ func._transformed_ast = None
480
+
481
+ @wraps(func)
482
+ def jit_wrapper(*args, **kwargs):
483
+ func_ptr = BaseDSL._preprocess_and_execute(func)
484
+ return executor(func_ptr, *args, **kwargs)
485
+
486
+ return jit_wrapper
487
+
488
+ if len(dargs) == 1 and callable(dargs[0]):
489
+ return jit_runner_decorator(dargs[0])
490
+ else:
491
+ return jit_runner_decorator
492
+
493
+ @classmethod
494
+ def jit(cls, *dargs, **dkwargs):
495
+ """
496
+ Decorator to mark a function for JIT compilation for Host code.
497
+ """
498
+ frame = inspect.currentframe().f_back
499
+ # Instantiate the DSL Class
500
+ main_dsl = cls._get_dsl()
501
+ return main_dsl.jit_runner(main_dsl._func, frame, *dargs, **dkwargs)
502
+
503
+ @classmethod
504
+ def kernel(cls, *dargs, **dkwargs):
505
+ """
506
+ Decorator to mark a function for JIT compilation for GPU.
507
+ """
508
+ frame = inspect.currentframe().f_back
509
+ # Instantiate the DSL Class
510
+ main_dsl = cls._get_dsl()
511
+ return main_dsl.jit_runner(main_dsl._kernel_helper, frame, *dargs, **dkwargs)
512
+
513
+ @abstractmethod
514
+ def _kernel_helper(self, func, *args, **kwargs):
515
+ """
516
+ Helper function to handle kernel generation logic
517
+ """
518
+ pass
519
+
520
+ @abstractmethod
521
+ def _build_gpu_module(self, attrs):
522
+ """
523
+ Build the module op that contains the kernels.
524
+ """
525
+ pass
526
+
527
+ @abstractmethod
528
+ def _get_pipeline(self, pipeline):
529
+ """
530
+ Get the pipeline from the other configuration options.
531
+ """
532
+ if pipeline != None:
533
+ return pipeline
534
+ return None
535
+
536
+ @staticmethod
537
+ def log_additions(func_type, operands=None, types=None, arg_attrs=None):
538
+ if operands is not None and operands != []:
539
+ log().debug(
540
+ f"Added {func_type} operands: [%s]", ", ".join(map(str, operands))
541
+ )
542
+ if types is not None:
543
+ log().debug(
544
+ f"Added {func_type} arg_types: [%s]", ", ".join(map(str, types))
545
+ )
546
+ if arg_attrs is not None:
547
+ log().debug(
548
+ f"Added {func_type} arg_attrs: [%s]", ", ".join(map(str, arg_attrs))
549
+ )
550
+
551
+ def mangle_name(self, function_name, args, args_spec: inspect.FullArgSpec):
552
+ """Does simple name mangling"""
553
+
554
+ for spec_arg, arg in zip(args_spec.args, args):
555
+ spec_ty = args_spec.annotations.get(spec_arg, None)
556
+ if spec_ty != None:
557
+ if issubclass(type(spec_ty), (t.IRValue, t.IRVariadic)):
558
+ continue
559
+ if isinstance(spec_ty, (ir.Type, ir.Value)):
560
+ continue
561
+ if isinstance(arg, (ir.Type, ir.Value, ir.OpResult)):
562
+ continue
563
+ if isinstance(type(arg), (ir.Type, ir.Value, ir.OpResult)):
564
+ continue
565
+ if self._is_tensor_descriptor(arg):
566
+ continue
567
+ if inspect.isclass(spec_ty):
568
+ class_name = str(arg).replace("class", "")
569
+ class_name = class_name.replace(" ", "")
570
+ function_name = f"{function_name}_{class_name}"
571
+ elif isinstance(arg, (list, tuple)):
572
+ function_name = f"{function_name}_{'_'.join(map(str, arg))}"
573
+ else:
574
+ function_name = f"{function_name}_{arg}"
575
+ # we would need a dedicated MR to follow up
576
+ unwanted_chars = r"'-![]#,.<>()\":{}=%?@;"
577
+ translation_table = str.maketrans("", "", unwanted_chars)
578
+ function_name = function_name.translate(translation_table)
579
+ # identify address and drop
580
+ function_name = re.sub(r"0x[a-f0-9]{8,16}", "", function_name)
581
+ function_name = re.sub(r"\s+", " ", function_name)
582
+ function_name = function_name.replace(" ", "_")
583
+ function_name = function_name.replace("\n", "_")
584
+ # max fname is 256 character, leave space
585
+ function_name = function_name[:180]
586
+ log().info(f"Final mangled function name: {function_name}")
587
+ return function_name
588
+
589
+ def _generate_execution_arguments_for_known_types(
590
+ self, arg, arg_spec, arg_name, i, fop_args, iv_block_args
591
+ ):
592
+ """
593
+ Generate MLIR arguments for known types.
594
+
595
+ Sub-DSLs can override this method to handle types that are not
596
+ natively supported by the Base DSL.
597
+ """
598
+ ir_arg = []
599
+ if is_argument_constexpr(arg, arg_spec, arg_name, i, func):
600
+ ir_arg.append(arg)
601
+
602
+ return ir_arg, iv_block_args
603
+
604
+ def generate_execution_arguments(
605
+ self,
606
+ args,
607
+ kwargs,
608
+ fop,
609
+ args_spec: inspect.FullArgSpec,
610
+ ):
611
+ """Create list of arguments that will be passed to MLIR's func.func op"""
612
+
613
+ def gen_exec_args(input_args, arg_names, annotations, fop_args):
614
+ assert len(input_args) == len(arg_names)
615
+
616
+ ir_args = []
617
+ iv_block_args = 0
618
+ for i, arg in enumerate(input_args):
619
+ arg_name = arg_names[i]
620
+ arg_spec = annotations.get(arg_name, None)
621
+ log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, arg_spec)
622
+
623
+ # Implicit cast to NumericMeta
624
+ if isinstance(arg_spec, t.NumericMeta) and not isinstance(
625
+ arg, arg_spec
626
+ ):
627
+ arg = t.cast(arg, arg_spec)
628
+
629
+ ir_arg, iv_block_args = (
630
+ self._generate_execution_arguments_for_known_types(
631
+ arg, arg_spec, arg_name, i, fop_args, iv_block_args
632
+ )
633
+ )
634
+
635
+ if not ir_arg:
636
+ # If it's not a known type, try JIT argument adapter
637
+ # to convert the argument if possible
638
+ adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
639
+ arg = adapter(arg) if adapter else arg
640
+
641
+ n_args = len(get_mlir_types(arg))
642
+ blk_args = fop_args[iv_block_args : iv_block_args + n_args]
643
+ ir_arg.append(new_from_mlir_values(arg, blk_args))
644
+ iv_block_args += n_args
645
+
646
+ self.log_additions(ir_arg)
647
+ ir_args.extend(ir_arg)
648
+
649
+ return ir_args, iv_block_args
650
+
651
+ fop_args = list(fop.regions[0].blocks[0].arguments)
652
+ ir_args, iv_block_args = gen_exec_args(
653
+ args, args_spec.args, args_spec.annotations, fop_args
654
+ )
655
+ ir_kwargs, _ = gen_exec_args(
656
+ [kwargs[arg] for arg in args_spec.kwonlyargs],
657
+ args_spec.kwonlyargs,
658
+ args_spec.annotations,
659
+ fop_args[iv_block_args:],
660
+ )
661
+ ir_kwargs = {k: v for k, v in zip(args_spec.kwonlyargs, ir_kwargs)}
662
+
663
+ log().debug("execution args: %s", ", ".join(map(str, ir_args)))
664
+ log().debug("execution kwargs: %s", ", ".join(map(str, ir_kwargs)))
665
+ return ir_args, ir_kwargs
666
+
667
+ @abstractmethod
668
+ def _generate_mlir_type_for_tensor_descriptor(self, tensor):
669
+ """
670
+ Generate MLIR type for the tensor descriptor.
671
+ """
672
+ pass
673
+
674
+ @abstractmethod
675
+ def _generate_executable_arg_for_tensor_descriptor(
676
+ self, mlir_value=None, ptr_tensor_ty=None, tensor=None
677
+ ):
678
+ """
679
+ Generates executable value for the given tensor descriptor.
680
+ """
681
+ pass
682
+
683
+ def _get_globals(self):
684
+ """
685
+ Combines global and local variables from the current context and the
686
+ caller's frame comes. This includes the current module's globals, the
687
+ global variables from the caller's frame, and the local variables from
688
+ the caller's frame.
689
+
690
+ "self.frame" is used to fetch the caller's frame.
691
+
692
+ AST preprocessor generates a new python code, so the resulting globals
693
+ dictionary is used to execute the python code.
694
+ """
695
+ all_globals = {}
696
+ if self.frame:
697
+ all_globals.update(self.frame.f_globals)
698
+ all_globals.update(self.frame.f_locals)
699
+ return all_globals
700
+
701
+ @abstractmethod
702
+ def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
703
+ pass
704
+
705
+ @abstractmethod
706
+ def _handle_tensor_descriptor(
707
+ self, maybe_tensor, arg_name: str, need_gpu_memory: bool
708
+ ) -> Any:
709
+ pass
710
+
711
+ def _validate_arg(self, arg, arg_index, arg_name, arg_spec):
712
+ """
713
+ Validates if the arg is really of the annotated type for type safety.
714
+
715
+ The default implementation is empty. Subclasses can override this method to add more validation logic.
716
+ Returns None if validation passes, otherwise returns an error derived from DSLBaseError.
717
+ """
718
+ pass
719
+
720
+ def _generate_jit_func_args_for_known_types(
721
+ self,
722
+ func,
723
+ arg,
724
+ arg_name,
725
+ arg_spec,
726
+ arg_index,
727
+ *,
728
+ is_host=True,
729
+ ):
730
+ """
731
+ Generate JIT function arguments for known types.
732
+
733
+ Sub-DSLs can override this method to handle types that are not
734
+ natively supported by the Base DSL.
735
+ """
736
+
737
+ jit_arg_type, jit_arg_attr, jit_exec_arg = [], [], []
738
+ default_attr = ir.DictAttr.get({})
739
+
740
+ if is_argument_constexpr(arg, arg_spec, arg_name, arg_index, func):
741
+ jit_exec_arg = jit_arg_type = jit_arg_attr = None
742
+
743
+ return jit_exec_arg, jit_arg_type, jit_arg_attr
744
+
745
+ def _generate_jit_func_args(
746
+ self,
747
+ func,
748
+ function_name,
749
+ args,
750
+ kwargs,
751
+ args_spec: inspect.FullArgSpec,
752
+ *,
753
+ is_host=True,
754
+ ):
755
+ """Generate JIT function arguments."""
756
+
757
+ assert len(args) == len(args_spec.args) and len(kwargs) == len(
758
+ args_spec.kwonlyargs
759
+ ), (
760
+ f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args "
761
+ f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}"
762
+ )
763
+
764
+ jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], []
765
+ jit_adapted_args = []
766
+ default_attr = ir.DictAttr.get({})
767
+
768
+ input_args = [*args, *kwargs.values()]
769
+ input_arg_names = [*args_spec.args, *args_spec.kwonlyargs]
770
+ for i, (arg_name, arg) in enumerate(zip(input_arg_names, input_args)):
771
+ spec_ty = args_spec.annotations.get(arg_name, None)
772
+ log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, spec_ty)
773
+
774
+ # Implicitly convert into Numeric type if possible
775
+ if isinstance(spec_ty, t.NumericMeta) and not isinstance(arg, spec_ty):
776
+ arg = t.cast(arg, spec_ty)
777
+
778
+ # Type safety check
779
+ if spec_ty is not None:
780
+ err = self._validate_arg(arg, i, arg_name, spec_ty)
781
+ if err is not None:
782
+ raise err
783
+
784
+ jit_exec_arg, jit_arg_type, jit_arg_attr = (
785
+ self._generate_jit_func_args_for_known_types(
786
+ func,
787
+ arg,
788
+ arg_name,
789
+ spec_ty,
790
+ i,
791
+ is_host=is_host,
792
+ )
793
+ )
794
+
795
+ if jit_arg_type is not None and len(jit_arg_type) == 0:
796
+ # If not any known type, try JIT argument adapter
797
+ # to convert the argument
798
+ adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
799
+ if adapter:
800
+ arg = adapter(arg)
801
+ jit_adapted_args.append(arg)
802
+
803
+ if is_host:
804
+ jit_exec_arg.extend(get_c_pointers(arg))
805
+ jit_arg_type.extend(get_mlir_types(arg))
806
+ else:
807
+ dyn_vals = extract_mlir_values(arg)
808
+ jit_exec_arg.extend(dyn_vals)
809
+ jit_arg_type.extend([v.type for v in dyn_vals])
810
+
811
+ if not jit_arg_type or not jit_exec_arg:
812
+ if (is_host and hasattr(arg, "__c_pointers__")) or (
813
+ not is_host
814
+ and hasattr(arg, "__extract_mlir_values__")
815
+ and hasattr(arg, "__new_from_mlir_values__")
816
+ ):
817
+ pass
818
+ else:
819
+ raise DSLRuntimeError(
820
+ f"failed to generate argument #{i+1} ({arg_name}) for JIT function '{function_name}'.",
821
+ context={
822
+ f"Argument {arg_name}": "The DSL attempted to convert it into Dynamic Expression (aka MLIR values) but failed.",
823
+ f"Call-site argument value": arg,
824
+ f"Call-site argument type": type(arg),
825
+ },
826
+ suggestion=f"Consider annotating the argument with `{arg_name} : Constexpr` "
827
+ "if it's a value known at compile-time. "
828
+ f"Otherwise, implement the {'`JitArgument`' if is_host else '`DynamicExpression`'} "
829
+ f"protocol or register a custom JIT argument adapter for type `{type(arg)}` to "
830
+ "enable dynamic value conversion at runtime.",
831
+ )
832
+
833
+ jit_arg_attr.extend([default_attr] * len(jit_arg_type))
834
+
835
+ if jit_arg_type is not None:
836
+ jit_exec_args.extend(jit_exec_arg)
837
+ jit_arg_types.extend(jit_arg_type)
838
+ jit_arg_attrs.extend(jit_arg_attr)
839
+
840
+ return jit_exec_args, jit_arg_types, jit_arg_attrs, jit_adapted_args
841
+
842
+ def generate_mlir_function_types(
843
+ self, func, function_name, input_args, kwargs, args_spec: inspect.FullArgSpec
844
+ ):
845
+ """Convert input arguments to MLIR function signature also convert numpy arrays to memref."""
846
+
847
+ exe_args, types, attrs, adapted_args = self._generate_jit_func_args(
848
+ func, function_name, input_args, kwargs, args_spec, is_host=True
849
+ )
850
+
851
+ log().debug("Execution Arguments: %s", ", ".join(map(str, exe_args)))
852
+ log().debug("Types: %s", ", ".join(map(str, types)))
853
+
854
+ assert len(exe_args) == len(
855
+ types
856
+ ), "expects the same number of arguments and function parameters"
857
+
858
+ return exe_args, types, adapted_args
859
+
860
+ @dataclass
861
+ class LaunchConfig:
862
+ cluster: list = None
863
+ grid: list = field(default_factory=lambda: [1, 1, 1])
864
+ block: list = field(default_factory=lambda: [1, 1, 1])
865
+ smem: int = None
866
+ async_deps: list = field(default_factory=list)
867
+ has_cluster: bool = False
868
+ min_blocks_per_mp: int = 0
869
+ auto_smem: bool = False
870
+
871
+ def __post_init__(self):
872
+ if len(self.grid) != 3:
873
+ raise DSLRuntimeError(f"Expect 3d grid!")
874
+ if len(self.block) != 3:
875
+ raise DSLRuntimeError(f"Expect 3d block!")
876
+
877
+ if self.smem is None:
878
+ self.smem = 0
879
+ self.auto_smem = True
880
+
881
+ self.has_cluster = self.cluster is not None
882
+ if self.cluster is None:
883
+ self.cluster = [None, None, None]
884
+ elif len(self.cluster) != 3:
885
+ raise DSLRuntimeError(f"Expect 3d cluster!")
886
+
887
+ def diagnostic(self):
888
+ """Check command line parameters and enables diagnostic"""
889
+ # Check command line arguments "-diagnostic"
890
+ parser = argparse.ArgumentParser(description="Process diagnostic status.")
891
+ parser.add_argument(
892
+ "-diagnostic",
893
+ nargs="?",
894
+ const="all",
895
+ choices=["all", "fail", "success", "info", "suggestion"],
896
+ help="Set diagnostic status (fail, success, info, suggestion).",
897
+ )
898
+
899
+ args, _ = parser.parse_known_args()
900
+ ctx = ir.Context.current
901
+
902
+ def callback(d):
903
+ print(f" [{self.name} Diagnostic] : {d.message}")
904
+
905
+ ctx.attach_diagnostic_handler(callback)
906
+
907
+ # Early return, don't enable diagnostics
908
+ if args.diagnostic is None:
909
+ return
910
+
911
+ # Enable MLIR Flags
912
+ ctx.emit_error_diagnostics = True
913
+ ir._GlobalDebug.flag = True
914
+ if args.diagnostic == "all":
915
+ ir._GlobalDebug.set_types("diagnostic")
916
+ else:
917
+ ir._GlobalDebug.set_types(f"diagnostic-{args.diagnostic}")
918
+
919
+ def get_location(self):
920
+ """
921
+ Get python location information and generate MLIR location
922
+ """
923
+
924
+ if self.frame is None:
925
+ log().debug("Frame is None")
926
+ return None
927
+
928
+ file_loc = ir.Location.file(
929
+ self.frame.f_code.co_filename, self.frame.f_lineno, 0
930
+ )
931
+
932
+ loc = ir.Location.name(self.frame.f_code.co_name, childLoc=file_loc)
933
+ return loc
934
+
935
+ def compile_and_jit(self, module, pipeline, shared_libs, function_name=""):
936
+ """
937
+ Compile and JIT an MLIR module.
938
+ """
939
+
940
+ try:
941
+ self.diagnostic()
942
+
943
+ orig_stdout = sys.stdout
944
+ orig_stderr = sys.stderr
945
+ sys.stderr = redirect_stderr = io.StringIO()
946
+ sys.stdout = redirect_stdout = io.StringIO()
947
+
948
+ try:
949
+ kernel = self.compiler_provider.compile_and_jit(
950
+ module,
951
+ pipeline,
952
+ shared_libs=shared_libs,
953
+ cuda_toolkit=self.envar.cuda_toolkit,
954
+ arch=self.envar.arch,
955
+ )
956
+
957
+ finally:
958
+ sys.stdout = orig_stdout
959
+ sys.stderr = orig_stderr
960
+ ir._GlobalDebug.flag = False
961
+
962
+ # Print captured output.
963
+ print(redirect_stdout.getvalue(), file=sys.stdout, end="")
964
+ print(redirect_stderr.getvalue(), file=sys.stderr, end="")
965
+
966
+ return kernel
967
+
968
+ except Exception as e:
969
+ raise DSLRuntimeError("🧊🧊🧊 ICE 🧊🧊🧊", cause=e)
970
+ finally:
971
+ pass
972
+
973
+ def preprocess_pipeline(self, pipeline, arch) -> str:
974
+
975
+ if self.envar.cuda_toolkit is None:
976
+ self.print_warning(
977
+ "CUDA_TOOLKIT_PATH environment variable is not set. Cannot set toolkitPath."
978
+ )
979
+
980
+ options = {
981
+ "toolkitPath": self.envar.cuda_toolkit if self.envar.cuda_toolkit else None,
982
+ self.pass_sm_arch_name: arch,
983
+ }
984
+
985
+ opt_str = ""
986
+ for k, v in options.items():
987
+ if v:
988
+ opt_str += f"{k}={v} "
989
+
990
+ if opt_str:
991
+ # Automatically append the pipeline options if any is specified through env var
992
+ pattern = re.compile(r"{(.+)}")
993
+ match = pattern.search(pipeline)
994
+ if match:
995
+ opt_str = f"{{{match[1]} {opt_str}}}"
996
+ pipeline = re.sub(r"{.+}", opt_str, pipeline)
997
+ else:
998
+ pipeline = pipeline.rstrip(")") + f"{{{opt_str}}})"
999
+ log().debug(f"Using pipeline = {pipeline}")
1000
+ return pipeline
1001
+
1002
+ def get_shared_libs(self) -> list:
1003
+ shared_libs = []
1004
+ support_libs = self.envar.shared_libs
1005
+ if support_libs is not None:
1006
+ _libs = support_libs.split(":")
1007
+ for lib in _libs:
1008
+ if not os.path.exists(lib):
1009
+ raise FileNotFoundError(
1010
+ errno.ENOENT, os.strerror(errno.ENOENT), lib
1011
+ )
1012
+ shared_libs.append(lib)
1013
+ else:
1014
+ self.print_warning(f"{self.name}_LIBS environment variable is not set")
1015
+
1016
+ return shared_libs
1017
+
1018
+ @lru_cache(maxsize=1)
1019
+ def get_version(self):
1020
+ version_hash = hashlib.sha256()
1021
+
1022
+ return version_hash
1023
+
1024
+ def get_module_hash(self, module, function_name):
1025
+ s = io.BytesIO()
1026
+ module.operation.write_bytecode(s)
1027
+ for attr, value in self.envar.__dict__.items():
1028
+ if value is not None:
1029
+ s.write(str(value).encode())
1030
+ # Add compile options to the hash
1031
+ s.write(self.compile_options.to_str().encode())
1032
+ module_hash = self.get_version().copy()
1033
+ module_hash.update(s.getvalue())
1034
+ module_hash = module_hash.hexdigest()
1035
+
1036
+ log().debug("Bytecode=[%s]", s.getvalue().hex())
1037
+ log().debug("Version=[%s]", self.get_version().hexdigest())
1038
+ log().info(
1039
+ "Function=[%s] Computed module_hash=[%s]", function_name, module_hash
1040
+ )
1041
+ return module_hash
1042
+
1043
+ def build_module(self, module, function_name: str):
1044
+ """
1045
+ Build the MLIR module, verify and return the module
1046
+ """
1047
+
1048
+ # Save IR in a file
1049
+ if self.envar.keepIR:
1050
+ save_ir(self.name, module, function_name)
1051
+
1052
+ if self.envar.printIR:
1053
+ print("\n//===--- ------ Generated IR ------ ---====\n")
1054
+ module.operation.print(
1055
+ enable_debug_info=self.envar.generate_source_location
1056
+ )
1057
+ print("\n//===--- --- End of Generated IR -- ---====\n")
1058
+
1059
+ # Verify the module
1060
+ try:
1061
+ module.operation.verify()
1062
+ except Exception as e:
1063
+ raise DSLRuntimeError(f"🧊🧊🧊 ICE IR Verification Failed 🧊🧊🧊", cause=e)
1064
+
1065
+ return module
1066
+
1067
+ def generate_original_ir(
1068
+ self,
1069
+ ir,
1070
+ func,
1071
+ funcBody,
1072
+ kwargs,
1073
+ function_name,
1074
+ func_types,
1075
+ gpu_module_attrs,
1076
+ args,
1077
+ args_spec,
1078
+ ):
1079
+ # This location is set to None for now; otherwise, calls to the same
1080
+ # function on different lines would produce different line numbers,
1081
+ # which would break the cache.
1082
+ loc = None # self.get_location()
1083
+
1084
+ def build_ir_module():
1085
+ module = ir.Module.create(loc=loc)
1086
+ unit_attr = ir.UnitAttr.get()
1087
+ module.operation.attributes["gpu.container_module"] = unit_attr
1088
+
1089
+ with ir.InsertionPoint(module.body):
1090
+ # Always generate gpu module. It's canonicalized by the compiler when it's not used.
1091
+ self._build_gpu_module(gpu_module_attrs)
1092
+
1093
+ fop = func.FuncOp(function_name, (func_types, []), loc=loc)
1094
+ fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
1095
+ log().debug("Generated Function OP [%s]", fop)
1096
+ with ir.InsertionPoint(fop.add_entry_block()):
1097
+ ir_args, ir_kwargs = self.generate_execution_arguments(
1098
+ args, kwargs, fop, args_spec
1099
+ )
1100
+ # Call user function body
1101
+ try:
1102
+ result = funcBody(*ir_args, **ir_kwargs)
1103
+ func.ReturnOp([])
1104
+ except NameError as name_error:
1105
+ raise DSLRuntimeError(
1106
+ f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥",
1107
+ cause=name_error,
1108
+ suggestion="Using variables defined in dynamic control flow is not supported. Please give an initial value before control flow.",
1109
+ )
1110
+ except DSLRuntimeError as dsl_error:
1111
+ # Throw it's already a DSL error
1112
+ raise dsl_error
1113
+ return module, result
1114
+
1115
+ # Build IR module
1116
+ profiler = timer(enable=self.envar.jitTimeProfiling)
1117
+ module, result = profiler(build_ir_module)()
1118
+ module_hash = self.get_module_hash(module, function_name)
1119
+
1120
+ module = self.build_module(module, function_name)
1121
+
1122
+ return module, module_hash, result
1123
+
1124
+ def compile_and_cache(
1125
+ self, module, module_hash, function_name, pipeline, args_spec, no_cache
1126
+ ):
1127
+ arch = self.envar.arch
1128
+ pipeline = self.preprocess_pipeline(self._get_pipeline(pipeline), arch)
1129
+ shared_libs = self.get_shared_libs()
1130
+ profiler = timer(enable=self.envar.jitTimeProfiling)
1131
+ if (
1132
+ no_cache
1133
+ or module_hash not in self.jit_cache
1134
+ or self.jit_cache[module_hash].ir_module is None
1135
+ ):
1136
+ log().info(
1137
+ "JIT cache miss function=[%s] module_hash=[%s]",
1138
+ function_name,
1139
+ module_hash,
1140
+ )
1141
+ # Compile and JIT MLIR module
1142
+ engine = profiler(self.compile_and_jit)(
1143
+ module, pipeline, shared_libs, function_name=function_name
1144
+ )
1145
+ else:
1146
+ log().info(
1147
+ "JIT cache hit IN-FILE function=[%s] module_hash=[%s]",
1148
+ function_name,
1149
+ module_hash,
1150
+ )
1151
+ module = self.jit_cache[module_hash].ir_module
1152
+ engine = self.compiler_provider.jit(module, shared_libs=shared_libs)
1153
+ capi_func = profiler(engine.lookup)(function_name)
1154
+ jit_executor = JitExecutor(
1155
+ self,
1156
+ engine,
1157
+ capi_func,
1158
+ module,
1159
+ args_spec,
1160
+ function_name,
1161
+ jit_time_profiling=self.envar.jitTimeProfiling,
1162
+ )
1163
+ jit_executor = jit_executor.update_jit_cuda_modules(self.kernel_symbols)
1164
+
1165
+ if not no_cache:
1166
+ # module stored in cache is compiled.
1167
+ self.jit_cache[module_hash] = jit_executor
1168
+
1169
+ return jit_executor
1170
+
1171
+ def post_compilation_cleanup(self):
1172
+ """Clean up some internal state after one compilation is completed."""
1173
+ # clear the kernel symbols after the compilation is done.
1174
+ self.kernel_symbols = []
1175
+ self.launch_inner_count = 0
1176
+ # reset num_kernels to 0 for next compilation.
1177
+ self.num_kernels = 0
1178
+ # reset the compile options after the compilation is done.
1179
+ self.compile_options = CompileOptions()
1180
+
1181
+ def generate_mlir(
1182
+ self,
1183
+ funcBody,
1184
+ kwargs,
1185
+ function_name,
1186
+ gpu_module_attrs,
1187
+ args,
1188
+ args_spec,
1189
+ pipeline,
1190
+ no_cache,
1191
+ compile_only,
1192
+ loc=None,
1193
+ ):
1194
+ """Generate MLIR module and compile iself.T_provider."""
1195
+ with ir.Context(), ir.Location.unknown():
1196
+ # Convert input arguments to MLIR arguments
1197
+ exe_args, func_types, adapted_args = self.generate_mlir_function_types(
1198
+ funcBody, function_name, args, kwargs, args_spec
1199
+ )
1200
+
1201
+ # Generate original ir module and its hash value.
1202
+ module, module_hash, result = self.generate_original_ir(
1203
+ ir,
1204
+ func,
1205
+ funcBody,
1206
+ kwargs,
1207
+ function_name,
1208
+ func_types,
1209
+ gpu_module_attrs,
1210
+ args,
1211
+ args_spec,
1212
+ )
1213
+
1214
+ # dryrun is used to only generate IR
1215
+ if self.envar.dryrun:
1216
+ return result
1217
+
1218
+ if (
1219
+ no_cache
1220
+ or module_hash not in self.jit_cache
1221
+ or self.jit_cache[module_hash].capi_func is None
1222
+ ):
1223
+ # no cache or cache miss, do ir generation/compilation/jit engine
1224
+ jit_executor = self.compile_and_cache(
1225
+ module, module_hash, function_name, pipeline, args_spec, no_cache
1226
+ )
1227
+ else:
1228
+ # cache hit
1229
+ log().info(
1230
+ "JIT cache hit IN-MEMORY function=[%s] module_hash=[%s]",
1231
+ function_name,
1232
+ module_hash,
1233
+ )
1234
+ jit_executor = self.jit_cache[module_hash]
1235
+
1236
+ self.post_compilation_cleanup()
1237
+ # If compile_only is set, bypass execution return the jit_executor directly
1238
+ if compile_only:
1239
+ return jit_executor
1240
+ # Run the compiled program
1241
+ jit_executor.run_compiled_program(exe_args)
1242
+
1243
+ return result
1244
+
1245
+ def run_preprocessor(self, funcBody):
1246
+ if not hasattr(funcBody, "_preprocessed"):
1247
+ function_name = funcBody.__name__
1248
+ self.funcBody = funcBody
1249
+ log().info("Started preprocessing [%s]", function_name)
1250
+ exec_globals = self._get_globals()
1251
+ transformed_ast = self.preprocessor.transform(funcBody, exec_globals)
1252
+ if self.envar.print_after_preprocessor:
1253
+ log().info(
1254
+ f"# Printing unparsed AST after preprocess of func=`{function_name}` id=`{id(funcBody)}`"
1255
+ )
1256
+ DSLPreprocessor.print_ast(transformed_ast)
1257
+ funcBody._preprocessed = True
1258
+ return transformed_ast
1259
+ return None
1260
+
1261
+ def get_function_ptr(self, original_function):
1262
+ file_name = inspect.getsourcefile(original_function)
1263
+ code_object = compile(
1264
+ original_function._transformed_ast, filename=file_name, mode="exec"
1265
+ )
1266
+ return self.preprocessor.exec(
1267
+ original_function.__name__,
1268
+ original_function,
1269
+ code_object,
1270
+ self._get_globals(),
1271
+ )
1272
+
1273
+ def _get_function_bound_args(self, sig, func_name, *args, **kwargs):
1274
+ """
1275
+ Binds provided arguments to a function's signature and applies default values.
1276
+
1277
+ E.g. given a function signature `def foo(a, b=2, c=3)`, and at call-site if we do
1278
+ `foo(a=1, c=4)`, the returned BoundArguments object will have args = `[1]`
1279
+ and kwargs = `{'b': 2, 'c': 4}`
1280
+
1281
+ An exception will be raised if binding fails.
1282
+ """
1283
+ try:
1284
+ bound_args = sig.bind_partial(*args, **kwargs)
1285
+ bound_args.apply_defaults()
1286
+ except Exception as e:
1287
+ raise DSLRuntimeError(
1288
+ f"Failed to bind arguments to function `{func_name}` with signature `{sig}`",
1289
+ cause=e,
1290
+ )
1291
+ return bound_args
1292
+
1293
+ def _canonicalize_args(self, sig, *args, **kwargs):
1294
+ """
1295
+ Canonicalize the input arguments so that returned args only contain
1296
+ positional arguments and kwargs only contain keyword arguments.
1297
+ """
1298
+ function_name = self.funcBody.__name__
1299
+ bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
1300
+ canonicalized_args = bound_args.args
1301
+ canonicalized_kwargs = bound_args.kwargs
1302
+ return canonicalized_args, canonicalized_kwargs
1303
+
1304
+ def _check_arg_count(self, *args, **kwargs):
1305
+ if not self.funcBody:
1306
+ raise DSLRuntimeError("Function body is not set.")
1307
+
1308
+ # Pass the actual function object to inspect.signature to get the signature.
1309
+ sig = inspect.signature(self.funcBody)
1310
+
1311
+ function_name = self.funcBody.__name__
1312
+
1313
+ bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
1314
+
1315
+ # Check if all non-default arguments are provided
1316
+ for param in sig.parameters.values():
1317
+ if (
1318
+ param.default is inspect.Parameter.empty
1319
+ and param.name not in bound_args.arguments
1320
+ ):
1321
+ raise DSLRuntimeError(
1322
+ f"Missing required argument in `{function_name}`: '{param.name}'"
1323
+ )
1324
+
1325
+ return sig
1326
+
1327
+ def _func(self, funcBody, *args, **kwargs):
1328
+ """Decorator for MLIR functions.
1329
+ It cuts the boilerplate code, does the following:
1330
+ 1. Generates `func.func`
1331
+ 2. Types translation (numpy arrays -> cute.memref, float -> <f32>, etc.)
1332
+ 3. Compiles and JITs the MLIR module
1333
+ 4. Invokes the generated function
1334
+ 5. Operator overloading (a + b --> arith.addi a, b)
1335
+ 6. Generates GPU kernel function with GPU module and kernel attributes baked
1336
+ """
1337
+ if ir.Context.current is None:
1338
+ pass
1339
+ elif ir.InsertionPoint.current is not None:
1340
+ return funcBody(*args, **kwargs)
1341
+
1342
+ function_name = funcBody.__name__
1343
+ self.funcBody = funcBody
1344
+
1345
+ pipeline = kwargs.pop("pipeline", None)
1346
+ gpu_module_attrs = kwargs.pop("gpu_module_attrs", {})
1347
+
1348
+ # Disable cache
1349
+ no_cache = kwargs.pop("no_cache", False)
1350
+
1351
+ # Always compile(disable cache) and return the result jit_executor
1352
+ compile_only = kwargs.pop("compile_only", False)
1353
+
1354
+ if not no_cache and compile_only:
1355
+ no_cache = True
1356
+ self.print_warning("Cache is disabled as user wants to compile only.")
1357
+
1358
+ # Check the number of arguments
1359
+ sig = self._check_arg_count(*args, **kwargs)
1360
+
1361
+ args_spec = inspect.getfullargspec(funcBody)
1362
+
1363
+ # Canonicalize the input arguments
1364
+ canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
1365
+ sig, *args, **kwargs
1366
+ )
1367
+
1368
+ # Simple name mangling
1369
+ function_name = self.mangle_name(function_name, canonicalized_args, args_spec)
1370
+
1371
+ # Generate MLIR Context and start generating IR
1372
+ log().debug(f"Generating MLIR for function '{function_name}'")
1373
+ result = self.generate_mlir(
1374
+ funcBody,
1375
+ canonicalized_kwargs,
1376
+ function_name,
1377
+ gpu_module_attrs,
1378
+ canonicalized_args,
1379
+ args_spec,
1380
+ pipeline,
1381
+ no_cache,
1382
+ compile_only,
1383
+ )
1384
+
1385
+ return result
1386
+
1387
+ class _KernelGenHelper(ABC):
1388
+ def __init__(self):
1389
+ self.func_op = None
1390
+ self.func_type = None
1391
+
1392
+ @abstractmethod
1393
+ def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None):
1394
+ assert arg_types is not None, "Invalid arg_types!"
1395
+ assert kernel_name is not None, "kernel name is empty"
1396
+ pass
1397
+
1398
+ @abstractmethod
1399
+ def generate_func_ret_op(self):
1400
+ pass
1401
+
1402
+ @abstractmethod
1403
+ def generate_launch_op(self, *args, **kwargs):
1404
+ pass
1405
+
1406
+ @abstractmethod
1407
+ def get_func_body_start(self):
1408
+ pass
1409
+
1410
+ @abstractmethod
1411
+ def enter_gpu_module(module):
1412
+ """Compute the insertion point into the given module."""
1413
+ pass
1414
+
1415
+ @lru_cache(maxsize=1)
1416
+ def _get_default_stream(self):
1417
+ """Returns the default stream 0"""
1418
+ from .runtime import cuda as cuda_helpers
1419
+
1420
+ return cuda_helpers.stream_create()
1421
+
1422
+ def _execute_cuda(
1423
+ self, fname_cubin, kernel_name, grid_size, block_size, smem_size, stream=None
1424
+ ):
1425
+ """
1426
+ Executes a specified CUDA kernel from a cubin file, handling module loading,
1427
+ kernel retrieval, stream creation, kernel launch, and synchronization.
1428
+ """
1429
+ from .runtime import cuda as cuda_helpers
1430
+
1431
+ # Step 1. Load CUDA Module
1432
+ module = cuda_helpers.load_cubin_module(fname_cubin)
1433
+ # Step 2. Find CUDA function
1434
+ kernel_ptr = cuda_helpers.get_kernel_function(module, kernel_name)
1435
+
1436
+ sync_execution_default = False
1437
+ if stream is None:
1438
+ stream = self._get_default_stream()
1439
+ sync_execution_default = True
1440
+
1441
+ # Step 4. Launch the kernel
1442
+ cuda_helpers.launch_kernel(
1443
+ kernel_ptr,
1444
+ grid_size,
1445
+ block_size,
1446
+ stream,
1447
+ smem_size=smem_size,
1448
+ kernel_args=self.exe_args,
1449
+ )
1450
+
1451
+ if sync_execution_default:
1452
+ # Step 5. Optional Sync cuda stream
1453
+ cuda_helpers.stream_sync(stream)
1454
+
1455
+ def _execute_by_cuda_driver(
1456
+ self,
1457
+ kernel_generator,
1458
+ generate_cubin,
1459
+ grid_size,
1460
+ block_size,
1461
+ smem_size,
1462
+ stream=None,
1463
+ ):
1464
+ """
1465
+ This function builds IR and execute the module using cuda driver.
1466
+ It doesn't use mlir's cuda runtime
1467
+ """
1468
+ ret = None
1469
+
1470
+ # Step 1. Build IR
1471
+ with ir.Context(), ir.Location.unknown():
1472
+ loc = self.get_location()
1473
+ module = ir.Module.create(loc=loc)
1474
+ unit_attr = ir.UnitAttr.get()
1475
+ module.operation.attributes["gpu.container_module"] = unit_attr
1476
+ with ir.InsertionPoint(module.body):
1477
+ self._build_gpu_module()
1478
+ ret, kernel_name = kernel_generator()
1479
+ log().debug(
1480
+ f"Kernel generator returned: ret={ret}, kernel_name={kernel_name}"
1481
+ )
1482
+
1483
+ module = self.build_module(module, kernel_name)
1484
+
1485
+ # dryrun is used to only generate IR
1486
+ if self.envar.dryrun:
1487
+ return ret
1488
+
1489
+ # Generate cubin
1490
+ fname_cubin = generate_cubin(module, kernel_name)
1491
+
1492
+ # Execute a cuda kernel from cubin
1493
+ self._execute_cuda(
1494
+ fname_cubin, kernel_name, grid_size, block_size, smem_size, stream
1495
+ )
1496
+
1497
+ return ret
1498
+
1499
+ def generate_kernel_operands_and_types(
1500
+ self, kernel_func, kernel_name, args_spec, args, kwargs
1501
+ ):
1502
+ """
1503
+ Generate the operands and types for the kernel function
1504
+ """
1505
+
1506
+ kernel_operands, kernel_arg_types, kernel_arg_attrs = [], [], []
1507
+
1508
+ log().debug(
1509
+ "Processing GPU kernel call in [%s] mode",
1510
+ (
1511
+ f"Only {self.device_jit_decorator_name}"
1512
+ if self.device_compilation_only
1513
+ else f"{self.host_jit_decorator_name} + {self.device_jit_decorator_name}"
1514
+ ),
1515
+ )
1516
+
1517
+ if self.device_compilation_only:
1518
+ return kernel_operands, kernel_arg_types, kernel_arg_attrs
1519
+
1520
+ kernel_operands, kernel_arg_types, kernel_arg_attrs, _ = (
1521
+ self._generate_jit_func_args(
1522
+ kernel_func, kernel_name, args, kwargs, args_spec, is_host=False
1523
+ )
1524
+ )
1525
+
1526
+ log().debug("Final kernel_operands: %s", ", ".join(map(str, kernel_operands)))
1527
+ log().debug("Final kernel_arg_types: %s", ", ".join(map(str, kernel_arg_types)))
1528
+ log().debug("Final kernel_arg_attrs: %s", ", ".join(map(str, kernel_arg_attrs)))
1529
+
1530
+ assert (
1531
+ len(kernel_operands) == len(kernel_arg_types) == len(kernel_arg_attrs)
1532
+ ), "Size of kernel_operands, kernel_arg_types and kernel_arg_attrs must be equal"
1533
+
1534
+ return kernel_operands, kernel_arg_types, kernel_arg_attrs
1535
+
1536
+ def kernel_launcher(self, *dargs, **dkwargs):
1537
+ def decorator(funcBody):
1538
+ @wraps(funcBody)
1539
+ def kernel_wrapper(*args, **kwargs):
1540
+ """
1541
+ Base decorator for generating kernel function
1542
+
1543
+ This decorator provides a template for kernel function generation
1544
+ including kernel function header/body and kernel launch op at call site
1545
+
1546
+ Optional arguments (with default value in <>):
1547
+ - requiredArgs <[]>: specifies the mandatory arguments that must present in kernel function signature
1548
+ the args will be validated and collected as a namedtuple
1549
+ - optionalArgs <[]>: specifies the optional arguments that might present in kernel function signature
1550
+ the args will be collected (if present) as a namedtuple
1551
+ - unitAttrNames <[]>: specifies the name(s) of ir.UnitAttr to be set for kernel function op
1552
+ - valueAttrDict <{}>: specifies the name(s) and value(s) of ir.Attribute to be set for kernel function op
1553
+ - kernelGenHelper <None>: specifies the mandatory customized kernel generation helper class (derived from _KernelGenHelper)
1554
+
1555
+ Return value:
1556
+ A namedtuple "KernelReturns" is returned with following fields:
1557
+ - kernel_func_ret: the return of the kernel function
1558
+ - launch_op_ret: the return of the launch op
1559
+ """
1560
+
1561
+ requiredArgs = dkwargs.get("requiredArgs", [])
1562
+ optionalArgs = dkwargs.get("optionalArgs", [])
1563
+ unitAttrNames = dkwargs.get("unitAttrNames", [])
1564
+ valueAttrDict = dkwargs.get("valueAttrDict", {})
1565
+ kernelGenHelper = dkwargs.get("kernelGenHelper", None)
1566
+
1567
+ kernel_name = funcBody.__name__
1568
+ args_spec = inspect.getfullargspec(funcBody)
1569
+ self.funcBody = funcBody
1570
+
1571
+ # Give each kernel a unique name. (The same kernel may be
1572
+ # called multiple times, resulting in multiple kernel traces.)
1573
+ # The mangled name of Python function is part of the name to
1574
+ # improve readability.
1575
+ kernel_name = f"kernel_{self.mangle_name(kernel_name, args, args_spec)}_{self.num_kernels}"
1576
+ self.num_kernels += 1
1577
+
1578
+ # Step 0. Preprocess the arguments
1579
+ def extract_args(argNames, assertIfNone=False) -> list:
1580
+ extracted = []
1581
+ for name in argNames:
1582
+ value = kwargs.pop(name, None)
1583
+ if assertIfNone and value is None:
1584
+ raise DSLRuntimeError(
1585
+ f"{name} is required for {kernel_name}"
1586
+ )
1587
+ extracted.append(value)
1588
+
1589
+ return extracted
1590
+
1591
+ RequiredArgs = namedtuple("RequiredArgs", requiredArgs)
1592
+ req_args = (
1593
+ RequiredArgs._make(extract_args(requiredArgs, assertIfNone=True))
1594
+ if requiredArgs
1595
+ else None
1596
+ )
1597
+ OptionalArgs = namedtuple("OptionalArgs", optionalArgs)
1598
+ opt_args = (
1599
+ OptionalArgs._make(extract_args(optionalArgs))
1600
+ if optionalArgs
1601
+ else None
1602
+ )
1603
+ assert (
1604
+ kernelGenHelper is not None
1605
+ ), "kernelGenHelper should be explicitly specified!"
1606
+
1607
+ # check arguments
1608
+ sig = self._check_arg_count(*args, **kwargs)
1609
+
1610
+ # Canonicalize the input arguments
1611
+ canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
1612
+ sig, *args, **kwargs
1613
+ )
1614
+
1615
+ kernel_operands, kernel_types, kernel_arg_attrs = (
1616
+ self.generate_kernel_operands_and_types(
1617
+ funcBody,
1618
+ kernel_name,
1619
+ args_spec,
1620
+ canonicalized_args,
1621
+ canonicalized_kwargs,
1622
+ )
1623
+ )
1624
+
1625
+ with self._enter_gpu_module():
1626
+ log().debug("Generating device kernel")
1627
+ if self.device_compilation_only:
1628
+ log().debug("Generating cuda-python arguments")
1629
+ # Convert input arguments to MLIR arguments
1630
+ self.exe_args, kernel_types, _ = (
1631
+ self.generate_mlir_function_types(
1632
+ funcBody,
1633
+ kernel_name,
1634
+ canonicalized_args,
1635
+ canonicalized_kwargs,
1636
+ args_spec,
1637
+ )
1638
+ )
1639
+
1640
+ helper = kernelGenHelper()
1641
+ loc = self.get_location()
1642
+ fop = helper.generate_func_op(
1643
+ kernel_types, kernel_arg_attrs, kernel_name, loc
1644
+ )
1645
+ log().debug(f"Kernel function op: {fop}")
1646
+ for attr in unitAttrNames:
1647
+ fop.attributes[attr] = ir.UnitAttr.get()
1648
+ for key, val in valueAttrDict.items():
1649
+ fop.attributes[key] = val
1650
+
1651
+ fop.sym_visibility = ir.StringAttr.get("public")
1652
+ with ir.InsertionPoint(helper.get_func_body_start()):
1653
+ ir_args, ir_kwargs = self.generate_execution_arguments(
1654
+ canonicalized_args, canonicalized_kwargs, fop, args_spec
1655
+ )
1656
+ log().debug(
1657
+ f"IR arguments - args: {ir_args} ; kwargs: {ir_kwargs}"
1658
+ )
1659
+ # Call user function body
1660
+ kernel_ret = funcBody(*ir_args, **ir_kwargs)
1661
+ helper.generate_func_ret_op()
1662
+
1663
+ # Step 3. Generate call site `launch_func`
1664
+ kernel_sym = ir.SymbolRefAttr.get(["kernels", kernel_name])
1665
+ launch_ret = helper.generate_launch_op(
1666
+ kernelSym=kernel_sym,
1667
+ kernelOperands=kernel_operands,
1668
+ requiredArgs=req_args,
1669
+ optionalArgs=opt_args,
1670
+ )
1671
+
1672
+ KernelReturns = namedtuple(
1673
+ "KernelReturns", ["kernel_func_ret", "launch_op_ret"]
1674
+ )
1675
+ result = KernelReturns(
1676
+ kernel_func_ret=kernel_ret, launch_op_ret=launch_ret
1677
+ )
1678
+ log().debug(f"Kernel result: {result}, kernel name: {kernel_name}")
1679
+ return result, kernel_name
1680
+
1681
+ return kernel_wrapper
1682
+
1683
+ if len(dargs) == 1 and callable(dargs[0]):
1684
+ return decorator(dargs[0])
1685
+ else:
1686
+ return decorator
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/env_manager.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ """
13
+ This module provides utilities for the environment variables setup.
14
+
15
+ It provides an EnvironmentVarManager, which reads environment variables for the DSL
16
+ and caches them for efficient access.
17
+
18
+ It also provides utilities to automatically setup a subset of environment variables
19
+ based on heuristics.
20
+ """
21
+
22
+ import os
23
+ import sys
24
+ import shutil
25
+ import glob
26
+ from pathlib import Path
27
+ from functools import lru_cache
28
+ from typing import Any
29
+
30
+ from ..base_dsl.runtime.cuda import get_compute_capability_major_minor
31
+ from .utils.logger import log
32
+
33
+ IS_WINDOWS = sys.platform == "win32"
34
+ CLIB_EXT = ".dll" if IS_WINDOWS else ".so"
35
+
36
+ # =============================================================================
37
+ # Environment Variable Helpers
38
+ # =============================================================================
39
+
40
+
41
+ @lru_cache(maxsize=None)
42
+ def get_str_env_var(var_name, default_value=None):
43
+ value = os.getenv(var_name)
44
+ return value if value is not None else default_value
45
+
46
+
47
+ @lru_cache(maxsize=None)
48
+ def get_bool_env_var(var_name, default_value=False):
49
+ value = get_str_env_var(var_name)
50
+ if value is None:
51
+ return default_value
52
+ return value not in {"False", "0", ""}
53
+
54
+
55
+ @lru_cache(maxsize=None)
56
+ def get_int_env_var(var_name, default_value=0):
57
+ value = get_str_env_var(var_name)
58
+ return int(value) if value and value.isdigit() else default_value
59
+
60
+
61
+ @lru_cache(maxsize=None)
62
+ def has_env_var(var_name):
63
+ return os.getenv(var_name) is not None
64
+
65
+
66
+ def detect_gpu_arch(prefix):
67
+ """
68
+ Attempts to detect the machine's GPU architecture.
69
+
70
+ Returns:
71
+ A string representing the GPU architecture (e.g. "70" for compute capability 7.0),
72
+ or a default value(e.g. "sm_100") if the GPU architecture cannot be determined.
73
+ """
74
+ arch = (None, None)
75
+ try:
76
+ arch = get_compute_capability_major_minor()
77
+ except Exception as e:
78
+ log().info(f"Failed to get CUDA compute capability: {e}")
79
+
80
+ if arch == (None, None):
81
+ # default to sm_100
82
+ arch = (10, 0)
83
+
84
+ major, minor = arch
85
+ suffix = ""
86
+ if major >= 9:
87
+ suffix = "a"
88
+
89
+ return f"sm_{major}{minor}{suffix}"
90
+
91
+
92
+ def find_libs_in_ancestors(start, target_libs, lib_folder_guesses):
93
+ """
94
+ Search ancestor directories for a candidate library folder containing all required libraries.
95
+
96
+ Starting from the given path, this function traverses up through each parent directory.
97
+ For every ancestor, it checks candidate subdirectories (specified by lib_folder_guesses)
98
+ for files that match the required library extension (CLIB_EXT). Library file names are
99
+ canonicalized by removing the "lib" prefix from their stem. If a candidate directory contains
100
+ all of the required libraries (as specified in target_libs), the function returns a list of
101
+ absolute paths to these library files.
102
+
103
+ Parameters:
104
+ start (str or Path): The starting directory from which to begin the search.
105
+ target_libs (iterable of str): A collection of required library names (without the "lib" prefix).
106
+ lib_folder_guesses (iterable of str): Relative paths from an ancestor directory that may contain the libraries.
107
+
108
+ Returns:
109
+ list[str] or None: A list of resolved paths to the required library files if found; otherwise, None.
110
+ """
111
+ # Traverse through all parent directories of the resolved starting path.
112
+ for ancestor in Path(start).resolve().parents:
113
+ # Iterate over each candidate relative directory path.
114
+ for rel_path in lib_folder_guesses:
115
+ target_dir = ancestor / rel_path
116
+ # Skip if the candidate directory does not exist.
117
+ if not target_dir.is_dir():
118
+ continue
119
+
120
+ # Initialize a list to hold the resolved paths of matching library files.
121
+ libs_cand = []
122
+ # Create a set of the remaining libraries we need to find.
123
+ remaining_libs = set(target_libs)
124
+
125
+ # Iterate over all items in the candidate directory.
126
+ for p in target_dir.iterdir():
127
+ # Consider only files with the expected library extension.
128
+ if p.suffix == CLIB_EXT:
129
+ # Canonicalize the library name by removing the "lib" prefix.
130
+ lib_name = p.stem.removeprefix("lib")
131
+ # If this library is required, add its resolved path and mark it as found.
132
+ if lib_name in remaining_libs:
133
+ libs_cand.append(str(p.resolve()))
134
+ remaining_libs.remove(lib_name)
135
+
136
+ # If all required libraries have been found, return the list of library paths.
137
+ if len(remaining_libs) == 0:
138
+ return libs_cand
139
+
140
+ # Return None if no candidate directory contains all required libraries.
141
+ return None
142
+
143
+
144
+ def _find_cuda_home():
145
+ """Find the CUDA installation path using a series of heuristic methods.
146
+ Methods below are checked in order, and the function returns on first match:
147
+ 1. Checking the environment variables CUDA_HOME and CUDA_PATH.
148
+ 2. Searching for the 'nvcc' compiler in the system PATH and deriving the path of cuda.
149
+ 3. Scanning common installation directories based on the operating system.
150
+ - On Windows systems (when IS_WINDOWS is True), it searches in:
151
+ C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*
152
+ - On Unix-like systems, it searches in:
153
+ /usr/local/cuda*
154
+
155
+ Returns:
156
+ Optional[str]: The absolute CUDA installation path if found; otherwise, None.
157
+
158
+ Note:
159
+ The variable IS_WINDOWS is defined in the module scope.
160
+ """
161
+ # Guess #1
162
+ cuda_home = get_str_env_var("CUDA_HOME") or get_str_env_var("CUDA_PATH")
163
+ if cuda_home is None:
164
+ # Guess #2
165
+ nvcc_path = shutil.which("nvcc")
166
+ if nvcc_path is not None:
167
+ cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
168
+ else:
169
+ # Guess #3
170
+ if IS_WINDOWS:
171
+ glob_pat = "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*"
172
+ else:
173
+ glob_pat = "/usr/local/cuda*"
174
+ cuda_homes = glob.glob(glob_pat)
175
+ if len(cuda_homes) == 0:
176
+ cuda_home = ""
177
+ else:
178
+ cuda_home = cuda_homes[0]
179
+ if not os.path.exists(cuda_home):
180
+ cuda_home = None
181
+ return cuda_home
182
+
183
+
184
+ def get_cuda_toolkit_path():
185
+ """
186
+ Get cuda_toolkit_path. It returns get_str_env_var('CUDA_TOOLKIT_PATH') if
187
+ set. Otherwise, attempts to discover a valid CUDA toolkit location and
188
+ return. If not found, return None.
189
+ """
190
+ # Check if the environment variable is already set, if so, return it immediately.
191
+ try:
192
+ cuda_toolkit_path_existing = get_str_env_var("CUDA_TOOLKIT_PATH")
193
+ if cuda_toolkit_path_existing:
194
+ return cuda_toolkit_path_existing
195
+
196
+ found_cuda_home = _find_cuda_home()
197
+ if found_cuda_home:
198
+ return found_cuda_home
199
+ except Exception as e:
200
+ log().info("default_env: exception on get_cuda_toolkit_path", e)
201
+ return None
202
+
203
+
204
+ def get_prefix_dsl_libs(prefix: str):
205
+ """
206
+ Returns get_str_env_var('{prefix}_LIBS') if set.
207
+ Otherwise, attempts to discover libs based on heuristics and return
208
+ If not found, return None.
209
+ """
210
+ # Check if the environment variable is already set, if so, return it immediately.
211
+ try:
212
+ prefix_libs_existing = get_str_env_var(f"{prefix}_LIBS")
213
+ if prefix_libs_existing:
214
+ return prefix_libs_existing
215
+
216
+ def get_libs_cand(start):
217
+ target_libs = {
218
+ "mlir_c_runner_utils",
219
+ "mlir_runner_utils",
220
+ "mlir_cuda_runtime",
221
+ }
222
+ lib_folder_guesses = [
223
+ "lib",
224
+ ]
225
+
226
+ libs_cand = find_libs_in_ancestors(start, target_libs, lib_folder_guesses)
227
+ if libs_cand:
228
+ dsl_libs = ":".join(libs_cand)
229
+ return dsl_libs
230
+
231
+ return None
232
+
233
+ # find from install folder
234
+ dsl_libs = get_libs_cand(__file__)
235
+
236
+ if not dsl_libs:
237
+ # try to find from build folder structure
238
+ dsl_libs = get_libs_cand(Path(__file__).parent.parent.resolve())
239
+
240
+ return dsl_libs
241
+
242
+ except Exception as e:
243
+ log().info(f"default_env: exception on get_prefix_dsl_libs", e)
244
+ return None
245
+
246
+
247
+ class EnvironmentVarManager:
248
+ """Manages environment variables for configuration options.
249
+
250
+ Printing options:
251
+ - [DSL_NAME]_LOG_TO_CONSOLE: Print logging to stderr (default: False)
252
+ - [DSL_NAME]_PRINT_AFTER_PREPROCESSOR: Print after preprocess (default: False)
253
+ - [DSL_NAME]_PRINT_IR: Print generated IR (default: False)
254
+ - [DSL_NAME]_FILTER_STACKTRACE: Filter internal stacktrace (default: True)
255
+ File options:
256
+ - [DSL_NAME]_KEEP_IR: Save generated IR in a file (default: False)
257
+ - [DSL_NAME]_LOG_TO_FILE: Store all logging into a file, excluding COMPILE_LOGS (default: False)
258
+ Other options:
259
+ - [DSL_NAME]_LOG_LEVEL: Logging level to set, for LOG_TO_CONSOLE or LOG_TO_FILE (default: 1).
260
+ - [DSL_NAME]_DRYRUN: Generates IR only (default: False)
261
+ - [DSL_NAME]_ARCH: GPU architecture (default: "sm_100")
262
+ - [DSL_NAME]_WARNINGS_AS_ERRORS: Enable warnings as error (default: False)
263
+ - [DSL_NAME]_WARNINGS_IGNORE: Ignore warnings (default: False)
264
+ - [DSL_NAME]_ENABLE_OPTIMIZATION_WARNINGS: Enable warnings of optimization warnings (default: False)
265
+ - [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False)
266
+ - [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False)
267
+ - [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000)
268
+ - [DSL_NAME]_LIBS: Path to dependent shared libraries (default: None)
269
+ - [DSL_NAME]_NO_SOURCE_LOCATION: Generate source location (default: False)
270
+ """
271
+
272
+ def __init__(self, prefix="DSL"):
273
+ self.prefix = prefix # change if needed
274
+
275
+ # Printing options
276
+ self.print_after_preprocessor = get_bool_env_var(
277
+ f"{prefix}_PRINT_AFTER_PREPROCESSOR", False
278
+ )
279
+ self.printIR = get_bool_env_var(f"{prefix}_PRINT_IR", False)
280
+ self.filterStacktrace = get_bool_env_var(f"{prefix}_FILTER_STACKTRACE", True)
281
+ # File options
282
+ self.keepIR = get_bool_env_var(f"{prefix}_KEEP_IR", False)
283
+ # Logging options
284
+ self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False)
285
+ self.log_to_file = get_bool_env_var(f"{prefix}_LOG_TO_FILE", False)
286
+ if (
287
+ has_env_var(f"{prefix}_LOG_LEVEL")
288
+ and not self.log_to_console
289
+ and not self.log_to_file
290
+ ):
291
+ log().warning(
292
+ f"Log level was set, but neither logging to file ({prefix}_LOG_TO_FILE) nor logging to console ({prefix}_LOG_TO_CONSOLE) is enabled!"
293
+ )
294
+ self.log_level = get_int_env_var(f"{prefix}_LOG_LEVEL", 1)
295
+
296
+ # Other options
297
+ self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False)
298
+ self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix))
299
+ self.warnings_as_errors = get_bool_env_var(
300
+ f"{prefix}_WARNINGS_AS_ERRORS", False
301
+ )
302
+ self.warnings_ignore = get_bool_env_var(f"{prefix}_WARNINGS_IGNORE", False)
303
+ self.enable_optimization_warnings = get_bool_env_var(
304
+ f"{prefix}_ENABLE_OPTIMIZATION_WARNINGS", False
305
+ )
306
+ self.jitTimeProfiling = get_bool_env_var(f"{prefix}_JIT_TIME_PROFILING", False)
307
+ self.disable_file_caching = get_bool_env_var(
308
+ f"{prefix}_DISABLE_FILE_CACHING", False
309
+ )
310
+ self.file_caching_capacity = get_int_env_var(
311
+ f"{prefix}_FILE_CACHING_CAPACITY", 1000
312
+ )
313
+ self.generate_source_location = not get_bool_env_var(
314
+ f"{prefix}_NO_SOURCE_LOCATION", False
315
+ )
316
+ # set cuda
317
+ self.cuda_toolkit = get_cuda_toolkit_path()
318
+
319
+ # set mlir shared libraries
320
+ self.shared_libs = get_prefix_dsl_libs(prefix)
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/jit_executor.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ """
13
+ This module provides jit executor related classes
14
+ """
15
+ import ctypes
16
+ import inspect
17
+ import io
18
+ from typing import get_origin
19
+
20
+ import numpy as np
21
+
22
+ # MLIR modules imports
23
+ from .._mlir import ir
24
+
25
+ # Local modules imports
26
+ from . import typing as t
27
+ from .common import DSLRuntimeError
28
+ from .runtime import cuda as cuda_helpers
29
+ from .runtime.jit_arg_adapters import JitArgAdapterRegistry, is_arg_spec_constexpr
30
+ from .typing import get_c_pointers
31
+ from .utils.logger import log
32
+ from .utils.timer import timer
33
+
34
+
35
+ class CudaSingleModule:
36
+ def __init__(self, cuda_module, kernel_ptr):
37
+ self.cuda_module = cuda_module
38
+ self.kernel_ptr = kernel_ptr
39
+
40
+
41
+ class CudaModules:
42
+ def __init__(self, modules, args):
43
+ # list of CudaSingleModule
44
+ self.modules = modules
45
+ # extra kernel ptr arguments for launch
46
+ self.args = args
47
+
48
+
49
+ class JitExecutor:
50
+ def __init__(
51
+ self,
52
+ dsl,
53
+ engine,
54
+ capi_func,
55
+ ir_module,
56
+ args_spec,
57
+ function_name,
58
+ cuda_modules: CudaModules = None,
59
+ jit_time_profiling=False,
60
+ ):
61
+ self.dsl = dsl
62
+ self.engine = engine
63
+ self.capi_func = capi_func
64
+ self.ir_module = ir_module
65
+ self.args_spec = args_spec
66
+ self.function_name = function_name
67
+ if args_spec is not None:
68
+ self.original_args_spec = args_spec
69
+ self.args_spec = self.filter_runtime_arg_spec(args_spec)
70
+ # cuda kernels
71
+ self.cuda_modules = cuda_modules
72
+ self.jit_time_profiling = jit_time_profiling
73
+
74
+ def filter_runtime_arg_spec(self, arg_spec: inspect.FullArgSpec):
75
+ runtime_args = []
76
+ runtime_annotations = {}
77
+ runtime_defaults = []
78
+
79
+ # Calculate the offset where defaults start in the original args
80
+ if arg_spec.defaults:
81
+ defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults)
82
+ else:
83
+ defaults_start_idx = len(arg_spec.args)
84
+
85
+ # Filter arguments and maintain their properties
86
+ for i, arg_name in enumerate(arg_spec.args):
87
+ arg_type = arg_spec.annotations.get(arg_name, None)
88
+
89
+ # Skip compile-time arguments
90
+ if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name):
91
+ continue
92
+
93
+ # Keep runtime arguments
94
+ runtime_args.append(arg_name)
95
+ if arg_name in arg_spec.annotations:
96
+ runtime_annotations[arg_name] = arg_type
97
+
98
+ # Keep corresponding default if it exists
99
+ if i >= defaults_start_idx:
100
+ default_idx = i - defaults_start_idx
101
+ runtime_defaults.append(arg_spec.defaults[default_idx])
102
+
103
+ # Filter kwonlyargs and their defaults
104
+ runtime_kwonlyargs = []
105
+ runtime_kwonlydefaults = {}
106
+
107
+ if arg_spec.kwonlyargs:
108
+ for kwarg in arg_spec.kwonlyargs:
109
+ arg_type = arg_spec.annotations.get(kwarg, None)
110
+
111
+ # Apply same filtering logic
112
+ if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name):
113
+ continue
114
+
115
+ runtime_kwonlyargs.append(kwarg)
116
+ if kwarg in arg_spec.annotations:
117
+ runtime_annotations[kwarg] = arg_type
118
+ if arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults:
119
+ runtime_kwonlydefaults[kwarg] = arg_spec.kwonlydefaults[kwarg]
120
+
121
+ # Convert runtime_defaults to tuple if not empty (as expected by FullArgSpec)
122
+ runtime_defaults = tuple(runtime_defaults) if runtime_defaults else None
123
+
124
+ return inspect.FullArgSpec(
125
+ args=runtime_args,
126
+ varargs=arg_spec.varargs, # Keep original varargs
127
+ varkw=arg_spec.varkw, # Keep original varkw
128
+ defaults=runtime_defaults,
129
+ kwonlyargs=runtime_kwonlyargs,
130
+ kwonlydefaults=runtime_kwonlydefaults if runtime_kwonlydefaults else None,
131
+ annotations=runtime_annotations,
132
+ )
133
+
134
+ def __del__(self):
135
+ if self.cuda_modules:
136
+ cuda_modules = [module.cuda_module for module in self.cuda_modules.modules]
137
+ for module in set(cuda_modules):
138
+ cuda_helpers.unload_cubin_module(module)
139
+
140
+ def get_constexpr_args(self) -> list[dict[str, int | str]]:
141
+ """
142
+ This function returns the constexpr args that have been pruned from the original function signature.
143
+ The return type is a list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name).
144
+
145
+ :return: list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name).
146
+ :rtype: list[dict[str, int | str]]
147
+ """
148
+ if self.original_args_spec is None:
149
+ return list()
150
+ constexpr_args = list()
151
+ for i, arg_name in enumerate(self.original_args_spec.args):
152
+ if arg_name not in self.args_spec.args:
153
+ constexpr_args.append({"argument_index": i, "argument_name": arg_name})
154
+
155
+ if self.original_args_spec.kwonlyargs:
156
+ for kwarg in self.original_args_spec.kwonlyargs:
157
+ if kwarg not in self.args_spec.kwonlyargs:
158
+ constexpr_args.append(
159
+ {"argument_index": None, "argument_name": kwarg}
160
+ )
161
+ return constexpr_args
162
+
163
+ def generate_execution_args(self, args, kwargs, args_spec: inspect.FullArgSpec):
164
+ """
165
+ This function is the prune version of `generate_mlir_function_types` which only generates execution args
166
+ to get rid of mlir context.
167
+ """
168
+
169
+ # Process positional arguments with defaults
170
+ rectified_args = list(args)
171
+ if args_spec.defaults and len(args) < len(args_spec.args):
172
+ rectified_args.extend(args_spec.defaults[len(args) - len(args_spec.args) :])
173
+ for k, v in kwargs.items():
174
+ if k in args_spec.args:
175
+ idx = args_spec.args.index(k)
176
+ if idx < len(rectified_args):
177
+ rectified_args[idx] = v
178
+ else:
179
+ rectified_args.append(v)
180
+
181
+ # Process keyword arguments
182
+ rectified_kwargs = {k: v for k, v in kwargs.items() if k not in args_spec.args}
183
+ if args_spec.kwonlydefaults and len(rectified_kwargs) < len(
184
+ args_spec.kwonlyargs
185
+ ):
186
+ rectified_kwargs.update(args_spec.kwonlydefaults)
187
+
188
+ # args/kwargs must match arg_specs
189
+ if len(rectified_args) != len(args_spec.args) or len(rectified_kwargs) != len(
190
+ args_spec.kwonlyargs
191
+ ):
192
+ raise DSLRuntimeError(
193
+ "input args/kwargs length does not match runtime function signature!",
194
+ context={
195
+ "input args length": len(rectified_args),
196
+ "input kwargs length": len(rectified_kwargs),
197
+ "function signature args length": len(args_spec.args),
198
+ "function signature kwonlyargs length": len(args_spec.kwonlyargs),
199
+ },
200
+ )
201
+
202
+ exe_args = []
203
+ adapted_args = []
204
+ input_args = rectified_args + list(rectified_kwargs.values())
205
+ input_arg_names = args_spec.args + args_spec.kwonlyargs
206
+ for arg, arg_name in zip(input_args, input_arg_names):
207
+ # short-cut for args already converted
208
+ if hasattr(arg, "__c_pointers__"):
209
+ exe_args.extend(arg.__c_pointers__())
210
+ continue
211
+
212
+ arg_type = args_spec.annotations.get(arg_name, None)
213
+
214
+ # Implicit cast to NumericMeta
215
+ if isinstance(arg_type, t.NumericMeta):
216
+ arg = t.cast(arg, arg_type)
217
+ else:
218
+ # If not any known type, try registered adapter to do the conversion
219
+ adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
220
+ if adapter:
221
+ arg = adapter(arg)
222
+ adapted_args.append(arg)
223
+
224
+ exe_args.extend(get_c_pointers(arg))
225
+
226
+ return exe_args, adapted_args
227
+
228
+ def __call__(self, *args, **kwargs):
229
+ exe_args, adapted_args = self.generate_execution_args(
230
+ args, kwargs, self.args_spec
231
+ )
232
+
233
+ self.run_compiled_program(exe_args)
234
+
235
+ # Assume each execution args has type `c_void_p` to reduce the overhead of `ctypes.cast`.
236
+ def get_invoke_packed_args(self, exe_args):
237
+ if self.cuda_modules:
238
+ exe_args += self.cuda_modules.args
239
+ packed_args = (ctypes.c_void_p * len(exe_args))()
240
+ for argNum in range(len(exe_args)):
241
+ packed_args[argNum] = exe_args[argNum]
242
+ return packed_args
243
+
244
+ def run_compiled_program(self, exe_args):
245
+ if self.jit_time_profiling:
246
+ profiler = timer(enable=True)
247
+ try:
248
+ packed_args = profiler(self.get_invoke_packed_args)(exe_args)
249
+ profiler(self.capi_func)(packed_args)
250
+ except Exception as e:
251
+ raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e)
252
+ else:
253
+ try:
254
+ packed_args = self.get_invoke_packed_args(exe_args)
255
+ self.capi_func(packed_args)
256
+ except Exception as e:
257
+ raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e)
258
+
259
+ def update_jit_cuda_modules(self, kernel_symbols):
260
+ # preload cuda module from compiled cubin in ir and store to jit_executor.kernels.
261
+ if len(kernel_symbols) > 0:
262
+ extra_args = []
263
+ module = self.ir_module
264
+ cuda_kernel_cache = dict()
265
+ cuda_driver_version = cuda_helpers.get_driver_version()
266
+ for sym in kernel_symbols:
267
+ if sym not in cuda_kernel_cache:
268
+ log().debug(f"Loading CUDA module for symbol: {sym}")
269
+
270
+ # load cuda module/get function pointer from module and cache
271
+ def walk_callback(sym, func_sym, cubin_data):
272
+ cubin_module = cuda_helpers.load_cubin_module_data(cubin_data)
273
+ kernel_ptr = cuda_helpers.get_kernel_function(
274
+ cubin_module, func_sym
275
+ )
276
+ # Enable non-portable cluster size for CUDA version 11.8 or higher.
277
+ if cuda_driver_version >= 11080:
278
+ cuda_helpers.set_kernel_attribute(
279
+ kernel_ptr,
280
+ cuda_helpers.cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED,
281
+ 1,
282
+ )
283
+ cuda_kernel_cache[sym] = CudaSingleModule(
284
+ cubin_module, kernel_ptr
285
+ )
286
+
287
+ self.walk_module_and_get_cubin_data(module, sym, walk_callback)
288
+ else:
289
+ log().debug(f"Symbol {sym} already in cache")
290
+ # check if kernel is empty.
291
+ if sym in cuda_kernel_cache:
292
+ extra_args.append(
293
+ ctypes.c_void_p(cuda_kernel_cache[sym].kernel_ptr.getPtr())
294
+ )
295
+ # store to the jit result if jit result is cached.
296
+ self.cuda_modules = CudaModules(cuda_kernel_cache.values(), extra_args)
297
+
298
+ return self
299
+
300
+ def _get_escaped_cubin_bytes(self, cubin_data):
301
+ """This function escapes cubin data from mlir raw bytecode to executable binary bytes"""
302
+
303
+ def ishex(inp):
304
+ return (
305
+ inp in range(0x30, 0x3A)
306
+ or inp in range(0x61, 0x67)
307
+ or inp in range(0x41, 0x47)
308
+ )
309
+
310
+ converted = bytearray()
311
+ idx = 0
312
+ while idx < len(cubin_data):
313
+ # escape the original bytes
314
+ if cubin_data[idx] == 0x5C:
315
+ # if data of idx is b'\\'
316
+ if ishex(cubin_data[idx + 1]) and ishex(cubin_data[idx + 2]):
317
+ converted += bytearray.fromhex(
318
+ cubin_data[idx + 1 : idx + 3].decode()
319
+ )
320
+ idx += 3
321
+ elif cubin_data[idx + 1] == 0x5C:
322
+ converted.append(cubin_data[idx])
323
+ idx += 2
324
+ else:
325
+ # no escape, directly write
326
+ converted.append(cubin_data[idx])
327
+ idx += 1
328
+ return bytes(converted)
329
+
330
+ def walk_module_and_get_cubin_data(self, module, sym, callback):
331
+ """This function is used to walk gpu binary op, extract the cubin inside, and process cubin data with callback."""
332
+
333
+ def walk_gpu_binary_op(op):
334
+ if op.name != "gpu.binary":
335
+ return ir.WalkResult.ADVANCE
336
+ s = io.BytesIO()
337
+ op.write_bytecode(s)
338
+ cubin_data = s.getvalue()
339
+ if sym.encode() not in cubin_data:
340
+ return ir.WalkResult.ADVANCE
341
+
342
+ if (
343
+ "kernels" != op.opview.sym_name.value
344
+ and sym != op.opview.sym_name.value
345
+ ):
346
+ return ir.WalkResult.ADVANCE
347
+ # function symbol of kernel(gpu.launch_func) is equal to sym name in mlir
348
+ func_sym = sym
349
+ if sym == op.opview.sym_name.value and not sym.endswith("_kernel"):
350
+ func_sym = sym.rsplit("_", 1)[0]
351
+
352
+ cubin_data = cubin_data.split(b'bin = "')[1].split(b'">')[0]
353
+ cubin_data = self._get_escaped_cubin_bytes(cubin_data)
354
+ callback(sym, func_sym, cubin_data)
355
+ return ir.WalkResult.ADVANCE
356
+
357
+ module.operation.walk(walk_gpu_binary_op)
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ """
13
+ This module provides a runtime utility functions that are needed for
14
+ the DSL.
15
+ """
16
+
17
+ from . import dlpack_types
18
+ from . import cuda
19
+ from . import jit_arg_adapters
20
+
21
+ __all__ = [
22
+ "dlpack_types",
23
+ "cuda",
24
+ "jit_arg_adapters",
25
+ ]
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/cuda.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ """
13
+ This module provides CUDA Python helper functions
14
+ """
15
+
16
+
17
+ from functools import lru_cache
18
+ from dataclasses import dataclass
19
+ from typing import List, Optional
20
+ import numpy as np
21
+ import os
22
+ import ctypes
23
+
24
+ import cuda.bindings.driver as cuda
25
+ import cuda.bindings.nvrtc as nvrtc
26
+
27
+ # MLIR imports
28
+ from ..._mlir import ir
29
+ from ..._mlir.dialects import gpu
30
+
31
+ # Local module imports
32
+ from ..utils.logger import log as _log
33
+ from ..common import *
34
+ from .jit_arg_adapters import JitArgAdapterRegistry
35
+
36
+
37
+ # =============================================================================
38
+ # Utils
39
+ # =============================================================================
40
+
41
+
42
+ def _cudaGetErrorEnum(error):
43
+ if isinstance(error, cuda.CUresult):
44
+ err, name = cuda.cuGetErrorName(error)
45
+ return name if err == cuda.CUresult.CUDA_SUCCESS else "<unknown>"
46
+ elif isinstance(error, nvrtc.nvrtcResult):
47
+ return nvrtc.nvrtcGetErrorString(error)[1]
48
+ else:
49
+ raise DSLRuntimeError("Unknown error type: {}".format(error))
50
+
51
+
52
+ def _get_gpu_arch_info(major, minor):
53
+ """Get GPU architecture information and compatibility details."""
54
+ gpu_arch_map = {
55
+ (7, 0): ("Volta", "sm_70", ["sm_70"]), # V100
56
+ (7, 5): ("Turing", "sm_75", ["sm_75"]), # RTX 20 Series, Quadro RTX
57
+ (8, 0): ("Ampere", "sm_80", ["sm_80"]), # A100
58
+ (8, 6): ("Ampere", "sm_86", ["sm_86", "sm_80"]), # RTX 30 Series
59
+ (8, 9): ("Ada", "sm_89", ["sm_89", "sm_86"]), # RTX 40 Series
60
+ (8, 7): ("Ampere", "sm_87", ["sm_87", "sm_86", "sm_80"]), # A10, A40
61
+ (9, 0): ("Hopper", "sm_90a", ["sm_90a"]), # H100
62
+ (10, 0): ("Blackwell", "sm_100a", ["sm_100a"]), # B200
63
+ }
64
+ return gpu_arch_map.get(
65
+ (major, minor), ("Unknown", f"sm_{major}{minor}", [f"sm_{major}{minor}"])
66
+ )
67
+
68
+
69
+ def get_compute_capability_major_minor(device_id: int = 0):
70
+ """
71
+ Returns the compute capability of the CUDA device as a tuple of (major, minor).
72
+ For example: (8, 0) for Ampere, (9, 0) for Hopper, (10, 0) for Blackwell.
73
+ Returns None on failure.
74
+ """
75
+ try:
76
+ checkCudaErrors(cuda.cuInit(0))
77
+ device = checkCudaErrors(cuda.cuDeviceGet(device_id))
78
+ major = checkCudaErrors(
79
+ cuda.cuDeviceGetAttribute(
80
+ cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
81
+ device,
82
+ )
83
+ )
84
+ minor = checkCudaErrors(
85
+ cuda.cuDeviceGetAttribute(
86
+ cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
87
+ device,
88
+ )
89
+ )
90
+ return major, minor
91
+ except RuntimeError as e:
92
+ _log().info(f"Failed to get CUDA compute capability: {e}")
93
+ return None, None
94
+
95
+
96
+ @dataclass
97
+ class DeviceInfo:
98
+ """Data class to store CUDA device information."""
99
+
100
+ device_count: int = 0
101
+ current_device: int = 0
102
+ device_name: Optional[str] = None
103
+ major_version: Optional[int] = None
104
+ minor_version: Optional[int] = None
105
+ arch_name: Optional[str] = None
106
+ sm_arch: Optional[str] = None
107
+ compatible_archs: Optional[List[str]] = None
108
+ memory_gb: Optional[float] = None
109
+ target_arch: Optional[str] = None
110
+ error_message: Optional[str] = None
111
+ initialization_failed: bool = False
112
+
113
+ def pretty_str(self) -> str:
114
+ """
115
+ Convert DeviceInfo to a formatted string for display.
116
+ """
117
+ info = ""
118
+
119
+ if self.initialization_failed:
120
+ return f"{Colors.BOLD}- CUDA initialization failed{Colors.RESET}"
121
+
122
+ if self.error_message:
123
+ return f"{Colors.BOLD}- Failed to get GPU info: {self.error_message}{Colors.RESET}"
124
+
125
+ if self.device_count > 0:
126
+ info += f"{Colors.BOLD}- CUDA devices available: {self.device_count} (current: {self.current_device})\n"
127
+
128
+ if self.major_version is not None and self.minor_version is not None:
129
+ info += f"- Architecture: {Colors.BLUE}{self.arch_name}{Colors.RESET} ({Colors.GREEN}{self.sm_arch}{Colors.RESET})\n"
130
+ info += f"- Compatible SM archs: {Colors.GREEN}{', '.join(self.compatible_archs or [])}{Colors.RESET}\n"
131
+
132
+ if self.memory_gb is not None:
133
+ info += f"- Total Memory: {Colors.BLUE}{self.memory_gb:.2f} GB{Colors.RESET}\n"
134
+
135
+ else:
136
+ info += f"- Compute capability: unknown\n"
137
+ info += f"- SM arch: unknown{Colors.RESET}\n"
138
+ else:
139
+ info += f"- No devices available\n"
140
+
141
+ return info
142
+
143
+
144
+ def get_device_info() -> DeviceInfo:
145
+ """
146
+ Get detailed information about CUDA devices.
147
+ Returns a DeviceInfo dataclass with device information.
148
+ """
149
+ device_info = DeviceInfo()
150
+
151
+ # Initialize CUDA if not already initialized
152
+ try:
153
+ result = cuda.cuInit(0)
154
+ if result[0].value: # Check for error
155
+ device_info.initialization_failed = True
156
+ return device_info
157
+ except:
158
+ pass
159
+
160
+ try:
161
+ # Get device count
162
+ result = cuda.cuDeviceGetCount()
163
+ device_info.device_count = result[1] if result[0].value == 0 else 0
164
+
165
+ if device_info.device_count > 0:
166
+ # Get current device
167
+ try:
168
+ result = cuda.cuCtxGetDevice()
169
+ if result[0].value == 0:
170
+ device_info.current_device = result[1]
171
+ except:
172
+ pass
173
+
174
+ # Get device name
175
+ try:
176
+ name_result = cuda.cuDeviceGetName(100, device_info.current_device)
177
+ if name_result[0].value == 0:
178
+ device_info.device_name = name_result[1]
179
+ except:
180
+ pass
181
+
182
+ # Get compute capability and architecture info
183
+ try:
184
+ major, minor = get_compute_capability_major_minor(
185
+ device_info.current_device
186
+ )
187
+
188
+ # Check if we successfully got the compute capability
189
+ if major is not None and minor is not None:
190
+ device_info.major_version = major
191
+ device_info.minor_version = minor
192
+
193
+ arch_name, sm_arch, compatible_archs = _get_gpu_arch_info(
194
+ device_info.major_version, device_info.minor_version
195
+ )
196
+
197
+ device_info.arch_name = arch_name
198
+ device_info.sm_arch = sm_arch
199
+ device_info.compatible_archs = compatible_archs
200
+
201
+ # Get memory info
202
+ try:
203
+ total_mem = cuda.cuDeviceGetAttribute(
204
+ cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_TOTAL_MEMORY,
205
+ device_info.current_device,
206
+ )
207
+ if total_mem[0].value == 0:
208
+ device_info.memory_gb = total_mem[1] / (
209
+ 1024 * 1024 * 1024
210
+ ) # Convert to GB
211
+ except:
212
+ pass
213
+
214
+ except Exception as e:
215
+ pass # Compute capability info will remain None
216
+
217
+ except Exception as e:
218
+ device_info.error_message = str(e)
219
+
220
+ return device_info
221
+
222
+
223
+ def checkCudaErrors(result):
224
+ """Check CUDA errors and provide detailed error messages."""
225
+ if result[0].value:
226
+ error_code = result[0].value
227
+ error_name = _cudaGetErrorEnum(result[0])
228
+
229
+ raise DSLCudaRuntimeError(error_code, error_name)
230
+
231
+ if len(result) == 1:
232
+ return None
233
+ elif len(result) == 2:
234
+ return result[1]
235
+ else:
236
+ return result[1:]
237
+
238
+
239
+ # =============================================================================
240
+ # Driver Helpers
241
+ # =============================================================================
242
+
243
+
244
+ @lru_cache(maxsize=1)
245
+ def initialize_cuda_context(device_id: int = 0, flags: int = 0):
246
+ """
247
+ Initializes the CUDA context for a specified device.
248
+ """
249
+ # Initialize CUDA Driver API
250
+ _log().info(f"cuInit {flags}")
251
+ checkCudaErrors(cuda.cuInit(flags))
252
+ # Retrieve handle for device
253
+ _log().info(f"cuDeviceGet {device_id}")
254
+ cuDevice = checkCudaErrors(cuda.cuDeviceGet(device_id))
255
+ _log().info(f"{cuDevice} <-- cuDeviceGet")
256
+ # Create context
257
+ _log().info(f"cuCtxCreate {0} {cuDevice}")
258
+ if cuda.CUDA_VERSION >= 13000:
259
+ # Use cuCtxCreate_v4 API with explicit CUctxCreateParams None, since v2
260
+ # and v3 API has been removed from CTK 13.
261
+ # See https://github.com/NVIDIA/cuda-python/pull/792
262
+ context = checkCudaErrors(cuda.cuCtxCreate(None, 0, cuDevice))
263
+ else:
264
+ context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice))
265
+ _log().info(f"{context} <-- cuCtxCreate")
266
+
267
+ return context
268
+
269
+
270
+ def load_cubin_module(cubin_file):
271
+ """
272
+ Loads a CUBIN file and returns the module.
273
+ """
274
+ # Load CUBIN file as binary data
275
+ _log().info(f"read cubin {cubin_file}")
276
+ with open(cubin_file, "rb") as f:
277
+ cubin_data = f.read()
278
+ # Load module data
279
+ _log().info(f"cuModuleLoadData {np.char.array(cubin_data).ctypes.data}")
280
+ module = checkCudaErrors(
281
+ cuda.cuModuleLoadData(np.char.array(cubin_data).ctypes.data)
282
+ )
283
+ return module
284
+
285
+
286
+ def unload_cubin_module(module):
287
+ """
288
+ Unloads a CUBIN module.
289
+ """
290
+ _log().info(f"cuModuleUnload {module}")
291
+ checkCudaErrors(cuda.cuModuleUnload(module))
292
+
293
+
294
+ def load_cubin_module_data(cubin_data):
295
+ """
296
+ Loads a CUBIN from data and returns the module.
297
+ """
298
+ # Load module data
299
+ _log().info(f"cuModuleLoadData {np.char.array(cubin_data).ctypes.data}")
300
+ module = checkCudaErrors(
301
+ cuda.cuModuleLoadData(np.char.array(cubin_data).ctypes.data)
302
+ )
303
+ return module
304
+
305
+
306
+ def get_kernel_function(module, kernel_name):
307
+ """
308
+ Retrieves the kernel function from the module.
309
+ """
310
+ _log().info(f"cuModuleGetFunction {module} {kernel_name}")
311
+ kernel = checkCudaErrors(
312
+ cuda.cuModuleGetFunction(module, bytes(kernel_name, "utf-8"))
313
+ )
314
+ _log().info(f"{kernel} <-- cuModuleGetFunction")
315
+ return kernel
316
+
317
+
318
+ def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size, kernel_args=None):
319
+ """
320
+ Launches the CUDA kernel.
321
+ """
322
+ _log().info(
323
+ f"cuLaunchKernel {kernel} grid={grid_dims} blocks={block_dims} smem_size={smem_size} stream={stream} {kernel_args}"
324
+ )
325
+ checkCudaErrors(
326
+ cuda.cuLaunchKernel(
327
+ kernel,
328
+ grid_dims[0],
329
+ grid_dims[1],
330
+ grid_dims[2],
331
+ block_dims[0],
332
+ block_dims[1],
333
+ block_dims[2],
334
+ smem_size, # Shared memory size
335
+ stream,
336
+ kernel_args,
337
+ 0, # Extra parameters
338
+ )
339
+ )
340
+
341
+
342
+ def stream_sync(stream):
343
+ """
344
+ Synchronizes the CUDA stream.
345
+ """
346
+ _log().info(f"cuStreamSynchronize {stream}")
347
+ checkCudaErrors(cuda.cuStreamSynchronize(stream))
348
+
349
+
350
+ def stream_create(id=0):
351
+ """
352
+ Creates the CUDA stream.
353
+ """
354
+ _log().info(f"cuStreamCreate {id}")
355
+ stream = checkCudaErrors(cuda.cuStreamCreate(id))
356
+ _log().info(f"{stream} <-- cuStreamCreate")
357
+ return stream
358
+
359
+
360
+ def stream_destroy(stream):
361
+ """
362
+ Destroys the CUDA stream.
363
+ """
364
+ _log().info(f"cuStreamDestroy {stream}")
365
+ checkCudaErrors(cuda.cuStreamDestroy(stream))
366
+
367
+
368
+ def context_destroy(context):
369
+ """
370
+ Destroys the CUDA context.
371
+ """
372
+ _log().info(f"cuCtxDestroy {context}")
373
+ checkCudaErrors(cuda.cuCtxDestroy(context))
374
+
375
+
376
+ def allocate(size_in_bytes: int, stream=None):
377
+ """
378
+ Allocate device memory based on numpy host array size.
379
+ """
380
+ _log().info("Allocate size_in_bytes=[%s] stream=[%s]", size_in_bytes, stream)
381
+ if stream is None:
382
+ device_memory = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes))
383
+ else:
384
+ device_memory = checkCudaErrors(cuda.cuMemAllocAsync(size_in_bytes, stream))
385
+ _log().info("Allocated [%s]", device_memory)
386
+ return device_memory
387
+
388
+
389
+ def deallocate(device_pointer, stream=None):
390
+ """
391
+ Deallocate the specified device memory pointer.
392
+ """
393
+ _log().info(
394
+ "Deallocate device_pointer=[%s] stream=[%s]", hex(int(device_pointer)), stream
395
+ )
396
+ if stream is None:
397
+ checkCudaErrors(cuda.cuMemFree(device_pointer))
398
+ else:
399
+ checkCudaErrors(cuda.cuMemFreeAsync(device_pointer, stream))
400
+
401
+
402
+ def memcpy_h2d(host_pointer, device_pointer, size_in_bytes, stream=None):
403
+ """
404
+ Copy data from host to device memory.
405
+ """
406
+ _log().info(
407
+ "Copy host-to-device host_pointer[%s] device_ptr=[%s] size_in_bytes=[%s] stream=[%s]",
408
+ hex(host_pointer),
409
+ hex(int(device_pointer)),
410
+ size_in_bytes,
411
+ stream,
412
+ )
413
+ if stream is None:
414
+ checkCudaErrors(cuda.cuMemcpyHtoD(device_pointer, host_pointer, size_in_bytes))
415
+ else:
416
+ checkCudaErrors(
417
+ cuda.cuMemcpyHtoDAsync(device_pointer, host_pointer, size_in_bytes, stream)
418
+ )
419
+
420
+
421
+ def memcpy_d2h(host_pointer, device_pointer, size_in_bytes, stream=None):
422
+ """
423
+ Copy data from device to host memory.
424
+ """
425
+ _log().info(
426
+ "Copy device-host-to device_pointer=[%s] host_pointer[%s] size_in_bytes=[%s] stream=[%s]",
427
+ hex(int(device_pointer)),
428
+ hex(host_pointer),
429
+ size_in_bytes,
430
+ stream,
431
+ )
432
+ if stream is None:
433
+ checkCudaErrors(cuda.cuMemcpyDtoH(host_pointer, device_pointer, size_in_bytes))
434
+ else:
435
+ checkCudaErrors(
436
+ cuda.cuMemcpyDtoHAsync(host_pointer, device_pointer, size_in_bytes, stream)
437
+ )
438
+
439
+
440
+ def default_stream():
441
+ return cuda.CUstream(0)
442
+
443
+
444
+ def get_driver_version():
445
+ """
446
+ Returns the CUDA driver version.
447
+ """
448
+ return checkCudaErrors(cuda.cuDriverGetVersion())
449
+
450
+
451
+ def set_kernel_attribute(kernel, attribute, value):
452
+ """
453
+ Sets a CUDA kernel attribute.
454
+ """
455
+ return checkCudaErrors(cuda.cuFuncSetAttribute(kernel, attribute, value))
456
+
457
+
458
+ @JitArgAdapterRegistry.register_jit_arg_adapter(cuda.CUstream)
459
+ class StreamAdapter:
460
+ """
461
+ Convert a CUDA stream to a stream representation for JIT arg generation.
462
+ """
463
+
464
+ def __init__(self, arg):
465
+ self._arg = arg
466
+ self._c_pointer = self._arg.getPtr()
467
+
468
+ def __new_from_mlir_values__(self, values):
469
+ assert len(values) == 1
470
+ return values[0]
471
+
472
+ def __c_pointers__(self):
473
+ return [self._c_pointer]
474
+
475
+ def __get_mlir_types__(self):
476
+ return [gpu.AsyncTokenType.get()]
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/device_tensor.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ import copy
13
+
14
+ from . import cuda as cuda_helpers
15
+ from .tensor_descriptor import *
16
+ from ..common import *
17
+
18
+
19
+ def allocate(tensor: TensorDescriptor, stream=None):
20
+ """
21
+ Allocates GPU memory
22
+ """
23
+ if tensor._check_is_managed_by_framework():
24
+ raise DSLRuntimeError(
25
+ "GPU tensors are managed by the framework and cannot be modified."
26
+ )
27
+ if not tensor.device_pointer is None:
28
+ raise DSLRuntimeError("Tensor is already allocated on the device.")
29
+
30
+ tensor.device_pointer = cuda_helpers.allocate(tensor.size_in_bytes, stream)
31
+
32
+ log().info("Allocate done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
33
+
34
+
35
+ def deallocate(tensor: TensorDescriptor, stream=None):
36
+ """
37
+ Deallocates GPU memory
38
+ """
39
+ if tensor._check_is_managed_by_framework():
40
+ raise DSLRuntimeError(
41
+ "GPU tensors are managed by the framework and cannot be modified."
42
+ )
43
+ if tensor.device_pointer is None:
44
+ raise DSLRuntimeError("Tensor is not allocated on the device.")
45
+
46
+ log().info(
47
+ "Deallocating done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer
48
+ )
49
+
50
+ cuda_helpers.deallocate(tensor.device_pointer, stream)
51
+ tensor.device_pointer = None
52
+
53
+
54
+ def copy_to_gpu(tensor: TensorDescriptor, do_allocate=True, stream=None):
55
+ """
56
+ Copies data from host memory to the GPU memory.
57
+ If do_allocate is True, it first calls allocate
58
+ """
59
+ log().info("copyin tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
60
+ if do_allocate:
61
+ allocate(tensor, stream)
62
+ cuda_helpers.memcpy_h2d(
63
+ tensor.data_ptr, tensor.device_pointer, tensor.size_in_bytes, stream
64
+ )
65
+ log().info("copyin done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
66
+ return tensor
67
+
68
+
69
+ def copy_from_gpu(tensor: TensorDescriptor, do_deallocate=True, stream=None):
70
+ """
71
+ Copies data from GPU memory back to the host.
72
+ If do_deallocate is True, it calls deallocate
73
+ """
74
+ log().info("copyout tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
75
+ if tensor._check_is_managed_by_framework():
76
+ raise DSLRuntimeError(
77
+ "GPU tensors are managed by the framework and cannot be modified."
78
+ )
79
+ if tensor.device_pointer is None:
80
+ raise DSLRuntimeError("Tensor is not allocated on the device.")
81
+
82
+ cuda_helpers.memcpy_d2h(
83
+ tensor.data_ptr, tensor.device_pointer, tensor.size_in_bytes, stream
84
+ )
85
+ if do_deallocate:
86
+ deallocate(tensor, stream)
87
+ log().info("copyout done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
88
+
89
+
90
+ def to_gpu(tensor, stream=None) -> TensorDescriptor:
91
+ """
92
+ Copies the tensor to the GPU memory from Host memory
93
+ """
94
+ if isinstance(tensor, TensorDescriptor):
95
+ new_tensor = copy.copy(tensor)
96
+ copy_to_gpu(new_tensor, stream=stream)
97
+ return new_tensor
98
+
99
+ if TensorDescriptor.can_transformed_to_dlpack(tensor):
100
+ new_tensor = TensorDescriptor(tensor)
101
+ copy_to_gpu(new_tensor, stream=stream)
102
+ return new_tensor
103
+
104
+ raise DSLRuntimeError("Unsupported type")
105
+
106
+
107
+ def from_gpu(tensor, stream=None) -> TensorDescriptor:
108
+ """
109
+ Copies the tensor to the GPU memory from Host memory
110
+ """
111
+ if isinstance(tensor, TensorDescriptor):
112
+ new_tensor = copy.copy(tensor)
113
+ copy_from_gpu(new_tensor, stream=stream)
114
+ return new_tensor
115
+
116
+ if TensorDescriptor.can_transformed_to_dlpack(tensor):
117
+ new_tensor = TensorDescriptor(tensor)
118
+ copy_from_gpu(new_tensor, stream=stream)
119
+ return new_tensor
120
+
121
+ raise DSLRuntimeError("Unsupported type")
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/dlpack_types.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ """
13
+ This module provides helper structs for dlpack.
14
+ DLPack is an open standard for in-memory tensor structures, enabling
15
+ seamless sharing of tensors across different frameworks.
16
+ Learn more at: https://github.com/dmlc/dlpack
17
+ """
18
+
19
+ import ctypes
20
+ import enum
21
+
22
+
23
+ class DLDeviceType(enum.IntEnum):
24
+ """Enums for device types based on the DLPack specification."""
25
+
26
+ kDLCPU = 1
27
+ kDLGPU = 2
28
+ kDLCPUPinned = 3
29
+
30
+
31
+ class DLDataTypeCode:
32
+ """Enums for data type codes based on the DLPack specification.
33
+
34
+ see https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h
35
+ """
36
+
37
+ kDLInt = 0
38
+ kDLUInt = 1
39
+ kDLFloat = 2
40
+ kDLOpaqueHandle = 3
41
+ kDLBfloat = 4
42
+ kDLComplex = 5
43
+ kDLBool = 6
44
+
45
+
46
+ class DLDevice(ctypes.Structure):
47
+ """Structure representing the device information in DLPack."""
48
+
49
+ _fields_ = [
50
+ ("device_type", ctypes.c_int), # kDLCPU, kDLGPU, etc.
51
+ ("device_id", ctypes.c_int), # Device ID (e.g., GPU ID)
52
+ ]
53
+
54
+
55
+ class DLDataType(ctypes.Structure):
56
+ """Structure representing the data type in DLPack."""
57
+
58
+ _fields_ = [
59
+ ("code", ctypes.c_uint8), # Data type code (e.g., kDLFloat)
60
+ ("bits", ctypes.c_uint8), # Number of bits per value
61
+ ("lanes", ctypes.c_uint16), # Number of lanes
62
+ ]
63
+
64
+
65
+ class DLTensor(ctypes.Structure):
66
+ """Structure representing the DLTensor in DLPack."""
67
+
68
+ _fields_ = [
69
+ ("data", ctypes.c_void_p), # Pointer to tensor data
70
+ ("device", DLDevice), # Device info
71
+ ("ndim", ctypes.c_int), # Number of dimensions
72
+ ("dtype", DLDataType), # Data type
73
+ ("shape", ctypes.POINTER(ctypes.c_int64)), # Shape of tensor
74
+ ("strides", ctypes.POINTER(ctypes.c_int64)), # Strides of tensor
75
+ ("byte_offset", ctypes.c_uint64), # Byte offset to tensor data
76
+ ]
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ """
13
+ This module provides runtime utilities for JIT argument conversion in DSL.
14
+ """
15
+
16
+ from functools import wraps
17
+ from typing import get_origin
18
+
19
+ # Local modules imports
20
+ from ..common import DSLRuntimeError
21
+ from ..typing import (
22
+ Constexpr,
23
+ Int32,
24
+ Float32,
25
+ Boolean,
26
+ )
27
+
28
+
29
+ def is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func):
30
+ """
31
+ Check if the argument spec is a constexpr.
32
+ """
33
+
34
+ def _is_reserved_python_func_arg(arg_index, arg_name, func):
35
+ """
36
+ Check if the argument is a reserved python function argument.
37
+ """
38
+
39
+ if arg_index != 0:
40
+ return False
41
+
42
+ if arg_name == "self":
43
+ return True
44
+
45
+ is_classmethod = isinstance(func, classmethod) or (
46
+ hasattr(func, "__func__") and isinstance(func.__func__, classmethod)
47
+ )
48
+ return arg_name == "cls" and is_classmethod
49
+
50
+ return (
51
+ _is_reserved_python_func_arg(arg_index, arg_name, owning_func)
52
+ or (isinstance(arg_spec, type) and issubclass(arg_spec, Constexpr))
53
+ or (get_origin(arg_spec) is Constexpr)
54
+ )
55
+
56
+
57
+ def is_argument_constexpr(arg, arg_spec, arg_name, arg_index, owning_func):
58
+ """
59
+ Check if the argument is a constexpr.
60
+ """
61
+
62
+ def _is_type_argument(arg, arg_annotation):
63
+ """
64
+ Check if the argument is a type argument like Type[X]
65
+ """
66
+
67
+ return isinstance(arg, type) and (
68
+ arg_annotation is None or get_origin(arg_annotation) is type
69
+ )
70
+
71
+ return (
72
+ is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func)
73
+ or _is_type_argument(arg, arg_spec)
74
+ or arg is None
75
+ )
76
+
77
+
78
+ class JitArgAdapterRegistry:
79
+ """
80
+ A registry to keep track of the JIT argument adapters.
81
+
82
+ An adapter is a callable that converts a Python type to a type with following protocols supported:
83
+ - JitArgument
84
+ - DynamicExpression
85
+ The converted type can then be further processed by DSL to generate arguments for JIT functions.
86
+ """
87
+
88
+ # A dictionary with key=type and value=callable
89
+ jit_arg_adapter_registry = {}
90
+
91
+ @classmethod
92
+ def register_jit_arg_adapter(cls, *dargs, **dkwargs):
93
+ """
94
+ Register a JIT argument adapter callable
95
+
96
+ This can be used as a decorator on any callable like:
97
+
98
+ @register_jit_arg_adapter(my_py_type)
99
+ def my_adapter_for_my_py_type(arg):
100
+ ...
101
+
102
+ @register_jit_arg_adapter(my_py_type)
103
+ class MyAdapterForMyPythonType:
104
+ ...
105
+
106
+ The adapters are registered per type. If a type is already registerd, an error will be raised.
107
+ """
108
+
109
+ def decorator(*dargs, **dkwargs):
110
+ darg_python_ty = dargs[0]
111
+
112
+ @wraps(darg_python_ty)
113
+ def wrapper(*args, **kwargs):
114
+ if len(args) != 1 or not callable(args[0]):
115
+ raise DSLRuntimeError(
116
+ "a callable must be provided for registering JIT argument adapter"
117
+ )
118
+ adapter = args[0]
119
+
120
+ if darg_python_ty in cls.jit_arg_adapter_registry:
121
+ raise DSLRuntimeError(
122
+ f"JIT argument adapter for {darg_python_ty} is already registered!",
123
+ context={
124
+ "Registered adapter": cls.jit_arg_adapter_registry[
125
+ darg_python_ty
126
+ ],
127
+ "Adapter to be registered": adapter,
128
+ },
129
+ )
130
+ cls.jit_arg_adapter_registry[darg_python_ty] = adapter
131
+ return adapter
132
+
133
+ return wrapper
134
+
135
+ if len(dargs) > 0:
136
+ return decorator(*dargs, **dkwargs)
137
+ else:
138
+ raise DSLRuntimeError(
139
+ "a Python type must be provided for registering JIT argument adapter"
140
+ )
141
+
142
+ @classmethod
143
+ def get_registered_adapter(cls, ty):
144
+ """
145
+ Get the registered JIT argument adapter for the given type.
146
+ """
147
+ return cls.jit_arg_adapter_registry.get(ty, None)
148
+
149
+
150
+ # =============================================================================
151
+ # JIT Argument Adapters
152
+ # =============================================================================
153
+
154
+
155
+ @JitArgAdapterRegistry.register_jit_arg_adapter(int)
156
+ @JitArgAdapterRegistry.register_jit_arg_adapter(float)
157
+ @JitArgAdapterRegistry.register_jit_arg_adapter(bool)
158
+ def _convert_python_scalar(arg):
159
+ """
160
+ Convert a Python scalar to a DSL type.
161
+ """
162
+ conversion_map = {
163
+ int: Int32,
164
+ float: Float32,
165
+ bool: Boolean,
166
+ }
167
+ return conversion_map.get(type(arg))(arg)
168
+
169
+
170
+ @JitArgAdapterRegistry.register_jit_arg_adapter(tuple)
171
+ @JitArgAdapterRegistry.register_jit_arg_adapter(list)
172
+ def _convert_python_sequence(arg):
173
+ """
174
+ Go through each element in the sequence and convert it to a type that can be
175
+ further processed by DSL to generate the corresponding JIT argument(s).
176
+ """
177
+ adapted_arg = []
178
+ for elem in arg:
179
+ adapter = JitArgAdapterRegistry.get_registered_adapter(type(elem))
180
+ if adapter is not None:
181
+ converted_elem = adapter(elem)
182
+ adapted_arg.append(converted_elem)
183
+ else:
184
+ # If no registered adapter is found, just return the original element
185
+ adapted_arg.append(elem)
186
+
187
+ assert len(adapted_arg) == len(arg)
188
+ return type(arg)(adapted_arg)
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ # Helpers
13
+ import itertools, operator
14
+ import ctypes
15
+ from . import dlpack_types as _dpack
16
+ from .dlpack_runtime import (
17
+ dlpack_to_tensor_desc,
18
+ get_tensor_desc_data_ptr,
19
+ get_tensor_desc_is_in_device,
20
+ get_tensor_desc_element_type,
21
+ get_tensor_desc_shape,
22
+ get_tensor_desc_stride,
23
+ get_tensor_desc_element_size_in_bytes,
24
+ get_tensor_desc_ndim,
25
+ get_tensor_desc_dtype_code,
26
+ get_tensor_desc_dtype_bits,
27
+ get_tensor_desc_device_type,
28
+ get_tensor_desc_device_id,
29
+ )
30
+
31
+ from ..utils.logger import log
32
+ from ..common import *
33
+ from ..typing import (
34
+ Boolean,
35
+ Float8E5M2,
36
+ Int64,
37
+ Int32,
38
+ Int16,
39
+ Int8,
40
+ Uint64,
41
+ Uint32,
42
+ Uint16,
43
+ Uint8,
44
+ Float64,
45
+ Float32,
46
+ Float16,
47
+ BFloat16,
48
+ )
49
+
50
+
51
+ class TensorDescriptor:
52
+ def __init__(self, tensor):
53
+ """Initialize with a tensor that supports the DLPack protocol.
54
+
55
+ Args:
56
+ tensor: Any tensor object that implements __dlpack__ and __dlpack_device__
57
+ """
58
+
59
+ self.tensor = tensor
60
+ self._capsule = dlpack_to_tensor_desc(tensor)
61
+
62
+ self.data_ptr = get_tensor_desc_data_ptr(self._capsule)
63
+ self.device_type = get_tensor_desc_device_type(self._capsule)
64
+ self.device_type = _dpack.DLDeviceType(self.device_type)
65
+
66
+ if self.device_type == _dpack.DLDeviceType.kDLGPU:
67
+ self.device_pointer = self.data_ptr
68
+ elif self.device_type == _dpack.DLDeviceType.kDLCPU:
69
+ self.device_pointer = None
70
+ else:
71
+ raise DSLRuntimeError(
72
+ f"DLPack device type is not supported {self.dl_tensor.device.device_type}"
73
+ )
74
+
75
+ log().info("TensorDescriptor is created = [%s]", self)
76
+
77
+ @staticmethod
78
+ def can_transformed_to_dlpack(dl_tensor):
79
+ if not hasattr(dl_tensor, "__dlpack__") or not hasattr(
80
+ dl_tensor, "__dlpack_device__"
81
+ ):
82
+ return False
83
+ return True
84
+
85
+ @property
86
+ def is_in_device(self):
87
+ """Check if the tensor is stored on a device."""
88
+ return not self.device_pointer is None
89
+
90
+ @property
91
+ def device_id(self):
92
+ """Return device id where tensor resides."""
93
+ if self.is_in_device:
94
+ return get_tensor_desc_device_id(self._capsule)
95
+ return -1
96
+
97
+ @property
98
+ def element_type(self):
99
+ """Return the corresponding Python type based on DLPack dtype metadata."""
100
+ str_element_type = get_tensor_desc_element_type(self._capsule)
101
+ dtype_map = {
102
+ # bool is 8bit from numpy and torch
103
+ "Bool": Boolean,
104
+ "Int64": Int64,
105
+ "Int32": Int32,
106
+ "Int16": Int16,
107
+ "Int8": Int8,
108
+ "UInt64": Uint64,
109
+ "UInt32": Uint32,
110
+ "UInt16": Uint16,
111
+ "UInt8": Uint8,
112
+ "Float64": Float64,
113
+ "Float32": Float32,
114
+ "Float16": Float16,
115
+ "BFloat16": BFloat16,
116
+ "Float8E5M2": Float8E5M2,
117
+ }
118
+
119
+ if str_element_type not in dtype_map:
120
+ raise KeyError(
121
+ f"Unsupported element type in dlpack: '{str_element_type}'. Supported types are: {list(dtype_map.keys())}"
122
+ )
123
+
124
+ return dtype_map[str_element_type]
125
+
126
+ @property
127
+ def shape(self):
128
+ """Return the shape of the tensor."""
129
+ return get_tensor_desc_shape(self._capsule)
130
+
131
+ @property
132
+ def rank(self):
133
+ """Return the rank of the tensor."""
134
+ return get_tensor_desc_ndim(self._capsule)
135
+
136
+ @property
137
+ def strides(self):
138
+ """Return the rank of the tensor."""
139
+ return get_tensor_desc_stride(self._capsule)
140
+
141
+ @property
142
+ def element_size_in_bytes(self):
143
+ """Calculate the element size in bytes of the DLPack tensor."""
144
+ return get_tensor_desc_element_size_in_bytes(self._capsule)
145
+
146
+ @property
147
+ def size_in_bytes(self):
148
+ """Calculate the total size in bytes of the DLPack tensor."""
149
+ # Calculate the number of elements using the shape
150
+ ndim = get_tensor_desc_ndim(self._capsule)
151
+ shape = get_tensor_desc_shape(self._capsule)
152
+ num_elements = 1
153
+ for i in range(ndim):
154
+ num_elements *= shape[i]
155
+
156
+ # Total bytes
157
+ total_bytes = self.element_size_in_bytes * num_elements
158
+ return total_bytes
159
+
160
+ def __str__(self):
161
+ """Return a compact string representation of the device_tensor with a tensor prefix."""
162
+ # Extract shape
163
+ shape = "x".join(map(str, self.shape))
164
+
165
+ # Extract dtype
166
+ dtype_code = get_tensor_desc_dtype_code(self._capsule)
167
+ dtype_bits = get_tensor_desc_dtype_bits(self._capsule)
168
+ dtype = (
169
+ f"i{dtype_bits}"
170
+ if dtype_code == _dpack.DLDataTypeCode.kDLInt
171
+ else f"f{dtype_bits}"
172
+ )
173
+
174
+ # Extract device
175
+ device_type = "cpu" if not self.is_in_device else "gpu"
176
+
177
+ return f"tensor<{shape}x{dtype}>_{device_type}"
178
+
179
+ def _check_is_managed_by_framework(self):
180
+ """
181
+ Ensure the tensor is not managed by the framework (e.g., GPU tensor).
182
+ Raises an exception if the tensor is framework-managed.
183
+ """
184
+ return self.device_type == _dpack.DLDeviceType.kDLGPU
185
+
186
+ @staticmethod
187
+ def is_compatible(maybe_tensor_descriptor) -> bool:
188
+ """Check if the object is a TensorDescriptor or can be converted to one."""
189
+ return isinstance(
190
+ maybe_tensor_descriptor, TensorDescriptor
191
+ ) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)
192
+
193
+
194
+ def from_tensor(tensor) -> TensorDescriptor:
195
+ """Create a TensorDescriptor from a tensor object."""
196
+ return TensorDescriptor(tensor)
197
+
198
+
199
+ def to_tensor(tensor_descriptor: TensorDescriptor):
200
+ """Return tensor object from tensor descriptor."""
201
+ return tensor_descriptor.tensor
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/typing.py ADDED
@@ -0,0 +1,1962 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ import ctypes
13
+ import numpy as np
14
+ import operator
15
+ from typing_extensions import deprecated
16
+ from functools import reduce
17
+ from typing import (
18
+ Generic,
19
+ Protocol,
20
+ Union,
21
+ Any,
22
+ List,
23
+ Type,
24
+ TypeVar,
25
+ overload,
26
+ runtime_checkable,
27
+ get_origin,
28
+ )
29
+ from types import FunctionType
30
+ from dataclasses import dataclass
31
+ from abc import ABC, abstractmethod
32
+
33
+ from .common import *
34
+ from .ast_helpers import const_expr
35
+ from ._mlir_helpers import arith as arith_helper, lru_cache_ir
36
+ from ._mlir_helpers.arith import ArithValue
37
+
38
+ from .._mlir import ir
39
+ from .._mlir.extras import types as T
40
+ from .._mlir.dialects import arith, math
41
+
42
+ # =============================================================================
43
+ # Dynamic Expression Protocol
44
+ # =============================================================================
45
+
46
+
47
+ @runtime_checkable
48
+ class DynamicExpression(Protocol):
49
+ """Protocol defining the interface for object holding dynamic values in the DSL.
50
+
51
+ This protocol enables classes to represent dynamic values in the DSL. Classes implementing
52
+ this protocol can be used in JIT-compiled functions and dynamic value generation.
53
+
54
+ It is required for custom data types to work correctly with following JIT features:
55
+ * as function argument to call another JIT function from JIT function
56
+ * as return value from JIT function
57
+ * for constructions like if-else, while-loop, etc.
58
+
59
+ :param value: The MLIR operation result value to initialize the object with
60
+ :type value: ir.Value
61
+
62
+ **Required Methods**
63
+
64
+ * ``__extract_mlir_values__``: Extract MLIR values from the object
65
+ * ``__new_from_mlir_values__``: Create new instance from MLIR values
66
+
67
+ **Implementation Example**
68
+
69
+ To implement a custom data type that works with the DSL:
70
+
71
+ .. code-block:: python
72
+
73
+ class CustomData(metaclass=DslType):
74
+ def __init__(self, int_value):
75
+ self.int_value = int_value
76
+
77
+ def __extract_mlir_values__(self):
78
+ return [self.int_value]
79
+
80
+ def __new_from_mlir_values__(self, values):
81
+ return CustomData(values[0])
82
+
83
+ **Usage in JIT Functions**
84
+
85
+ When used in JIT-compiled functions, the DSL automatically extracts MLIR values:
86
+
87
+ .. code-block:: python
88
+
89
+ @jit
90
+ def caller():
91
+ x = CustomData(1)
92
+ return foo(x)
93
+
94
+ This generates MLIR like:
95
+
96
+ .. code-block:: mlir
97
+
98
+ func @caller() -> i32 {
99
+ %0 = func.call @foo(%arg0) : (i32) -> i32
100
+ return %0 : i32
101
+ }
102
+ """
103
+
104
+ def __extract_mlir_values__(self):
105
+ """Extract MLIR values from this object.
106
+
107
+ :return: List of MLIR values representing this object's data
108
+ :rtype: List[ir.Value]
109
+ """
110
+ raise NotImplementedError
111
+
112
+ def __new_from_mlir_values__(self, values):
113
+ """Create a new instance from MLIR values.
114
+
115
+ :param values: List of MLIR values to construct the object from
116
+ :type values: List[ir.Value]
117
+ :return: New instance of the implementing class
118
+ :rtype: Any
119
+ """
120
+ raise NotImplementedError
121
+
122
+
123
+ @runtime_checkable
124
+ class JitArgument(Protocol):
125
+ """
126
+ Protocol class defining the interface for JIT function argument generation.
127
+
128
+ This protocol enables classes to provide the necessary information for generating
129
+ JIT function arguments and allow the DSL JIT executor to call JIT compiled functions.
130
+
131
+ **Required Methods**
132
+
133
+ * ``__c_pointers__``: Returns ctypes pointers for runtime execution
134
+ * ``__get_mlir_types__``: Returns MLIR types for function definition
135
+ * ``__new_from_mlir_values__``: Creates new instances from MLIR values
136
+
137
+ **Example**
138
+
139
+ .. code-block:: python
140
+
141
+ class CustomData:
142
+ def __init__(self, int_value, ...):
143
+ self.int_value = int_value
144
+ ...
145
+
146
+ def __c_pointers__(self):
147
+ return [ctypes.pointer(ctypes.c_int32(self.int_value)), ...]
148
+
149
+ def __get_mlir_types__(self):
150
+ return [ir.IntegerType.get(32), ...]
151
+
152
+ def __new_from_mlir_values__(self, values):
153
+ return CustomData(values[0], ...)
154
+
155
+ @jit
156
+ def foo(x: CustomData):
157
+ a = x.int_value + 1
158
+ ...
159
+
160
+ # `CustomData` is an argument of `foo`
161
+ foo(CustomData(1, ...))
162
+
163
+ When called like ``y = foo(x)``, the following steps occur:
164
+
165
+ 1. JIT compiler generates MLIR function definition using ``__get_mlir_types__``
166
+
167
+ .. code-block:: mlir
168
+
169
+ func.func @foo(%arg0: i32, ...) {
170
+ ...
171
+
172
+ return
173
+ }
174
+
175
+ 2. JIT function can't use values from Python, so it needs to reconstruct the object from
176
+ MLIR values, a.k.a `%arg0`, with ``__new_from_mlir_values__`` and pass it to `foo`.
177
+
178
+ Following code demonstrates how JIT compiler reconstructs the object and pass to Python.
179
+
180
+ .. code-block:: python
181
+
182
+ # Implementation of IR tracing
183
+ new_x = CustomData(ir.Value(%arg0), ...)
184
+ y = foo(new_x)
185
+ # `x.int_value` is %arg0 rather than `c1` defined by Python.
186
+
187
+ 3. For Python runtime execution, JIT engine invokes compiled function using ``__c_pointers__``
188
+ pointing to the underlying data object passing to JIT compiled function.
189
+
190
+ .. code-block:: python
191
+
192
+ jit_engine.invoke(compiled_foo, concat([x.__c_pointers__(), ...]))
193
+ """
194
+
195
+ def __c_pointers__(self):
196
+ """
197
+ Generate a list of ctypes pointers for the current object.
198
+
199
+ :return: List of ctypes pointers
200
+ :rtype: List[ctypes.c_void_p]
201
+ """
202
+ raise NotImplementedError
203
+
204
+ def __get_mlir_types__(self):
205
+ """
206
+ Generate a list of MLIR types for the current object.
207
+
208
+ :return: List of MLIR types
209
+ :rtype: List[ir.Type]
210
+ """
211
+ raise NotImplementedError
212
+
213
+ def __new_from_mlir_values__(self, values):
214
+ """
215
+ Create a new object from MLIR values.
216
+
217
+ :param values: List of MLIR values
218
+ :type values: List[ir.Value]
219
+ :return: A new object that represents the given MLIR values
220
+ :rtype: Any
221
+ """
222
+ raise NotImplementedError
223
+
224
+
225
+ def get_c_pointers(obj):
226
+ """
227
+ Given the `obj`, recursively go through it to extract all contained C pointers
228
+ """
229
+ if hasattr(obj, "__c_pointers__"):
230
+ return obj.__c_pointers__()
231
+ elif isinstance(obj, (tuple, list)):
232
+ return sum((get_c_pointers(x) for x in obj), [])
233
+ elif isinstance(obj, set):
234
+ raise DSLRuntimeError(
235
+ "Sets are not supported in get_c_pointers to ensure order preservation",
236
+ context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.",
237
+ suggestion="Consider using a list or tuple instead",
238
+ )
239
+ return []
240
+
241
+
242
+ def get_mlir_types(obj):
243
+ """
244
+ Given the `obj`, recursively go through it to extract all contained MLIR types
245
+ """
246
+ if hasattr(obj, "__get_mlir_types__"):
247
+ return obj.__get_mlir_types__()
248
+ elif hasattr(obj, "__extract_mlir_values__"):
249
+ return [v.type for v in obj.__extract_mlir_values__()]
250
+ elif isinstance(obj, ir.Value):
251
+ return [obj.type]
252
+ elif isinstance(obj, (tuple, list)):
253
+ return sum((get_mlir_types(x) for x in obj), [])
254
+ elif isinstance(obj, set):
255
+ raise DSLRuntimeError(
256
+ "Sets are not supported in get_mlir_types to ensure order preservation",
257
+ context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.",
258
+ suggestion="Consider using a list or tuple instead",
259
+ )
260
+ return []
261
+
262
+
263
+ class DslType(type):
264
+ """Metaclass for all DSL types in the system.
265
+
266
+ This metaclass provides type system infrastructure for DSL types, handling MLIR
267
+ type mappings and NumPy type conversions.
268
+
269
+ All data types in DSL must provide the following methods:
270
+
271
+ :param mlir_type: Corresponding MLIR type for this DSL type
272
+ :type mlir_type: Any, optional
273
+ :param is_abstract: Whether this type is abstract, defaults to False
274
+ :type is_abstract: bool, optional
275
+
276
+ **Required Methods**
277
+
278
+ * ``__str__`` (classmethod): Return string representation of the type
279
+ * ``__c_pointers__`` (optional): Return list of ctypes pointers of data used to invoke JIT function
280
+ * ``__get_mlir_types__``: Return list of MLIR types of the MLIR values contained in the instance
281
+ * ``__extract_mlir_values__``: Return list of MLIR values contained in the instance
282
+ * ``__new_from_mlir_values__``: Return a new instance from list of MLIR values
283
+
284
+ **Attributes**
285
+
286
+ :ivar _ir: MLIR provider
287
+ :vartype _ir: Any
288
+ :ivar _T: MLIR Type system provider
289
+ :vartype _T: Any
290
+
291
+ **Properties**
292
+
293
+ :property mlir_type: Returns the corresponding MLIR type for this DSL type
294
+ :type mlir_type: Any
295
+
296
+ """
297
+
298
+ _is_abstract: bool
299
+
300
+ def __new__(cls, name, bases, attrs, is_abstract=False, **kwargs):
301
+ new_cls = super().__new__(cls, name, bases, attrs)
302
+
303
+ new_cls._is_abstract = is_abstract
304
+
305
+ return new_cls
306
+
307
+ @property
308
+ def is_abstract(cls):
309
+ return cls._is_abstract
310
+
311
+
312
+ class NumericMeta(DslType):
313
+ """Metaclass for numeric types providing width and numpy dtype information.
314
+
315
+ :param width: Bit width of the numeric type, defaults to 8
316
+ :type width: int
317
+ :param np_dtype: Corresponding NumPy dtype
318
+ :type np_dtype: numpy.dtype, optional
319
+ :param mlir_type: Corresponding MLIR type
320
+ :type mlir_type: Any, optional
321
+ :param is_abstract: Whether the type is abstract, defaults to False
322
+ :type is_abstract: bool, optional
323
+
324
+ :ivar width: Bit width of the numeric type
325
+ :type width: int
326
+ :ivar _np_dtype: Corresponding NumPy dtype
327
+ :type _np_dtype: Union[numpy.dtype, None]
328
+
329
+ :property numpy_dtype: Returns the corresponding NumPy dtype
330
+ :rtype numpy_dtype: numpy.dtype
331
+ """
332
+
333
+ width: int
334
+
335
+ # Placeholder type
336
+ _mlir_type = Any
337
+ _np_dtype: Union[np.dtype, None]
338
+
339
+ def __new__(
340
+ cls,
341
+ name,
342
+ bases,
343
+ attrs,
344
+ width=8,
345
+ np_dtype=None,
346
+ mlir_type=None,
347
+ is_abstract=False,
348
+ **kwargs,
349
+ ):
350
+ def _extract_mlir_values(self):
351
+ return [self.ir_value()]
352
+
353
+ def _new_from_mlir_values(self, values: list) -> "Numeric":
354
+ res_ty = type(self)
355
+ return res_ty(values[0])
356
+
357
+ new_attrs = {
358
+ "__extract_mlir_values__": _extract_mlir_values,
359
+ "__new_from_mlir_values__": _new_from_mlir_values,
360
+ }
361
+ new_cls = super().__new__(
362
+ cls,
363
+ name,
364
+ bases,
365
+ new_attrs | attrs,
366
+ is_abstract=is_abstract,
367
+ **kwargs,
368
+ )
369
+
370
+ if mlir_type is not None:
371
+ new_cls._mlir_type = staticmethod(mlir_type)
372
+
373
+ new_cls.width = width
374
+ new_cls._np_dtype = np_dtype
375
+ return new_cls
376
+
377
+ @property
378
+ def numpy_dtype(cls):
379
+ return cls._np_dtype
380
+
381
+ @property
382
+ def is_integer(cls) -> bool: ...
383
+
384
+ @property
385
+ def is_float(cls) -> bool: ...
386
+
387
+ def is_same_kind(cls, other: Type) -> bool:
388
+ return cls.is_integer == other.is_integer or cls.is_float == other.is_float
389
+
390
+ @staticmethod
391
+ def from_python(value: Any) -> Type["Numeric"]:
392
+ """
393
+ Deduce the DSL type from a Python value.
394
+ """
395
+ if isinstance(value, int):
396
+ return Int32
397
+ elif isinstance(value, float):
398
+ return Float32
399
+ elif isinstance(value, bool):
400
+ return Boolean
401
+ raise DSLRuntimeError(
402
+ f"Could not deduce Type[Numeric] from python value: {value} :{type(value)}"
403
+ )
404
+
405
+ @property
406
+ def mlir_type(cls):
407
+ return cls._mlir_type() # type: ignore
408
+
409
+
410
+ Value = TypeVar("Value")
411
+
412
+
413
+ def cast(obj: Union[bool, int, float, Value], type_: Type["Numeric"]) -> "Numeric":
414
+ """Cast an object to the specified numeric type.
415
+
416
+ :param obj: Object to be cast
417
+ :type obj: Union[bool, int, float, Value]
418
+ :param type_: Target numeric type
419
+ :type type_: Type[Numeric]
420
+ :raises TypeError: If casting to an abstract type or unsupported type conversion
421
+ :return: Object cast to the target numeric type
422
+ :rtype: Numeric
423
+
424
+ Example::
425
+ >>> x = cast(5, Int32) # Cast integer to Int32
426
+ >>> y = cast(3.14, Float32) # Cast float to Float32
427
+ """
428
+ if type_.is_abstract:
429
+ if not isinstance(obj, type_):
430
+ raise TypeError(
431
+ f"can't cast {obj} to {type_}. Pass in concrete type instead, "
432
+ "e.g. Int32, Float32, etc."
433
+ )
434
+ # If target_type is abstract, and value is instance of target_type,
435
+ # then we can return value as is
436
+ else:
437
+ # Implicit cast based on using annotation type
438
+ obj = type_(obj)
439
+ return obj
440
+
441
+
442
+ # Option 1: use ir.Value as base
443
+ # class IntegerMeta(DslType, type(ir.Value)):
444
+ class IntegerMeta(NumericMeta):
445
+ """Metaclass for integer types providing signedness information.
446
+
447
+ :param width: Bit width of the integer type, defaults to 32
448
+ :type width: int
449
+ :param signed: Whether the integer type is signed, defaults to True
450
+ :type signed: bool
451
+ :param mlir_type: Corresponding MLIR type, defaults to None
452
+ :type mlir_type: Any, optional
453
+
454
+ :ivar signed: Whether the integer type is signed
455
+ :vartype signed: bool
456
+ :ivar arith: Arithmetic operations interface
457
+ :vartype arith: Any
458
+ """
459
+
460
+ signed: bool
461
+
462
+ def __new__(
463
+ cls,
464
+ name,
465
+ bases,
466
+ attrs,
467
+ width=32,
468
+ signed=True,
469
+ mlir_type=None,
470
+ is_abstract=False,
471
+ ):
472
+ if width == 1:
473
+ np_dtype = np.bool_
474
+ elif width == 128:
475
+ np_dtype = None
476
+ elif signed:
477
+ np_dtype = getattr(np, f"int{width}")
478
+ else:
479
+ np_dtype = getattr(np, f"uint{width}")
480
+
481
+ def _c_pointers(self):
482
+ if width == 1:
483
+ c_value = ctypes.c_bool(self.value)
484
+ elif signed:
485
+ c_value = getattr(ctypes, f"c_int{width}")(self.value)
486
+ else:
487
+ c_value = getattr(ctypes, f"c_uint{width}")(self.value)
488
+
489
+ return [ctypes.cast(ctypes.pointer(c_value), ctypes.c_void_p)]
490
+
491
+ new_attrs = {
492
+ "__c_pointers__": _c_pointers,
493
+ }
494
+ new_cls = super().__new__(
495
+ cls, name, bases, attrs | new_attrs, width, np_dtype, mlir_type, is_abstract
496
+ )
497
+ new_cls.signed = signed
498
+ return new_cls
499
+
500
+ def __str__(cls):
501
+ return f"{cls.__name__}"
502
+
503
+ @property
504
+ def is_integer(cls) -> bool:
505
+ return True
506
+
507
+ @property
508
+ def is_float(cls) -> bool:
509
+ return False
510
+
511
+ @property
512
+ def zero(cls) -> int:
513
+ return 0
514
+
515
+ @property
516
+ def min(cls) -> int:
517
+ if cls.signed:
518
+ return -(2 ** (cls.width - 1))
519
+ else:
520
+ return 0
521
+
522
+ @property
523
+ def max(cls) -> int:
524
+ if cls.signed:
525
+ return 2 ** (cls.width - 1) - 1
526
+ else:
527
+ return 2**cls.width - 1
528
+
529
+ def recast_width(cls, width):
530
+ type_map = {
531
+ 8: Int8,
532
+ 16: Int16,
533
+ 32: Int32,
534
+ 64: Int64,
535
+ 128: Int128,
536
+ }
537
+ if width not in type_map:
538
+ raise TypeError(f"Unsupported width: {width}")
539
+ return type_map[width]
540
+
541
+
542
+ class FloatMeta(NumericMeta):
543
+ """Metaclass for floating-point types.
544
+
545
+ This metaclass provides type system infrastructure for floating-point types in the DSL,
546
+ handling MLIR type mappings and NumPy type conversions.
547
+
548
+ :param width: Bit width of the float type, defaults to 32
549
+ :type width: int
550
+ :param mlir_type: Corresponding MLIR type, defaults to None
551
+ :type mlir_type: Any, optional
552
+ :param is_abstract: Whether this is an abstract base class, defaults to False
553
+ :type is_abstract: bool, optional
554
+
555
+ :ivar _arith: Arithmetic operations interface
556
+ :vartype _arith: Any
557
+ """
558
+
559
+ _exponent_width: int
560
+ _mantissa_width: int
561
+
562
+ def __new__(cls, name, bases, attrs, width=32, mlir_type=None, is_abstract=False):
563
+ np_dtype = getattr(np, name.lower(), None)
564
+ new_cls = super().__new__(
565
+ cls, name, bases, attrs, width, np_dtype, mlir_type, is_abstract
566
+ )
567
+ # Extract exponent and mantissa bits from class name if it follows Float<E><M> pattern
568
+ # For example: Float8E4M3 -> exponent_width=4, mantissa_width=3
569
+ import re
570
+
571
+ if not is_abstract:
572
+ match = re.match(r"Float(\d+)E(\d+)M(\d+)(?:.*)", name)
573
+ if match:
574
+ exp_bits = int(match.group(2))
575
+ mant_bits = int(match.group(3))
576
+
577
+ # Store extracted values as class attributes
578
+ new_cls._exponent_width = exp_bits
579
+ new_cls._mantissa_width = mant_bits
580
+ # Don't have 1-to-1 mapping of narrow precision types like bfloat16, tfloat32, etc.
581
+ return new_cls
582
+
583
+ def __str__(cls):
584
+ return f"{cls.__name__}"
585
+
586
+ @property
587
+ def is_integer(cls) -> bool:
588
+ return False
589
+
590
+ @property
591
+ def is_float(cls) -> bool:
592
+ return True
593
+
594
+ @property
595
+ def zero(cls) -> float:
596
+ return 0.0
597
+
598
+ @property
599
+ def inf(cls) -> float:
600
+ return float("inf")
601
+
602
+ @property
603
+ def nan(cls) -> float:
604
+ return float("nan")
605
+
606
+ @property
607
+ def exponent_width(cls) -> int:
608
+ return cls._exponent_width
609
+
610
+ @property
611
+ def mantissa_width(cls) -> int:
612
+ return cls._mantissa_width
613
+
614
+ def recast_width(cls, width):
615
+ type_map = {
616
+ 16: Float16,
617
+ 32: Float32,
618
+ 64: Float64,
619
+ }
620
+ if width not in type_map:
621
+ raise TypeError(f"Unsupported width: {width}")
622
+ return type_map[width]
623
+
624
+
625
+ def _arith_signless_to_int(a, target_type):
626
+ # is_signed: sign of result type
627
+ if target_type.width > a.type.width:
628
+ # arith dialect consider `1` in `i1` as `-1`, treat it as unsigned for DSL
629
+ if target_type.signed and a.type.width > 1:
630
+ return arith.extsi(target_type.mlir_type, a)
631
+ else:
632
+ return arith.extui(target_type.mlir_type, a)
633
+ elif target_type.width < a.type.width:
634
+ return arith.trunci(target_type.mlir_type, a)
635
+ else:
636
+ return a
637
+
638
+
639
+ def _binary_op_type_promote(a, b, promote_bool: bool = False):
640
+ """Promote two numeric operands following type promotion rules.
641
+
642
+ :param a: First numeric operand
643
+ :type a: Numeric
644
+ :param b: Second numeric operand
645
+ :type b: Numeric
646
+ :param promote_bool: Whether to promote boolean types to Int32 for arithmetic operations, defaults to False
647
+ :type promote_bool: bool, optional
648
+ :raises ValueError: If implicit float promotion is not supported between the given types
649
+ :return: Tuple containing promoted operands and their resulting type
650
+ :rtype: tuple[Numeric, Numeric, Type[Numeric]]
651
+
652
+ Type promotion rules:
653
+ 1. If operands are same type and not bools needing promotion:
654
+ - No promotion needed, return original types
655
+ 2. If either operand is float:
656
+ a. If one is float and one is int:
657
+ - Convert int to the float type
658
+ b. If both are float:
659
+ - Promote to higher precision float if width >= 16
660
+ - For same width, promote to more general type (Float32 over TFloat32)
661
+ - Otherwise raise ValueError for unsupported promotion
662
+ 3. Otherwise, both operands are integers. Integer promotion rules:
663
+ a. If promote_bool is True and either operand is bool:
664
+ - Promote bool to Int32 for arithmetic operations
665
+
666
+ Exceptions for numpy dtype casting:
667
+ - array(dtype=np.bool_) + array(dtype=np.bool_) -> array(dtype=np.bool_)
668
+
669
+ What is not supported:
670
+ - promotion with narrow precision float types which requires explicit cast by user
671
+ """
672
+ a_type = a.dtype
673
+ b_type = b.dtype
674
+
675
+ # Early return for same types (except when they're bools that need promotion)
676
+ if a_type == b_type and not (promote_bool and a_type is Boolean):
677
+ return a, b, a_type
678
+
679
+ # Handle floating point promotions
680
+ if a_type.is_float or b_type.is_float:
681
+ # Get highest precision float type based on bitwidth
682
+ a_width = getattr(a_type, "width", 0)
683
+ b_width = getattr(b_type, "width", 0)
684
+
685
+ # If one type is integer, convert it to the float type
686
+ if a_type.is_float and not b_type.is_float:
687
+ b_type = a_type.recast_width(max(a_width, b_width))
688
+ elif b_type.is_float and not a_type.is_float:
689
+ a_type = b_type.recast_width(max(a_width, b_width))
690
+
691
+ # Both are float types - handle precision promotion
692
+ if a_width > b_width and a_width >= 16:
693
+ res_type = a_type
694
+ elif b_width > a_width and b_width >= 16:
695
+ res_type = b_type
696
+ elif a_width == b_width:
697
+ # Same bitwidth - handle special cases like TFloat32 -> Float32 and BFloat16 -> Float16
698
+ if a_type is Float64 or b_type is Float64:
699
+ res_type = Float64
700
+ elif a_type is Float32 or b_type is Float32:
701
+ res_type = Float32
702
+ elif a_type is Float16 or b_type is Float16:
703
+ res_type = Float16
704
+ else:
705
+ raise ValueError(
706
+ f"implicit float promotion of {a_type} or {b_type} is not supported, cast explicitly"
707
+ )
708
+ else:
709
+ raise ValueError(
710
+ f"implicit float promotion of {a_type} or {b_type} is not supported, cast explicitly"
711
+ )
712
+
713
+ # Only convert if type is different
714
+ new_a = a.to(res_type) if a.dtype != res_type else a
715
+ new_b = b.to(res_type) if b.dtype != res_type else b
716
+ return new_a, new_b, res_type
717
+
718
+ # Handle bool promotion for arithmetic operations
719
+ if promote_bool:
720
+ if a_type is Boolean and b_type is Boolean:
721
+ # Only promote to Int32 when both are bool
722
+ a = a.to(Int32)
723
+ b = b.to(Int32)
724
+ a_type = b_type = a.dtype
725
+
726
+ # If both were bools, they're now same type (Int32)
727
+ if a_type == b_type:
728
+ return a, b, a_type
729
+
730
+ # Same type, no promotion needed
731
+ if a_type == b_type:
732
+ return a, b, a_type
733
+
734
+ a_signed = a_type.signed
735
+ b_signed = b_type.signed
736
+ a_width = a_type.width
737
+ b_width = b_type.width
738
+
739
+ # Mixed signedness case
740
+ if a_signed != b_signed:
741
+ unsigned_type = a_type if not a_signed else b_type
742
+ signed_type = a_type if a_signed else b_type
743
+ unsigned_width = a_width if not a_signed else b_width
744
+
745
+ if unsigned_width >= signed_type.width:
746
+ # Promote both to unsigned of larger width
747
+ res_type = unsigned_type
748
+ else:
749
+ # Promote both to signed of larger width
750
+ res_type = signed_type
751
+
752
+ new_a = a.to(res_type) if a.dtype != res_type else a
753
+ new_b = b.to(res_type) if b.dtype != res_type else b
754
+ return new_a, new_b, res_type
755
+
756
+ # Same signedness, different width - promote to larger width
757
+ if a_width >= b_width:
758
+ return a, b.to(a.dtype), a.dtype
759
+ else:
760
+ return a.to(b.dtype), b, b.dtype
761
+
762
+
763
+ def _binary_op(op, promote_operand=True, promote_bool=False, flip=False):
764
+ """Wrapper for binary operations on Numeric types.
765
+
766
+ This wrapper handles type promotion, operation execution, and result type determination
767
+ for binary operations between Numeric types.
768
+
769
+ :param op: The binary operation to perform (e.g., operator.add, operator.sub)
770
+ :type op: callable
771
+ :param emitter: Function that emits the MLIR operation for dynamic values
772
+ :type emitter: callable
773
+ :param promote_operand: Whether to promote operands to the same type, defaults to True
774
+ :type promote_operand: bool, optional
775
+ :param promote_bool: Whether to promote boolean results to Boolean type, defaults to False
776
+ :type promote_bool: bool, optional
777
+ :param flip: Whether to flip the operands when calling the operation, defaults to False
778
+ :type flip: bool, optional
779
+
780
+ :raises TypeError: When an unsupported operation is attempted on specific numeric types
781
+
782
+ .. note::
783
+ Not all operations are supported for all numeric types. In particular:
784
+
785
+ - Subtraction is not fully supported for Integer types
786
+ - Multiplication, floor division, and modulo operations may have limited support
787
+ - Division (truediv) with integer types is not fully supported and converts to Float32
788
+ """
789
+
790
+ def wrapper(lhs, rhs, *, loc=None, ip=None):
791
+ orig_lhs_type = type(lhs)
792
+ orig_rhs_type = type(rhs)
793
+
794
+ # When called directly with self and other
795
+ ty = type(lhs)
796
+ # Canonicalize to Numeric type for promotion
797
+ if not isinstance(rhs, Numeric):
798
+ if not isinstance(rhs, (ArithValue, int, float, bool)):
799
+ # This allows rhs class to implement __rmul__
800
+ return NotImplemented
801
+
802
+ if isinstance(rhs, ArithValue):
803
+ if isinstance(rhs.type, ir.VectorType):
804
+ return NotImplemented
805
+
806
+ rhs = as_numeric(rhs)
807
+
808
+ # default result type to left-hand-side
809
+ res_type = ty
810
+
811
+ if promote_operand:
812
+ lhs, rhs, res_type = _binary_op_type_promote(lhs, rhs, promote_bool)
813
+ else:
814
+ rhs = ty(rhs)
815
+
816
+ if op in (
817
+ operator.lt,
818
+ operator.le,
819
+ operator.gt,
820
+ operator.ge,
821
+ operator.eq,
822
+ operator.ne,
823
+ ):
824
+ res_type = Boolean
825
+ elif op == operator.truediv and isinstance(lhs, Integer):
826
+ res_type = Float32
827
+ elif promote_bool and orig_lhs_type == Boolean and orig_rhs_type == Boolean:
828
+ res_type = Boolean
829
+
830
+ if isinstance(lhs.value, ArithValue) and isinstance(lhs, Integer):
831
+ lhs_val = lhs.value.with_signedness(lhs.signed)
832
+ else:
833
+ lhs_val = lhs.value
834
+
835
+ if isinstance(rhs.value, ArithValue) and isinstance(rhs, Integer):
836
+ rhs_val = rhs.value.with_signedness(rhs.signed)
837
+ else:
838
+ rhs_val = rhs.value
839
+
840
+ if flip:
841
+ lhs_val, rhs_val = rhs_val, lhs_val
842
+
843
+ # Check if the operation is supported by the operands
844
+ res_val = op(lhs_val, rhs_val)
845
+ return res_type(res_val, loc=loc, ip=ip)
846
+
847
+ return wrapper
848
+
849
+
850
+ class Numeric(metaclass=NumericMeta, is_abstract=True):
851
+ """Base class for all numeric types in the DSL.
852
+
853
+ This class provides the foundation for both Integer and Float types,
854
+ implementing basic arithmetic operations.
855
+
856
+ :param value: The value to store in the numeric type
857
+ :type value: Union[bool, int, float, Value]
858
+
859
+ :ivar value: The stored numeric value
860
+ :vartype value: Union[bool, int, float, Value]
861
+ """
862
+
863
+ def __init__(self, value: Union[bool, int, float, Value], *, loc=None, ip=None):
864
+ self.value = value
865
+
866
+ def __str__(self) -> str:
867
+ # Use member's pretty-str method if member object has method.
868
+ # This can be extended in future to have better support for IDE, jupyter notebook, etc.
869
+ pretty_str = getattr(self.value, "pretty_str", None)
870
+ if pretty_str is not None:
871
+ return pretty_str()
872
+ else:
873
+ return "?"
874
+
875
+ def __repr__(self) -> str:
876
+ return f"{self.__class__.__name__}({repr(self.value)})"
877
+
878
+ def __hash__(self):
879
+ return hash(type(self).__class__) ^ hash(self.value)
880
+
881
+ @property
882
+ def dtype(self) -> Type["Numeric"]:
883
+ return type(self)
884
+
885
+ @overload
886
+ def to(self, dtype: Type["Numeric"], *, loc=None, ip=None) -> "Numeric": ...
887
+
888
+ @overload
889
+ def to(self, dtype: Type[int], *, loc=None, ip=None) -> int: ...
890
+
891
+ @overload
892
+ def to(self, dtype: Type[float], *, loc=None, ip=None) -> float: ...
893
+
894
+ @overload
895
+ def to(self, dtype: Type[bool], *, loc=None, ip=None) -> bool: ...
896
+
897
+ @overload
898
+ def to(self, dtype: Type[ir.Value], *, loc=None, ip=None) -> ir.Value: ...
899
+
900
+ def to(self, dtype: Type, *, loc=None, ip=None):
901
+ """Convert this numeric value to another numeric type.
902
+
903
+ If the target type is the same as the current type, returns self.
904
+ Otherwise, creates a new instance of the target type with the same value.
905
+
906
+ :param dtype: The target numeric type to convert to
907
+ :type dtype: Union[Type["Numeric"], Type[int], Type[float], Type[bool]]
908
+ :return: A new instance of the target type, or self if types match
909
+ :rtype: Numeric
910
+ :raises TypeError: If trying to convert an MLIR value to a static Python type
911
+ :raises TypeError: If trying to convert to unsupported float types like Float8E4M3,
912
+ Float8E4M3B11FNUZ, Float4E2M1FN, Float6E3M2FN, or Float6E2M3FN
913
+
914
+ .. note::
915
+
916
+ Unsupported destination float types:
917
+ - Float8E4M3
918
+ - Float8E4M3B11FNUZ
919
+ - Float4E2M1FN
920
+ - Float6E3M2FN
921
+ - Float6E2M3FN
922
+
923
+ Example::
924
+
925
+ .. code-block:: python
926
+
927
+ # Convert between DSL numeric types
928
+ x = Int32(5)
929
+ y = x.to(Float32) # Converts to Float32(5.0)
930
+
931
+ # Convert to Python primitive types
932
+ # They are considered as static values at JIT time
933
+ z = x.to(int) # Returns Python int 5
934
+ w = y.to(float) # Returns Python float 5.0
935
+
936
+ # This will raise a ValueError
937
+ mlir_val = arith.constant(T.i32(), 42)
938
+ num = Int32(mlir_val)
939
+ num.to(int) # ValueError: unable to convert MLIR value to static type: <class 'int'>
940
+ """
941
+ if dtype in _unsupported_dst_float_types:
942
+ raise TypeError(f"Unsupported destination float type: {dtype}")
943
+
944
+ if isinstance(dtype, type(self)):
945
+ return self
946
+ elif isinstance(dtype, NumericMeta):
947
+ return dtype(self)
948
+ elif dtype is ir.Value:
949
+ if isinstance(self.value, (int, float, bool)):
950
+ res = arith_helper.const(
951
+ self.value, self.dtype.mlir_type, loc=loc, ip=ip
952
+ )
953
+ elif isinstance(self.value, ir.Value):
954
+ res = self.value
955
+ else:
956
+ raise ValueError(
957
+ f"cannot convert {type(self)} to {dtype}, "
958
+ f"self.value is {self.value.type}"
959
+ )
960
+
961
+ if not isinstance(res, ArithValue):
962
+ raise ValueError(f"Expected ArithValue, got {type(res)} as {res.type}")
963
+
964
+ return res.with_signedness(getattr(type(self), "signed", None))
965
+ elif dtype in (int, float, bool):
966
+ if isinstance(self.value, ir.Value):
967
+ raise ValueError(
968
+ f"unable to convert {self.value} to static type: {dtype}"
969
+ )
970
+ return dtype(self.value)
971
+ else:
972
+ raise ValueError(f"unable to convert {type(self)} to {dtype}")
973
+
974
+ def ir_value(self, *, loc=None, ip=None) -> ir.Value:
975
+ return self.to(ir.Value, loc=loc, ip=ip)
976
+
977
+ @property
978
+ def zero(self) -> "Numeric": ...
979
+
980
+ def __dsl_not__(self, *, loc=None, ip=None):
981
+ """DSL implementation of Python's `not` operator.
982
+
983
+ Returns True if the value is equal to zero, False otherwise.
984
+ This matches Python's behavior where any non-zero number is considered True.
985
+
986
+ :param loc: The source location information, defaults to None
987
+ :type loc: Optional[Location]
988
+ :param ip: The insertion point for the operation, defaults to None
989
+ :type ip: Optional[InsertionPoint]
990
+ :return: The result of the logical not operation
991
+ :rtype: Boolean
992
+ """
993
+ if isinstance(self.value, (int, float, bool)):
994
+ return not self.value
995
+ else:
996
+ ty = type(self)
997
+ zero_val = arith.constant(ty.mlir_type, ty.zero)
998
+ return self.__eq__(ty(zero_val), loc=loc, ip=ip)
999
+
1000
+ def __dsl_and__(self, other, *, loc=None, ip=None):
1001
+ """DSL implementation of Python's `and` operator.
1002
+
1003
+ Returns the second operand if the first is truthy, otherwise returns the first operand.
1004
+ A numeric value is considered truthy if it is non-zero.
1005
+
1006
+ :param other: The right-hand operand
1007
+ :type other: Numeric
1008
+ :param loc: The source location information, defaults to None
1009
+ :type loc: Optional[Location]
1010
+ :param ip: The insertion point for the operation, defaults to None
1011
+ :type ip: Optional[InsertionPoint]
1012
+ :return: The result of the logical and operation
1013
+ :rtype: Boolean
1014
+
1015
+ Example::
1016
+
1017
+ 5 and 3 -> 3
1018
+ 0 and 3 -> 0
1019
+ 3 and 0 and ... -> 0
1020
+ """
1021
+ is_true = self.__dsl_bool__(loc=loc, ip=ip)
1022
+
1023
+ def and_op(lhs, rhs):
1024
+ if isinstance(lhs, (int, float, bool)):
1025
+ if isinstance(rhs, (int, float, bool)):
1026
+ return lhs and rhs
1027
+ else:
1028
+ lhs = arith.constant(rhs.type, lhs)
1029
+ return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip)
1030
+ else:
1031
+ if isinstance(rhs, (int, float, bool)):
1032
+ rhs = arith.constant(lhs.type, rhs)
1033
+ return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip)
1034
+ else:
1035
+ return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip)
1036
+
1037
+ return _binary_op(and_op, promote_bool=True)(self, other, loc=loc, ip=ip)
1038
+
1039
+ def __dsl_or__(self, other, *, loc=None, ip=None):
1040
+ """DSL implementation of Python's `or` operator.
1041
+
1042
+ Returns the first operand if it is truthy, otherwise returns the second operand.
1043
+ A numeric value is considered truthy if it is non-zero.
1044
+
1045
+ :param other: The right-hand operand
1046
+ :type other: Numeric
1047
+ :param loc: The source location information, defaults to None
1048
+ :type loc: Optional[Location]
1049
+ :param ip: The insertion point for the operation, defaults to None
1050
+ :type ip: Optional[InsertionPoint]
1051
+ :return: The result of the logical or operation
1052
+ :rtype: Boolean
1053
+
1054
+ Example::
1055
+
1056
+ 5 or 3 -> 5
1057
+ 0 or 3 -> 3
1058
+ 3 or 0 -> 3
1059
+ """
1060
+ is_true = self.__dsl_bool__(loc=loc, ip=ip)
1061
+
1062
+ def or_op(lhs, rhs):
1063
+ if isinstance(lhs, (int, float, bool)):
1064
+ if isinstance(rhs, (int, float, bool)):
1065
+ return lhs or rhs
1066
+ else:
1067
+ lhs = arith.constant(rhs.type, lhs)
1068
+ return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip)
1069
+ else:
1070
+ if isinstance(rhs, (int, float, bool)):
1071
+ rhs = arith.constant(lhs.type, rhs)
1072
+ return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip)
1073
+ else:
1074
+ return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip)
1075
+
1076
+ return _binary_op(or_op, promote_bool=True)(self, other, loc=loc, ip=ip)
1077
+
1078
+ def __dsl_bool__(self, *, loc=None, ip=None) -> "Boolean":
1079
+ """DSL implementation of Python's __bool__ method.
1080
+
1081
+ Returns a Boolean indicating whether this value is considered truthy.
1082
+ For numeric types, returns True if the value is non-zero.
1083
+
1084
+ :param loc: The source location information, defaults to None
1085
+ :type loc: Optional[Location]
1086
+ :param ip: The insertion point for the operation, defaults to None
1087
+ :type ip: Optional[InsertionPoint]
1088
+ :return: True if this value is truthy (non-zero), False otherwise
1089
+ :rtype: Boolean
1090
+ """
1091
+ zero = type(self).zero
1092
+ return self.__ne__(zero, loc=loc, ip=ip)
1093
+
1094
+ def __bool__(self):
1095
+ if isinstance(self.value, (int, float, bool)):
1096
+ return bool(self.value)
1097
+ else:
1098
+ raise DSLRuntimeError(
1099
+ f"Unable to convert dynamic `{type(self).__name__}` value to bool at compile time.",
1100
+ suggestion=[
1101
+ "Decorate the parent function with `jit` decorator and with `preprocess` enabled.",
1102
+ "Ensure not using patterns that DSL does not support.",
1103
+ "Otherwise, please file a bug report.",
1104
+ ],
1105
+ )
1106
+
1107
+ def __index__(self):
1108
+ if isinstance(self.value, (int, float, bool)):
1109
+ return self.value
1110
+ else:
1111
+ raise DSLRuntimeError(
1112
+ f"'{type(self.value)}' object cannot be interpreted as an integer",
1113
+ suggestion="Mark the loop as dynamic with `dynamic_expr` or `range_dynamic` and decorate the parent function with `jit` decorator",
1114
+ )
1115
+
1116
+ def __neg__(self, *, loc=None, ip=None):
1117
+ if isinstance(self, (bool, int, float)):
1118
+ return type(self)(-self.value) # type: ignore
1119
+ else:
1120
+ return type(self)(-self.value, loc=loc, ip=ip) # type: ignore
1121
+
1122
+ @staticmethod
1123
+ def _from_python_value(value):
1124
+ if isinstance(value, Numeric):
1125
+ return value
1126
+
1127
+ if isinstance(value, bool):
1128
+ res_type = Boolean
1129
+ elif isinstance(value, int):
1130
+ res_type = Int32
1131
+ elif isinstance(value, float):
1132
+ res_type = Float32
1133
+ elif isinstance(value, ArithValue):
1134
+ res_type = Numeric.from_mlir_type(value.type)
1135
+ else:
1136
+ raise ValueError(
1137
+ f"unable to convert {value} in type {type(value)} to Numeric"
1138
+ )
1139
+ return res_type(value)
1140
+
1141
+ def __add__(self, other, *, loc=None, ip=None) -> "Numeric":
1142
+ return _binary_op(operator.add, promote_bool=True)(self, other, loc=loc, ip=ip)
1143
+
1144
+ def __sub__(self, other, *, loc=None, ip=None) -> "Numeric":
1145
+ return _binary_op(operator.sub, promote_bool=True)(self, other, loc=loc, ip=ip)
1146
+
1147
+ def __mul__(self, other, *, loc=None, ip=None) -> "Numeric":
1148
+ return _binary_op(operator.mul, promote_bool=True)(self, other, loc=loc, ip=ip)
1149
+
1150
+ def __floordiv__(self, other, *, loc=None, ip=None) -> "Numeric":
1151
+ return _binary_op(operator.floordiv, promote_bool=True)(
1152
+ self, other, loc=loc, ip=ip
1153
+ )
1154
+
1155
+ def __truediv__(self, other, *, loc=None, ip=None) -> "Numeric":
1156
+ return _binary_op(operator.truediv, promote_bool=True)(
1157
+ self, other, loc=loc, ip=ip
1158
+ )
1159
+
1160
+ def __mod__(self, other, *, loc=None, ip=None) -> "Numeric":
1161
+ return _binary_op(operator.mod, promote_bool=True)(self, other, loc=loc, ip=ip)
1162
+
1163
+ def __radd__(self, other, *, loc=None, ip=None) -> "Numeric":
1164
+ return self.__add__(other, loc=loc, ip=ip)
1165
+
1166
+ def __rsub__(self, other, *, loc=None, ip=None) -> "Numeric":
1167
+ return _binary_op(operator.sub, promote_bool=True, flip=True)(
1168
+ self, other, loc=loc, ip=ip
1169
+ )
1170
+
1171
+ def __rmul__(self, other, *, loc=None, ip=None) -> "Numeric":
1172
+ return self.__mul__(other, loc=loc, ip=ip)
1173
+
1174
+ def __rfloordiv__(self, other, *, loc=None, ip=None) -> "Numeric":
1175
+ return _binary_op(operator.floordiv, promote_bool=True, flip=True)(
1176
+ self, other, loc=loc, ip=ip
1177
+ )
1178
+
1179
+ def __rtruediv__(self, other, *, loc=None, ip=None) -> "Numeric":
1180
+ return _binary_op(operator.truediv, promote_bool=True, flip=True)(
1181
+ self, other, loc=loc, ip=ip
1182
+ )
1183
+
1184
+ def __rmod__(self, other, *, loc=None, ip=None) -> "Numeric":
1185
+ return _binary_op(operator.mod, promote_bool=True, flip=True)(
1186
+ self, other, loc=loc, ip=ip
1187
+ )
1188
+
1189
+ def __eq__(self, other, *, loc=None, ip=None) -> "Boolean":
1190
+ return _binary_op(operator.eq)(self, other, loc=loc, ip=ip) # type: ignore
1191
+
1192
+ def __ne__(self, other, *, loc=None, ip=None) -> "Boolean":
1193
+ return _binary_op(operator.ne)(self, other, loc=loc, ip=ip) # type: ignore
1194
+
1195
+ def __lt__(self, other, *, loc=None, ip=None) -> "Boolean":
1196
+ return _binary_op(operator.lt)(self, other, loc=loc, ip=ip) # type: ignore
1197
+
1198
+ def __le__(self, other, *, loc=None, ip=None) -> "Boolean":
1199
+ return _binary_op(operator.le)(self, other, loc=loc, ip=ip) # type: ignore
1200
+
1201
+ def __gt__(self, other, *, loc=None, ip=None) -> "Boolean":
1202
+ return _binary_op(operator.gt)(self, other, loc=loc, ip=ip) # type: ignore
1203
+
1204
+ def __ge__(self, other, *, loc=None, ip=None) -> "Boolean":
1205
+ return _binary_op(operator.ge)(self, other, loc=loc, ip=ip) # type: ignore
1206
+
1207
+ def __pow__(self, other, *, loc=None, ip=None) -> "Numeric":
1208
+ return _binary_op(operator.pow)(self, other, loc=loc, ip=ip) # type: ignore
1209
+
1210
+ def __c_pointers__(self):
1211
+ raise ValueError(
1212
+ f"only support built-in types: bool, (u)int{8, 16, 32, 64}, float{32, 64}, but got {type(self)}"
1213
+ )
1214
+
1215
+ def __get_mlir_types__(self):
1216
+ return [type(self).mlir_type]
1217
+
1218
+ @staticmethod
1219
+ def from_mlir_type(mlir_type):
1220
+ type_map = {
1221
+ T.bool(): Boolean,
1222
+ T.f64(): Float64,
1223
+ T.f32(): Float32,
1224
+ T.tf32(): TFloat32,
1225
+ T.f16(): Float16,
1226
+ T.bf16(): BFloat16,
1227
+ T.i(128): Int128,
1228
+ T.i64(): Int64,
1229
+ T.i32(): Int32,
1230
+ T.i16(): Int16,
1231
+ T.i8(): Int8,
1232
+ T.si(128): Int128,
1233
+ T.si64(): Int64,
1234
+ T.si32(): Int32,
1235
+ T.si16(): Int16,
1236
+ T.si8(): Int8,
1237
+ T.ui(128): Uint128,
1238
+ T.ui64(): Uint64,
1239
+ T.ui32(): Uint32,
1240
+ T.ui16(): Uint16,
1241
+ T.ui8(): Uint8,
1242
+ T.f8E5M2(): Float8E5M2,
1243
+ T.f8E4M3(): Float8E4M3,
1244
+ T.f8E4M3FN(): Float8E4M3FN,
1245
+ T.f8E4M3B11FNUZ(): Float8E4M3B11FNUZ,
1246
+ T.f4E2M1FN(): Float4E2M1FN,
1247
+ T.f6E2M3FN(): Float6E2M3FN,
1248
+ T.f6E3M2FN(): Float6E3M2FN,
1249
+ T.f8E8M0FNU(): Float8E8M0FNU,
1250
+ }
1251
+ if mlir_type not in type_map:
1252
+ raise DSLRuntimeError(f"Unsupported DSL type: {mlir_type}")
1253
+ return type_map[mlir_type]
1254
+
1255
+
1256
+ def as_numeric(obj: Union[bool, int, float, ir.Value, Numeric]) -> Numeric:
1257
+ """Convert a Python primitive value to a Numeric type.
1258
+
1259
+ :param obj: Python primitive value to convert
1260
+ :type obj: Union[bool, int, float]
1261
+ :return: The converted Numeric object
1262
+ :rtype: Numeric
1263
+
1264
+ Example::
1265
+
1266
+ .. code-block:: python
1267
+
1268
+ x = as_numeric(5) # Converts to Int32
1269
+ y = as_numeric(3.14) # Converts to Float32
1270
+ z = as_numeric(True) # Converts to Boolean
1271
+ """
1272
+ if isinstance(obj, Numeric):
1273
+ return obj
1274
+ return Numeric._from_python_value(obj)
1275
+
1276
+
1277
+ class Integer(Numeric, metaclass=IntegerMeta, mlir_type=T.i32, is_abstract=True):
1278
+ """A class representing integer values with specific width and signedness.
1279
+
1280
+ This class provides functionality to create and manipulate integer values with
1281
+ configurable width and signedness. It supports conversion from various input types
1282
+ including Python scalars, MLIR Values, and other numeric types.
1283
+
1284
+ :param x: The input value to convert to this integer type
1285
+ :type x: Union[bool, int, float, ir.Value, Integer, Float]
1286
+
1287
+ :return: A new Integer instance with the converted value
1288
+ :rtype: Integer
1289
+
1290
+ :raises AssertionError: If the type's numpy_dtype is None
1291
+ :raises NotImplementedError: If converting between different Integer types
1292
+ :raises ValueError: If the input type is not supported for conversion
1293
+ :raises OverflowError: If converting float infinity to integer
1294
+
1295
+ Type conversion behavior:
1296
+
1297
+ * Python scalars (bool, int, float):
1298
+ * Converted through numpy dtype casting
1299
+ * NaN and infinity values are rejected
1300
+ * Example: Int8(256) -> -256 (overflow behavior)
1301
+
1302
+ * MLIR Value with IntegerType:
1303
+ * Width differences handled by signless to signed/unsigned conversion
1304
+ * Example: i8 -> i8/ui8 depending on target type
1305
+
1306
+ * MLIR Value with FloatType:
1307
+ * Uses MLIR float-to-int conversion
1308
+ * NaN and infinity values is undefined behavior
1309
+ * Example: f32 -> i32/ui32 depending on target type
1310
+
1311
+ * Integer:
1312
+ * Uses MLIR float-to-int conversion or numpy dtype casting
1313
+ * Example: Int32(Int32(5)) => 5
1314
+
1315
+ * Float:
1316
+ * Uses MLIR float-to-int conversion
1317
+ * Example: Int32(Float(5.7)) -> 5
1318
+
1319
+ Example usage:
1320
+
1321
+ .. code-block:: python
1322
+
1323
+ x = Int32(5) # From integer
1324
+ y = Int32(True) # From boolean
1325
+ z = Int32(3.7) # From float (truncates)
1326
+ w = Int32(x) # From same Integer type
1327
+ c5 = arith.constant(5, T.i32())
1328
+ a = Int32(c5) # Treat c5 as int32 bitwise
1329
+ """
1330
+
1331
+ def __init__(self, x, *, loc=None, ip=None):
1332
+ ty = type(self)
1333
+
1334
+ if isinstance(x, (bool, int, float)):
1335
+ # Add check for NaN before numpy conversion
1336
+ if isinstance(x, float):
1337
+ if np.isnan(x):
1338
+ raise ValueError("Cannot convert float NaN to integer")
1339
+ elif np.isinf(x):
1340
+ raise OverflowError("Cannot convert float infinity to integer")
1341
+
1342
+ np_dtype = ty.numpy_dtype
1343
+ assert np_dtype is not None, f"expects numpy.dtype, but got {np_dtype}"
1344
+ x_val = int(np.array(x).astype(np_dtype))
1345
+ elif type(x) == ty:
1346
+ x_val = x.value
1347
+ elif isinstance(x, ir.Value): # type: ignore
1348
+ x_val = x
1349
+ if isinstance(x.type, ir.IntegerType): # type: ignore
1350
+ if x.type.width != ty.width:
1351
+ # signless -> (u)int
1352
+ x_val = _arith_signless_to_int(x, ty)
1353
+ elif isinstance(x.type, ir.FloatType): # type: ignore
1354
+ # float -> (u)int
1355
+ x_val = arith_helper.fptoi(x, ty.signed, ty.mlir_type, loc=loc, ip=ip)
1356
+ elif isinstance(x, Integer):
1357
+ if isinstance(x.value, ir.Value):
1358
+ x_val = arith_helper.int_to_int(x.ir_value(), ty)
1359
+ else:
1360
+ # For non-MLIR values, use numpy casting
1361
+ src_val = np.array(x.value, dtype=type(x).numpy_dtype)
1362
+ x_val = int(src_val.astype(ty.numpy_dtype))
1363
+ elif isinstance(x, Float):
1364
+ # float -> int is handled by Integer.__init__ recursively
1365
+ Integer.__init__(self, x.value)
1366
+ return
1367
+ else:
1368
+ raise DSLRuntimeError(f"{x} to integer conversion is not supported")
1369
+
1370
+ super().__init__(x_val)
1371
+
1372
+ def __invert__(self, *, loc=None, ip=None):
1373
+ res_type = type(self)
1374
+ return res_type(self.ir_value(loc=loc, ip=ip).__invert__(loc=loc, ip=ip))
1375
+
1376
+ def __lshift__(self, other, *, loc=None, ip=None):
1377
+ return _binary_op(operator.lshift)(self, other, loc=loc, ip=ip)
1378
+
1379
+ def __rlshift__(self, other, *, loc=None, ip=None):
1380
+ other_ = as_numeric(other)
1381
+ if not isinstance(other_, Integer):
1382
+ raise ValueError(f"Cannot left shift {other_} with {self}")
1383
+ return other_.__lshift__(self, loc=loc, ip=ip)
1384
+
1385
+ def __rshift__(self, other, *, loc=None, ip=None):
1386
+ return _binary_op(operator.rshift)(self, other, loc=loc, ip=ip)
1387
+
1388
+ def __rrshift__(self, other, *, loc=None, ip=None):
1389
+ other_ = as_numeric(other)
1390
+ if not isinstance(other_, Integer):
1391
+ raise ValueError(f"Cannot right shift {other_} with {self}")
1392
+ return other_.__rshift__(self, loc=loc, ip=ip)
1393
+
1394
+ def __and__(self, other, *, loc=None, ip=None):
1395
+ return _binary_op(operator.and_)(self, other, loc=loc, ip=ip)
1396
+
1397
+ def __rand__(self, other, *, loc=None, ip=None):
1398
+ return self.__and__(other, loc=loc, ip=ip)
1399
+
1400
+ def __or__(self, other, *, loc=None, ip=None):
1401
+ return _binary_op(operator.or_)(self, other, loc=loc, ip=ip)
1402
+
1403
+ def __ror__(self, other, *, loc=None, ip=None):
1404
+ return self.__or__(other, loc=loc, ip=ip)
1405
+
1406
+ def __xor__(self, other, *, loc=None, ip=None):
1407
+ return _binary_op(operator.xor)(self, other, loc=loc, ip=ip)
1408
+
1409
+ def __rxor__(self, other, *, loc=None, ip=None):
1410
+ return self.__xor__(other, loc=loc, ip=ip)
1411
+
1412
+
1413
+ class Float(Numeric, metaclass=FloatMeta, mlir_type=T.f32, is_abstract=True):
1414
+ """A class representing floating-point values.
1415
+
1416
+ :param x: The input value to convert to this float type.
1417
+ :type x: Union[bool, int, float, ir.Value, Integer, Float]
1418
+
1419
+ Type conversion behavior:
1420
+
1421
+ 1. Python scalars (bool, int, float):
1422
+ - Converted through numpy dtype casting
1423
+ - Example: Float32(1.7) -> 1.7
1424
+
1425
+ 2. MLIR Value with FloatType:
1426
+ - If width differs: converts between float types
1427
+ - Example: f16 -> f32
1428
+
1429
+ 3. MLIR Value with IntegerType:
1430
+ - Not supported, raises ValueError
1431
+
1432
+ 4. Integer:
1433
+ - Converts using MLIR int-to-float operation
1434
+ - Example: Float32(Int32(5)) -> 5.0
1435
+
1436
+ 5. Float:
1437
+ - Direct conversion between float types
1438
+ - Example: Float32(Float32(1.5)) -> 1.5
1439
+
1440
+ .. note::
1441
+ The following narrow precision types are only supported in device code:
1442
+
1443
+ 8-bit float types:
1444
+ - Float8E5M2
1445
+ - Float8E4M3
1446
+ - Float8E4M3FN
1447
+ - Float8E8M0FNU
1448
+ - Float8E4M3B11FNUZ
1449
+
1450
+ 6-bit float types:
1451
+ - Float6E3M2FN
1452
+ - Float6E2M3FN
1453
+
1454
+ 4-bit float types:
1455
+ - Float4E2M1FN
1456
+
1457
+ Narrow precision types and special floating-point formats support matrix on device:
1458
+
1459
+ :raises AssertionError: If the type's numpy_dtype is None
1460
+ :raises ValueError: If conversion from the input type is not supported
1461
+ """
1462
+
1463
+ def __init__(self, x, *, loc=None, ip=None):
1464
+ ty = type(self)
1465
+
1466
+ if isinstance(x, (bool, int, float)): # type: ignore
1467
+ # Why we need to convert x to with numpy?
1468
+ # np_dtype = ty.numpy_dtype
1469
+ # assert np_dtype is not None, f"expects numpy.dtype, but got {np_dtype}"
1470
+ # x = float(np.array(x).astype(np_dtype))
1471
+ super().__init__(float(x))
1472
+ elif isinstance(x, ir.Value): # type: ignore
1473
+ if isinstance(x.type, ir.IntegerType): # type: ignore
1474
+ raise DSLRuntimeError("signless to float conversion is not implemented")
1475
+ elif isinstance(x.type, ir.FloatType): # type: ignore
1476
+ if x.type != ty.mlir_type:
1477
+ x = arith_helper.cvtf(x, ty.mlir_type, loc=loc, ip=ip)
1478
+ super().__init__(x)
1479
+ elif isinstance(x, Integer):
1480
+ if isinstance(x.value, ir.Value): # type: ignore
1481
+ x = arith_helper.itofp(
1482
+ x.value, type(x).signed, ty.mlir_type, loc=loc, ip=ip
1483
+ )
1484
+ else:
1485
+ x = float(x.value)
1486
+ super().__init__(x)
1487
+ elif isinstance(x, Float):
1488
+ Float.__init__(self, x.value)
1489
+ else:
1490
+ raise DSLRuntimeError(f"{x} to Float conversion is not supported")
1491
+
1492
+
1493
+ class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T.bool):
1494
+ """Boolean type representation in the DSL.
1495
+
1496
+ This class represents boolean values in the DSL, with a width of 1 bit.
1497
+ It supports conversion from various types to boolean values.
1498
+
1499
+ :param a: Value to convert to Boolean
1500
+ :type a: Union[bool, int, float, "Value", Numeric]
1501
+ :param loc: Source location information, defaults to None
1502
+ :type loc: Optional[Location], optional
1503
+ :param ip: Insertion point for MLIR operations, defaults to None
1504
+ :type ip: Optional[InsertionPoint], optional
1505
+ :raises DSLRuntimeError: If the input value cannot be converted to Boolean
1506
+
1507
+ Conversion rules:
1508
+
1509
+ 1. Python bool/int/float:
1510
+ - Converted using Python's bool() function
1511
+ - Example: Boolean(1) -> True, Boolean(0) -> False
1512
+
1513
+ 2. Numeric:
1514
+ - Uses the Numeric.value to construct Boolean recursively
1515
+
1516
+ 3. MLIR Value with IntegerType:
1517
+ - If width is 1: Direct assignment
1518
+ - Otherwise: Compares with 0 using arith.cmpi
1519
+
1520
+ 4. MLIR Value with FloatType:
1521
+ - Compares with 0.0 using arith.cmpf
1522
+ - Uses unordered comparison to handle NaN values
1523
+ """
1524
+
1525
+ def __init__(
1526
+ self, a: Union[bool, int, float, ir.Value, Numeric], *, loc=None, ip=None
1527
+ ):
1528
+ value = None
1529
+ if isinstance(a, (bool, int, float)):
1530
+ value = bool(a)
1531
+ elif isinstance(a, Numeric):
1532
+ Boolean.__init__(self, a.value, loc=loc, ip=ip)
1533
+ return
1534
+ elif isinstance(a, ArithValue):
1535
+ if a.type == T.bool():
1536
+ value = a
1537
+ else:
1538
+ value = a != arith_helper.const(0, a.type, loc=loc, ip=ip)
1539
+ if value is None:
1540
+ raise DSLRuntimeError(f"Cannot convert {a} to Boolean")
1541
+ super().__init__(value, loc=loc, ip=ip)
1542
+ self._value_int8 = None
1543
+
1544
+ def ir_value_int8(self, *, loc=None, ip=None):
1545
+ """
1546
+ Returns int8 ir value of Boolean.
1547
+ When we need to store Boolean tensor element, use ir_value_int8().
1548
+
1549
+ :param loc: Source location information, defaults to None
1550
+ :type loc: Optional[Location], optional
1551
+ :param ip: Insertion point for MLIR operations, defaults to None
1552
+ :type ip: Optional[InsertionPoint], optional
1553
+ :return: The int8 value of this Boolean
1554
+ :rtype: ir.Value
1555
+ """
1556
+ if self._value_int8 is not None:
1557
+ return self._value_int8
1558
+ self._value_int8 = Int8(self.value, loc=loc, ip=ip).ir_value()
1559
+ return self._value_int8
1560
+
1561
+ def __neg__(self, *, loc=None, ip=None):
1562
+ """Negation operator is not supported for boolean type.
1563
+
1564
+ :param loc: Source location information, defaults to None
1565
+ :type loc: Optional[Location], optional
1566
+ :param ip: Insertion point for MLIR operations, defaults to None
1567
+ :type ip: Optional[InsertionPoint], optional
1568
+ :raises TypeError: Always raises this error as negation is not supported
1569
+ """
1570
+ raise TypeError("Negation, the operator `-` is not supported for boolean type")
1571
+
1572
+
1573
+ class Int8(Integer, metaclass=IntegerMeta, width=8, signed=True, mlir_type=T.i8): ...
1574
+
1575
+
1576
+ class Int16(Integer, metaclass=IntegerMeta, width=16, signed=True, mlir_type=T.i16): ...
1577
+
1578
+
1579
+ class Int32(Integer, metaclass=IntegerMeta, width=32, signed=True, mlir_type=T.i32): ...
1580
+
1581
+
1582
+ class Int64(Integer, metaclass=IntegerMeta, width=64, signed=True, mlir_type=T.i64): ...
1583
+
1584
+
1585
+ class Int128(
1586
+ Integer, metaclass=IntegerMeta, width=128, signed=True, mlir_type=lambda: T.i(128)
1587
+ ): ...
1588
+
1589
+
1590
+ class Uint8(Integer, metaclass=IntegerMeta, width=8, signed=False, mlir_type=T.i8): ...
1591
+
1592
+
1593
+ class Uint16(
1594
+ Integer, metaclass=IntegerMeta, width=16, signed=False, mlir_type=T.i16
1595
+ ): ...
1596
+
1597
+
1598
+ class Uint32(
1599
+ Integer, metaclass=IntegerMeta, width=32, signed=False, mlir_type=T.i32
1600
+ ): ...
1601
+
1602
+
1603
+ class Uint64(
1604
+ Integer, metaclass=IntegerMeta, width=64, signed=False, mlir_type=T.i64
1605
+ ): ...
1606
+
1607
+
1608
+ class Uint128(
1609
+ Integer, metaclass=IntegerMeta, width=128, signed=False, mlir_type=lambda: T.i(128)
1610
+ ): ...
1611
+
1612
+
1613
+ class Float64(Float, metaclass=FloatMeta, width=64, mlir_type=T.f64):
1614
+ def __c_pointers__(self):
1615
+ if not isinstance(self.value, float):
1616
+ raise ValueError("only float is supported")
1617
+
1618
+ return [
1619
+ ctypes.cast(ctypes.pointer(ctypes.c_double(self.value)), ctypes.c_void_p)
1620
+ ]
1621
+
1622
+
1623
+ class Float32(Float, metaclass=FloatMeta, width=32, mlir_type=T.f32):
1624
+ @staticmethod
1625
+ def _get_c_pointer(value: float):
1626
+ return ctypes.cast(ctypes.pointer(ctypes.c_float(value)), ctypes.c_void_p)
1627
+
1628
+ def __c_pointers__(self):
1629
+ if not isinstance(self.value, float):
1630
+ raise ValueError("only float is supported")
1631
+
1632
+ return [Float32._get_c_pointer(self.value)]
1633
+
1634
+
1635
+ class TFloat32(Float, metaclass=FloatMeta, width=32, mlir_type=T.tf32):
1636
+ def __c_pointers__(self):
1637
+ if not isinstance(self.value, float):
1638
+ raise ValueError("only float is supported")
1639
+ return [Float32._get_c_pointer(self.value)]
1640
+
1641
+
1642
+ class Float16(Float, metaclass=FloatMeta, width=16, mlir_type=T.f16):
1643
+ @staticmethod
1644
+ def _get_c_pointer(value: float):
1645
+ # Convert float to float16 binary representation
1646
+ # First convert to numpy float16 to handle the conversion
1647
+ f16_val = np.float16(value)
1648
+ # Get the raw bits as a 16-bit integer
1649
+ bits = f16_val.view(np.uint16)
1650
+ # Create a short (16-bit int) with those bits
1651
+ c_val = ctypes.c_short(bits)
1652
+ return ctypes.cast(ctypes.pointer(c_val), ctypes.c_void_p)
1653
+
1654
+ def __c_pointers__(self):
1655
+ if not isinstance(self.value, float):
1656
+ raise ValueError("only float is supported")
1657
+ return [Float16._get_c_pointer(self.value)]
1658
+
1659
+
1660
+ class BFloat16(Float, metaclass=FloatMeta, width=16, mlir_type=T.bf16):
1661
+ def __c_pointers__(self):
1662
+ if not isinstance(self.value, float):
1663
+ raise ValueError("only float is supported")
1664
+
1665
+ return Float.__c_pointers__(self)
1666
+
1667
+
1668
+ class Float8E5M2(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E5M2): ...
1669
+
1670
+
1671
+ class Float8E4M3FN(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3FN): ...
1672
+
1673
+
1674
+ class Float8E4M3B11FNUZ(
1675
+ Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3B11FNUZ
1676
+ ): ...
1677
+
1678
+
1679
+
1680
+ # Added missing float types
1681
+ class Float8E4M3(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3): ...
1682
+
1683
+
1684
+ class Float8E8M0FNU(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E8M0FNU): ...
1685
+
1686
+
1687
+ class Float4E2M1FN(Float, metaclass=FloatMeta, width=4, mlir_type=T.f4E2M1FN): ...
1688
+
1689
+
1690
+ class Float6E3M2FN(Float, metaclass=FloatMeta, width=6, mlir_type=T.f6E3M2FN): ...
1691
+
1692
+
1693
+ class Float6E2M3FN(Float, metaclass=FloatMeta, width=6, mlir_type=T.f6E2M3FN): ...
1694
+
1695
+
1696
+ _unsupported_dst_float_types = [
1697
+ Float8E4M3,
1698
+ Float8E4M3B11FNUZ,
1699
+ Float4E2M1FN,
1700
+ Float6E3M2FN,
1701
+ Float6E2M3FN,
1702
+ ]
1703
+
1704
+
1705
+ ALL_DTYPES = {
1706
+ Int8,
1707
+ Int16,
1708
+ Int32,
1709
+ Int64,
1710
+ Int128,
1711
+ Uint8,
1712
+ Uint16,
1713
+ Uint32,
1714
+ Uint64,
1715
+ Uint128,
1716
+ BFloat16,
1717
+ Float16,
1718
+ Float32,
1719
+ TFloat32,
1720
+ Float64,
1721
+ Float8E5M2,
1722
+ Float8E4M3,
1723
+ Float8E4M3FN,
1724
+ Float8E8M0FNU,
1725
+ Float8E4M3B11FNUZ,
1726
+ Float4E2M1FN,
1727
+ Float6E2M3FN,
1728
+ Float6E3M2FN,
1729
+ }
1730
+ __STR_TO_DTYPE__ = {dt.__name__: dt for dt in ALL_DTYPES}
1731
+
1732
+
1733
+ def dtype(dtype_) -> Type[Numeric]:
1734
+ t = None
1735
+ if const_expr(isinstance(dtype_, str) and dtype_ in __STR_TO_DTYPE__):
1736
+ t = __STR_TO_DTYPE__[dtype_]
1737
+ else:
1738
+ raise TypeError(f"can't interpret {dtype_} as data type")
1739
+
1740
+ return t
1741
+
1742
+
1743
+ ##############################################################
1744
+ # Tensor
1745
+ ##############################################################
1746
+
1747
+
1748
+ class TensorMeta(DslType):
1749
+ _element_type = Any
1750
+ _shape = Any
1751
+
1752
+ """
1753
+ Examples:
1754
+ >>> Tensor[Int32, (3,)]
1755
+ >>> Tensor[Float32, (3, 4)]
1756
+ >>> T = TypeVar("T")
1757
+ >>> Tensor[T, (3, 4, 5)]
1758
+ """
1759
+
1760
+ def __new__(cls, name, bases, attrs, element_type=Any, shape=Any):
1761
+ new_cls = super().__new__(cls, name, bases, attrs)
1762
+ new_cls._element_type = element_type
1763
+ new_cls._shape = shape
1764
+ return new_cls
1765
+
1766
+
1767
+ # Generic type
1768
+ TY = TypeVar("TY")
1769
+
1770
+
1771
+ class Constexpr(Generic[TY]):
1772
+ """Value is passed and computed by python interpreter"""
1773
+
1774
+ pass
1775
+
1776
+
1777
+ class align:
1778
+ def __init__(self, value: int):
1779
+ if value <= 0 or (value & (value - 1)) != 0:
1780
+ raise DSLRuntimeError("expects align be power of 2 as positive value")
1781
+ self._value = value
1782
+
1783
+ def __str__(self):
1784
+ return f"align({self._value})"
1785
+
1786
+
1787
+ class PointerMeta(DslType):
1788
+ def __new__(cls, name, bases, attrs, value_type=Int32, align_=align(1)):
1789
+ new_cls = super().__new__(
1790
+ cls,
1791
+ name,
1792
+ bases,
1793
+ attrs,
1794
+ mlir_type=lambda: getattr(ir, "UnrankedMemRefType").get(
1795
+ value_type.mlir_type, getattr(ir, "Attribute").parse("0")
1796
+ ),
1797
+ )
1798
+ new_cls._value_type = value_type
1799
+ new_cls._align = align_
1800
+ return new_cls
1801
+
1802
+ def __eq__(cls, other):
1803
+ if not isinstance(other, PointerMeta):
1804
+ return False
1805
+ return (
1806
+ cls._value_type == other._value_type
1807
+ and cls._align._value == other._align._value
1808
+ ) # Compare alignment values
1809
+
1810
+ def __hash__(cls):
1811
+ return hash((cls._value_type, cls._align._value)) # Hash alignment value
1812
+
1813
+ def __getitem__(cls, params) -> Type["Pointer"]:
1814
+ value_type, align_ = params
1815
+
1816
+ if not isinstance(align_, align):
1817
+ raise DSLRuntimeError(f"expects align but got {align_}")
1818
+
1819
+ # Create new class with proper name and parameters
1820
+ new_cls = type(
1821
+ f"Pointer[{value_type.__name__}, {align_}]",
1822
+ (Pointer,),
1823
+ {},
1824
+ value_type=value_type,
1825
+ align_=align_, # Pass alignment to __new__
1826
+ )
1827
+ return new_cls
1828
+
1829
+ def __str__(cls):
1830
+ return f"ptr<{cls._value_type}, {cls._align}>"
1831
+
1832
+
1833
+ class Pointer(metaclass=PointerMeta):
1834
+ """
1835
+ A pointer to a memory location.
1836
+
1837
+ Examples:
1838
+
1839
+ def foo(a : Pointer[Int32, align=8]):
1840
+ ...
1841
+
1842
+ """
1843
+
1844
+ def __init__(self, value):
1845
+ self.value = value
1846
+
1847
+ def __str__(self):
1848
+ return f"{self.value} : {type(self)}"
1849
+
1850
+
1851
+ class IRConst(Generic[TY]):
1852
+ """Value is passed as MLIR constant value for (arith.constant)."""
1853
+
1854
+ def __init__(self, ty: TY):
1855
+ self.ty = ty
1856
+
1857
+
1858
+ class IRValue(Generic[TY]):
1859
+ """Value is passed as MLIR dynamic value."""
1860
+
1861
+ def __init__(self, ty: TY):
1862
+ self.ty = ty
1863
+
1864
+
1865
+ class IRVariadic:
1866
+ """
1867
+ A helper class to pass a variadic number of arguments to a function.
1868
+ """
1869
+
1870
+ def __init__(self, operands):
1871
+ """
1872
+ Create a list of variadic operands. `operands` must be dynamic values.
1873
+ """
1874
+ self.operands = operands
1875
+
1876
+ def block_arg_types(self):
1877
+ """
1878
+ Return the list of block args types.
1879
+ """
1880
+ return [operand.type for operand in self.operands]
1881
+
1882
+ def set_func_args(self, block_args):
1883
+ """
1884
+ This function is called after entering a function. `block_args` are the
1885
+ block arguments that correspond to the passed operands. Derived classes
1886
+ may implement this function to provide convenience getters for block
1887
+ arguments.
1888
+ """
1889
+ pass
1890
+
1891
+ def __len__(self):
1892
+ """
1893
+ Return the length of variadic operands.
1894
+ """
1895
+ return len(self.operands)
1896
+
1897
+
1898
+ class FuncArgWithAttr(IRValue):
1899
+ """
1900
+ This derived class is specifically for func op arg with attr
1901
+ """
1902
+
1903
+ def __init__(self, ty, attr_name, attr_ty, attr_value=None):
1904
+ super().__init__(ty)
1905
+ assert attr_name is not None and (
1906
+ attr_ty is not None or attr_value is not None
1907
+ ), "Invalid attr_name and/or attr_ty and/or attr_value for FuncArgWithAttr"
1908
+ self.attr_name = attr_name
1909
+ self.attr_ty = attr_ty
1910
+ self.attr_value = attr_value
1911
+
1912
+
1913
+
1914
+ def implicitDowncastNumericType(value):
1915
+ if isinstance(value, Numeric):
1916
+ return value.ir_value()
1917
+ return value
1918
+
1919
+
1920
+ __all__ = [
1921
+ "DslType",
1922
+ "Numeric",
1923
+ "NumericMeta",
1924
+ "IntegerMeta",
1925
+ "FloatMeta",
1926
+ "Boolean",
1927
+ "Integer",
1928
+ "Int16",
1929
+ "Int32",
1930
+ "Int64",
1931
+ "Int128",
1932
+ "Int8",
1933
+ "Uint8",
1934
+ "Uint16",
1935
+ "Uint32",
1936
+ "Uint64",
1937
+ "Uint128",
1938
+ "Float",
1939
+ "Float16",
1940
+ "BFloat16",
1941
+ "TFloat32",
1942
+ "Float32",
1943
+ "Float64",
1944
+ "Float8E5M2",
1945
+ "Float8E4M3",
1946
+ "Float8E4M3FN",
1947
+ "Float8E4M3B11FNUZ",
1948
+ "Float8E4M3",
1949
+ "Float8E8M0FNU",
1950
+ "Float4E2M1FN",
1951
+ "Float6E2M3FN",
1952
+ "Float6E3M2FN",
1953
+ "as_numeric",
1954
+ "align",
1955
+ "Pointer",
1956
+ "dtype",
1957
+ "Constexpr",
1958
+ "IRConst",
1959
+ "IRValue",
1960
+ "IRVariadic",
1961
+ "implicitDowncastNumericType",
1962
+ ]
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from . import stacktrace
13
+ from . import logger
14
+ from . import timer
15
+ __all__ = [
16
+ "logger",
17
+ "timer",
18
+ "stacktrace",
19
+ ]
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/logger.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ """
13
+ This module provides logging helper functions
14
+ """
15
+
16
+ import logging
17
+
18
+ logger = None
19
+
20
+
21
+ def log():
22
+ return logger
23
+
24
+
25
+ def setup_log(
26
+ name, log_to_console=False, log_to_file=False, log_file_path=None, log_level=1
27
+ ):
28
+ """Set up and configure a logger with console and/or file handlers.
29
+
30
+ :param name: Name of the logger to create
31
+ :type name: str
32
+ :param log_to_console: Whether to enable logging to console, defaults to False
33
+ :type log_to_console: bool, optional
34
+ :param log_to_file: Whether to enable logging to file, defaults to False
35
+ :type log_to_file: bool, optional
36
+ :param log_file_path: Path to the log file, required if log_to_file is True
37
+ :type log_file_path: str, optional
38
+ :param log_level: Logging level to set, defaults to 1
39
+ :type log_level: int, optional
40
+ :raises ValueError: If log_to_file is True but log_file_path is not provided
41
+ :return: Configured logger instance
42
+ :rtype: logging.Logger
43
+ """
44
+ # Create a custom logger
45
+ global logger
46
+ logger = logging.getLogger(name)
47
+ if log_to_console or log_to_file:
48
+ logger.setLevel(log_level)
49
+ else:
50
+ # Makes sure logging is OFF
51
+ logger.setLevel(logging.CRITICAL + 1)
52
+
53
+ # Clear existing handlers to prevent duplicate logs
54
+ if logger.hasHandlers():
55
+ logger.handlers.clear()
56
+
57
+ # Define formatter
58
+ formatter = logging.Formatter(
59
+ f"%(asctime)s - %(name)s - %(levelname)s - [%(funcName)s] - %(message)s"
60
+ )
61
+
62
+ # Add console handler if enabled
63
+ if log_to_console:
64
+ console_handler = logging.StreamHandler()
65
+ console_handler.setLevel(log_level)
66
+ console_handler.setFormatter(formatter)
67
+ logger.addHandler(console_handler)
68
+
69
+ # Add file handler if enabled
70
+ if log_to_file:
71
+ if not log_file_path:
72
+ raise ValueError("log_file_path must be provided when enable_file is True")
73
+ file_handler = logging.FileHandler(log_file_path)
74
+ file_handler.setLevel(log_level)
75
+ file_handler.setFormatter(formatter)
76
+ logger.addHandler(file_handler)
77
+
78
+ return logger
79
+
80
+
81
+ logger = setup_log("generic")
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/stacktrace.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ """
13
+ This module provides stacktrace helper functions
14
+ """
15
+
16
+ import os
17
+ import re
18
+
19
+
20
+ def walk_to_top_module(start_path):
21
+ """
22
+ Walk up from the start_path to find the top-level Python module.
23
+
24
+ :param start_path: The path to start from.
25
+ :return: The path of the top-level module.
26
+ """
27
+ current_path = start_path
28
+
29
+ while True:
30
+ # Check if we are at the root directory
31
+ if os.path.dirname(current_path) == current_path:
32
+ break
33
+
34
+ # Check for __init__.py
35
+ init_file_path = os.path.join(current_path, "__init__.py")
36
+ if os.path.isfile(init_file_path):
37
+ # If __init__.py exists, move up one level
38
+ current_path = os.path.dirname(current_path)
39
+ else:
40
+ # If no __init__.py, we are not in a module; stop
41
+ break
42
+
43
+ # If we reached the root without finding a module, return None
44
+ if os.path.dirname(current_path) == current_path and not os.path.isfile(
45
+ os.path.join(current_path, "__init__.py")
46
+ ):
47
+ return None
48
+
49
+ # Return the path of the top-level module
50
+ return current_path
51
+
52
+
53
+ def _filter_internal_frames(traceback, internal_path):
54
+ """
55
+ Filter out stack frames from the traceback that belong to the specified module path.
56
+
57
+ This function removes stack frames from the traceback whose file paths start with
58
+ the given prefix_path, effectively hiding internal implementation details from
59
+ the error traceback shown to users.
60
+ """
61
+ iter_prev = None
62
+ iter_tb = traceback
63
+ while iter_tb is not None:
64
+ if os.path.abspath(iter_tb.tb_frame.f_code.co_filename).startswith(
65
+ internal_path
66
+ ):
67
+ if iter_tb.tb_next:
68
+ if iter_prev:
69
+ iter_prev.tb_next = iter_tb.tb_next
70
+ else:
71
+ traceback = iter_tb.tb_next
72
+ else:
73
+ iter_prev = iter_tb
74
+ iter_tb = iter_tb.tb_next
75
+ return traceback
76
+
77
+
78
+ _generated_function_names = re.compile(
79
+ r"^(loop_body|while_region|while_before_block|while_after_block|if_region|then_block|else_block|elif_region)_\d+$"
80
+ )
81
+
82
+
83
+ def _filter_duplicated_frames(traceback):
84
+ """
85
+ Filter out duplicated stack frames from the traceback.
86
+ The function filters out consecutive frames that are in the same file and have the same line number.
87
+ In a sequence of consecutive frames, the logic prefers to keep the non-generated frame or the last frame.
88
+ """
89
+ iter_prev = None
90
+ iter_tb = traceback
91
+ while iter_tb is not None:
92
+ skip_current = False
93
+ skip_next = False
94
+ if iter_tb.tb_next:
95
+ current_filename = os.path.abspath(iter_tb.tb_frame.f_code.co_filename)
96
+ next_filename = os.path.abspath(iter_tb.tb_next.tb_frame.f_code.co_filename)
97
+ # if in the same file, check if the line number is the same
98
+ if current_filename == next_filename:
99
+ current_lineno = iter_tb.tb_lineno
100
+ next_lineno = iter_tb.tb_next.tb_lineno
101
+ if current_lineno == next_lineno:
102
+ # Same file and line number, check name, if current is generated, skip current, otherwise skip next
103
+ name = iter_tb.tb_frame.f_code.co_name
104
+ is_generated = bool(_generated_function_names.match(name))
105
+ if is_generated:
106
+ # Skip current
107
+ skip_current = True
108
+ else:
109
+ # Skip next if it's generated, otherwise keep both
110
+ next_name = iter_tb.tb_next.tb_frame.f_code.co_name
111
+ skip_next = bool(_generated_function_names.match(next_name))
112
+ if skip_current:
113
+ if iter_prev:
114
+ iter_prev.tb_next = iter_tb.tb_next
115
+ else:
116
+ traceback = iter_tb.tb_next
117
+ elif skip_next:
118
+ # if next is last frame, don't skip
119
+ if iter_tb.tb_next.tb_next:
120
+ iter_tb.tb_next = iter_tb.tb_next.tb_next
121
+ iter_prev = iter_tb
122
+ else:
123
+ iter_prev = iter_tb
124
+ iter_tb = iter_tb.tb_next
125
+
126
+ return traceback
127
+
128
+
129
+ def filter_stackframe(traceback, prefix_path):
130
+ """
131
+ Filter out stack frames from the traceback that belong to the specified module path.
132
+
133
+ This function removes stack frames from the traceback whose file paths start with
134
+ the given prefix_path, effectively hiding internal implementation details from
135
+ the error traceback shown to users.
136
+
137
+ :param traceback: The traceback object to filter.
138
+ :param prefix_path: The path prefix to filter out from the traceback.
139
+ :return: The filtered traceback with internal frames removed.
140
+ """
141
+ # Step 1: filter internal frames
142
+ traceback = _filter_internal_frames(traceback, prefix_path)
143
+
144
+ # Step 2: consolidate duplicated frames
145
+ return _filter_duplicated_frames(traceback)
146
+
147
+
148
+ def filter_exception(value, module_dir):
149
+ """
150
+ Filter out internal implementation details from exception traceback.
151
+
152
+ This function recursively processes an exception and its cause chain,
153
+ removing stack frames that belong to the specified module directory.
154
+ This helps to present cleaner error messages to users by hiding
155
+ implementation details.
156
+
157
+ :param value: The exception object to filter.
158
+ :param module_dir: The module directory path to filter out from tracebacks.
159
+ :return: The filtered exception with internal frames removed.
160
+ """
161
+ if hasattr(value, "__cause__") and value.__cause__:
162
+ filter_exception(value.__cause__, module_dir)
163
+
164
+ if hasattr(value, "__traceback__"):
165
+ filter_stackframe(value.__traceback__, module_dir)
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/timer.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ """
13
+ This module provides a timing helper functions
14
+ """
15
+ from functools import wraps
16
+
17
+ from .logger import log
18
+
19
+
20
+ # TODO: revisit this part when mlir timing manager is ready for pybind.
21
+ def timer(*dargs, **kwargs):
22
+ enable = kwargs.get("enable", True)
23
+
24
+ def decorator(func):
25
+ @wraps(func)
26
+ def func_wrapper(*args, **kwargs):
27
+ if not enable:
28
+ return func(*args, **kwargs)
29
+ from time import time
30
+
31
+ start = time()
32
+ result = func(*args, **kwargs)
33
+ end = time()
34
+
35
+ # Convert time from seconds to us
36
+ spend_us = (end - start) * 1e6
37
+
38
+ # Determine the function type and format the log message
39
+ if hasattr(func, "__name__"):
40
+ func_name = func.__name__
41
+ log_message = f"[JIT-TIMER] Function: {func_name} | Execution Time: {spend_us:.2f} µs"
42
+ elif "CFunctionType" in str(type(func)):
43
+ log_message = f"[JIT-TIMER] C API Function: {str(func)} | Execution Time: {spend_us:.2f} µs"
44
+ else:
45
+ log_message = f"[JIT-TIMER] Anonymous Function | Execution Time: {spend_us:.2f} µs"
46
+
47
+ log().info(log_message)
48
+
49
+ return result
50
+
51
+ return func_wrapper
52
+
53
+ if len(dargs) == 1 and callable(dargs[0]):
54
+ return decorator(dargs[0])
55
+ else:
56
+ return decorator
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from .cutlass_dsl import (
13
+ Constexpr,
14
+ as_numeric,
15
+ min,
16
+ max,
17
+ and_,
18
+ or_,
19
+ all_,
20
+ any_,
21
+ not_,
22
+ all_,
23
+ any_,
24
+ select_,
25
+ # Control-flow without AST pre-processor
26
+ if_generate,
27
+ for_generate,
28
+ LoopUnroll,
29
+ while_generate,
30
+ yield_out,
31
+ # Control-flow with AST pre-processor
32
+ range_constexpr,
33
+ range_dynamic,
34
+ const_expr,
35
+ dynamic_expr,
36
+ # Data types
37
+ dtype, # Provides conversions to types inheriting from NumericType
38
+ DSLRuntimeError,
39
+ JitArgAdapterRegistry,
40
+ # Construction utilities for user-defined classes
41
+ extract_mlir_values,
42
+ new_from_mlir_values,
43
+ )
44
+
45
+ from .cute.typing import *
46
+
47
+ # Utilities not belonging to CuTe
48
+ from . import utils as utils
49
+
50
+ # Used as internal symbol
51
+ from . import cutlass_dsl as _dsl
52
+
53
+ # Aliases
54
+ LaunchConfig = _dsl.BaseDSL.LaunchConfig
55
+ register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter
56
+ gpu = _dsl.cutlass_gpu
57
+ cuda = _dsl.cuda_helpers
58
+
59
+ CACHE_FILE = "compiled_cache.db"
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/__init__.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ # Use the auto-generated enum AddressSpace
13
+ from cutlass._mlir.dialects.cute import AddressSpace
14
+
15
+ # Explicitly import types that might be directly used by other modules.
16
+ # This is a fix for using Sphinx to generate documentation
17
+ # Because Sphinx processes each module in isolation, it won't be able to rely
18
+ # on re-exported symbols via wildcard imports (from .typing import *) in the
19
+ # same way that Python does at runtime.
20
+ from .typing import (
21
+ Shape,
22
+ Stride,
23
+ IntTuple,
24
+ Coord,
25
+ Tile,
26
+ XTuple,
27
+ Tiler,
28
+ Layout,
29
+ Pointer,
30
+ Tensor,
31
+ )
32
+
33
+ # Import everything else
34
+ from .typing import *
35
+
36
+ from .core import (
37
+ assume,
38
+ is_integer,
39
+ is_int_tuple,
40
+ is_static,
41
+ size,
42
+ has_underscore,
43
+ slice_,
44
+ make_ptr,
45
+ make_layout,
46
+ recast_layout,
47
+ make_fragment_like,
48
+ depth,
49
+ rank,
50
+ flatten_to_tuple,
51
+ flatten,
52
+ unflatten,
53
+ product,
54
+ product_like,
55
+ shape,
56
+ size_in_bytes,
57
+ make_identity_layout,
58
+ make_ordered_layout,
59
+ make_composed_layout,
60
+ make_layout_tv,
61
+ make_swizzle,
62
+ recast_ptr,
63
+ make_tensor,
64
+ make_identity_tensor,
65
+ make_fragment,
66
+ recast_tensor,
67
+ get,
68
+ select,
69
+ front,
70
+ is_major,
71
+ leading_dim,
72
+ find,
73
+ find_if,
74
+ coalesce,
75
+ group_modes,
76
+ cosize,
77
+ dice,
78
+ product_each,
79
+ prepend,
80
+ append,
81
+ prepend_ones,
82
+ append_ones,
83
+ ceil_div,
84
+ slice_and_offset,
85
+ crd2idx,
86
+ domain_offset,
87
+ elem_less,
88
+ transform_leaf,
89
+ filter_zeros,
90
+ filter,
91
+ tile_to_shape,
92
+ shape_div,
93
+ composition,
94
+ complement,
95
+ right_inverse,
96
+ left_inverse,
97
+ max_common_layout,
98
+ max_common_vector,
99
+ logical_product,
100
+ zipped_product,
101
+ tiled_product,
102
+ flat_product,
103
+ raked_product,
104
+ blocked_product,
105
+ flat_divide,
106
+ logical_divide,
107
+ zipped_divide,
108
+ tiled_divide,
109
+ local_partition,
110
+ local_tile,
111
+ printf,
112
+ print_tensor,
113
+ # tiled mma/tiled copy
114
+ make_mma_atom,
115
+ make_tiled_mma,
116
+ make_copy_atom,
117
+ make_tiled_copy_tv,
118
+ make_tiled_copy,
119
+ make_tiled_copy_S,
120
+ make_tiled_copy_D,
121
+ make_tiled_copy_A,
122
+ make_tiled_copy_B,
123
+ make_tiled_copy_C,
124
+ make_tiled_copy_C_atom,
125
+ basic_copy,
126
+ basic_copy_if,
127
+ autovec_copy,
128
+ copy,
129
+ copy_atom_call,
130
+ gemm,
131
+ # Wrapper classes
132
+ ComposedLayout,
133
+ Swizzle,
134
+ E,
135
+ Atom,
136
+ MmaAtom,
137
+ CopyAtom,
138
+ TiledCopy,
139
+ TiledMma,
140
+ TensorSSA,
141
+ ReductionOp,
142
+ full,
143
+ full_like,
144
+ empty_like,
145
+ ones_like,
146
+ zeros_like,
147
+ where,
148
+ any_,
149
+ all_,
150
+ # User defined struct
151
+ struct,
152
+ pretty_str,
153
+ make_layout_image_mask,
154
+ repeat_like,
155
+ round_up,
156
+ is_congruent,
157
+ is_weakly_congruent,
158
+ ScaledBasis,
159
+ get_divisibility,
160
+ Ratio,
161
+ )
162
+
163
+ from . import arch
164
+ from . import nvgpu
165
+ from . import testing
166
+ from . import runtime
167
+
168
+ # Export all math ops without "math."
169
+ from .math import *
170
+
171
+ # Used as internal symbol
172
+ from .. import cutlass_dsl as _dsl
173
+
174
+ # Aliases
175
+ jit = _dsl.CuTeDSL.jit
176
+ kernel = _dsl.CuTeDSL.kernel
177
+ register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter
178
+ compile = _dsl.compile
179
+
180
+ # Explicitly export all symbols for documentation generation
181
+ __all__ = [
182
+ # Core types
183
+ "AddressSpace",
184
+ "Tensor",
185
+ "Layout",
186
+ "ComposedLayout",
187
+ "Swizzle",
188
+ "E",
189
+ "Atom",
190
+ "MmaAtom",
191
+ "CopyAtom",
192
+ "TiledCopy",
193
+ "TiledMma",
194
+ "TensorSSA",
195
+ # Basic utility functions
196
+ "assume",
197
+ "is_integer",
198
+ "is_int_tuple",
199
+ "is_static",
200
+ "size",
201
+ "has_underscore",
202
+ "slice_",
203
+ "depth",
204
+ "rank",
205
+ "shape",
206
+ "printf",
207
+ "print_tensor",
208
+ "pretty_str",
209
+ # Layout functions
210
+ "make_layout",
211
+ "recast_layout",
212
+ "make_identity_layout",
213
+ "make_ordered_layout",
214
+ "make_composed_layout",
215
+ "make_layout_tv",
216
+ "make_layout_image_mask",
217
+ # Tensor functions
218
+ "make_ptr",
219
+ "make_tensor",
220
+ "make_identity_tensor",
221
+ "make_fragment",
222
+ "make_fragment_like",
223
+ "recast_ptr",
224
+ "recast_tensor",
225
+ # Tensor manipulation
226
+ "get",
227
+ "select",
228
+ "front",
229
+ "is_major",
230
+ "leading_dim",
231
+ "find",
232
+ "find_if",
233
+ "coalesce",
234
+ "group_modes",
235
+ "cosize",
236
+ "size_in_bytes",
237
+ # Tuple operations
238
+ "flatten_to_tuple",
239
+ "flatten",
240
+ "product",
241
+ "product_like",
242
+ "product_each",
243
+ "prepend",
244
+ "append",
245
+ "prepend_ones",
246
+ "append_ones",
247
+ # Math operations
248
+ "ceil_div",
249
+ "round_up",
250
+ # Layout operations
251
+ "slice_and_offset",
252
+ "crd2idx",
253
+ "domain_offset",
254
+ "elem_less",
255
+ "filter_zeros",
256
+ "filter",
257
+ "tile_to_shape",
258
+ "shape_div",
259
+ "dice",
260
+ # Layout algebra
261
+ "composition",
262
+ "complement",
263
+ "right_inverse",
264
+ "left_inverse",
265
+ "max_common_layout",
266
+ "max_common_vector",
267
+ "is_congruent",
268
+ "is_weakly_congruent",
269
+ # Product operations
270
+ "logical_product",
271
+ "zipped_product",
272
+ "tiled_product",
273
+ "flat_product",
274
+ "raked_product",
275
+ "blocked_product",
276
+ # Division operations
277
+ "flat_divide",
278
+ "logical_divide",
279
+ "zipped_divide",
280
+ "tiled_divide",
281
+ "local_partition",
282
+ "local_tile",
283
+ # MMA and Copy operations
284
+ "make_mma_atom",
285
+ "make_tiled_mma",
286
+ "make_copy_atom",
287
+ "make_tiled_copy_tv",
288
+ "make_tiled_copy",
289
+ "make_tiled_copy_C_atom",
290
+ "basic_copy",
291
+ "basic_copy_if",
292
+ "autovec_copy",
293
+ "copy",
294
+ "copy_atom_call",
295
+ "gemm",
296
+ # Tensor creation
297
+ "full",
298
+ "full_like",
299
+ "empty_like",
300
+ "ones_like",
301
+ "zeros_like",
302
+ "where",
303
+ "any_",
304
+ "all_",
305
+ "repeat_like",
306
+ "ScaledBasis",
307
+ # User defined struct
308
+ "struct",
309
+ # Modules
310
+ "arch",
311
+ "nvgpu",
312
+ "testing",
313
+ "runtime",
314
+ # Decorators and code generation
315
+ "jit",
316
+ "kernel",
317
+ "register_jit_arg_adapter",
318
+ "compile",
319
+ ]
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/__init__.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from .elect import *
13
+ from .mbar import *
14
+ from .nvvm_wrappers import *
15
+ from .smem import *
16
+ from .tmem import *
17
+
18
+ # __all__ is required here for documentation generation
19
+ __all__ = [
20
+ #
21
+ # elect.py
22
+ #
23
+ "make_warp_uniform",
24
+ "elect_one",
25
+ #
26
+ # mbar.py
27
+ #
28
+ "mbarrier_init",
29
+ "mbarrier_init_fence",
30
+ "mbarrier_arrive_and_expect_tx",
31
+ "mbarrier_expect_tx",
32
+ "mbarrier_wait",
33
+ "mbarrier_try_wait",
34
+ "mbarrier_conditional_try_wait",
35
+ "mbarrier_arrive",
36
+ #
37
+ # nvvm_wrappers.py
38
+ #
39
+ "lane_idx",
40
+ "warp_idx",
41
+ "thread_idx",
42
+ "block_dim",
43
+ "block_idx",
44
+ "grid_dim",
45
+ "cluster_idx",
46
+ "cluster_dim",
47
+ "block_in_cluster_idx",
48
+ "block_in_cluster_dim",
49
+ "block_idx_in_cluster",
50
+ "shuffle_sync",
51
+ "shuffle_sync_up",
52
+ "shuffle_sync_down",
53
+ "shuffle_sync_bfly",
54
+ "barrier",
55
+ "barrier_arrive",
56
+ "sync_threads",
57
+ "sync_warp",
58
+ "fence_acq_rel_cta",
59
+ "fence_acq_rel_cluster",
60
+ "fence_acq_rel_gpu",
61
+ "fence_acq_rel_sys",
62
+ "cp_async_commit_group",
63
+ "cp_async_wait_group",
64
+ "cp_async_bulk_commit_group",
65
+ "cp_async_bulk_wait_group",
66
+ "cluster_wait",
67
+ "cluster_arrive",
68
+ "cluster_arrive_relaxed",
69
+ "fence_proxy",
70
+ "vote_ballot_sync",
71
+ "popc",
72
+ "fence_view_async_tmem_load",
73
+ "fence_view_async_tmem_store",
74
+ "warpgroup_reg_alloc",
75
+ "warpgroup_reg_dealloc",
76
+ "fma_packed_f32x2",
77
+ "mul_packed_f32x2",
78
+ "add_packed_f32x2",
79
+ "fmax",
80
+ "rcp_approx",
81
+ "exp2",
82
+ # Constants
83
+ "WARP_SIZE",
84
+ # Forward from auto-generated nvvm python
85
+ "ProxyKind",
86
+ "SharedSpace",
87
+ "RoundingModeKind",
88
+ #
89
+ # smem.py
90
+ #
91
+ "alloc_smem",
92
+ "get_dyn_smem",
93
+ "get_dyn_smem_size",
94
+ #
95
+ # tmem.py
96
+ #
97
+ "retrieve_tmem_ptr",
98
+ "alloc_tmem",
99
+ "relinquish_tmem_alloc_permit",
100
+ "dealloc_tmem",
101
+ ]
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/elect.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from cutlass.cutlass_dsl import CuTeDSL, T, dsl_user_op
13
+
14
+ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
15
+ from cutlass._mlir.dialects import nvvm, scf
16
+ from cutlass._mlir import ir
17
+
18
+ from ..typing import Int, Int32
19
+ from ...impl_utils import check_value_in
20
+
21
+
22
+ @dsl_user_op
23
+ def make_warp_uniform(value: Int, *, loc=None, ip=None) -> Int32:
24
+ """
25
+ Creates a warp-uniform value from the given integer input.
26
+
27
+ :param value: The integer to make warp uniform.
28
+ :type value: Int
29
+ :return: The warp-uniform value equal to the input.
30
+ :rtype: Int32
31
+ """
32
+ return Int32(
33
+ _cute_nvgpu_ir.arch_make_warp_uniform(
34
+ Int32(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
35
+ )
36
+ )
37
+
38
+
39
+ class IfOpRegion:
40
+ """
41
+ A context manager for if Op.
42
+ Automatically inserts `scf.yield([])` when exiting the context.
43
+ """
44
+
45
+ def __init__(self, block, *, loc=None, ip=None):
46
+ self.block = block
47
+ self.insert_point = ir.InsertionPoint(self.block)
48
+ self.loc = loc
49
+ self.ip = ip
50
+
51
+ def __enter__(self):
52
+ self.insert_point.__enter__()
53
+ return self.block.arguments
54
+
55
+ def __exit__(self, exc_type, exc_value, traceback):
56
+ scf.yield_([], loc=self.loc, ip=self.ip)
57
+ self.insert_point.__exit__(exc_type, exc_value, traceback)
58
+
59
+
60
+ @dsl_user_op
61
+ def elect_one(*, loc=None, ip=None) -> IfOpRegion:
62
+ """
63
+ Elects one thread within a warp.
64
+
65
+ .. code-block:: python
66
+
67
+ with elect_one():
68
+ # Only one thread in the warp executes the code in this context
69
+ pass
70
+ """
71
+ arch = CuTeDSL._get_dsl().envar.arch
72
+ check_value_in(
73
+ arch,
74
+ [
75
+ "sm_90",
76
+ "sm_90a",
77
+ "sm_100a",
78
+ "sm_100f",
79
+ ],
80
+ "arch",
81
+ )
82
+ is_thread_leader = nvvm.elect_sync(T.bool())
83
+ if_op = scf.IfOp(is_thread_leader, loc=loc, ip=ip)
84
+ return IfOpRegion(if_op.then_block, loc=loc, ip=ip)
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/mbar.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+ from typing import Optional
12
+
13
+ from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op
14
+
15
+ from cutlass._mlir.dialects import nvvm
16
+ from cutlass._mlir import ir
17
+
18
+ from ..typing import Pointer, Int, Boolean, Int32
19
+ from ...impl_utils import check_value_in
20
+
21
+
22
+ ####################################################################################################
23
+ #
24
+ # Mbarrier management utilities
25
+ #
26
+ ####################################################################################################
27
+
28
+
29
+ @dsl_user_op
30
+ def mbarrier_init(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None:
31
+ """
32
+ Initializes a mbarrier with the specified thread arrival count.
33
+
34
+ :param mbar_ptr: A pointer to the mbarrier in SMEM
35
+ :type mbar_ptr: Pointer
36
+ :param cnt: The arrival count of the mbarrier
37
+ :type cnt: Int
38
+ """
39
+ nvvm.mbarrier_init_shared(
40
+ mbar_ptr.llvm_ptr, Int32(cnt).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
41
+ )
42
+
43
+
44
+ @dsl_user_op
45
+ def mbarrier_init_fence(*, loc=None, ip=None) -> None:
46
+ """
47
+ A fence operation that applies to the mbarrier initializations.
48
+ """
49
+ arch = CuTeDSL._get_dsl().envar.arch
50
+ check_value_in(
51
+ arch,
52
+ [
53
+ "sm_90",
54
+ "sm_90a",
55
+ "sm_100a",
56
+ "sm_100f",
57
+ ],
58
+ "arch",
59
+ )
60
+ nvvm.fence_mbarrier_init(loc=loc, ip=ip)
61
+
62
+
63
+ @dsl_user_op
64
+ def mbarrier_arrive_and_expect_tx(
65
+ mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None
66
+ ) -> None:
67
+ """
68
+ Arrives on a mbarrier and expects a specified number of transaction bytes.
69
+
70
+ :param mbar_ptr: A pointer to the mbarrier in SMEM
71
+ :type mbar_ptr: Pointer
72
+ :param bytes: The number of transaction bytes
73
+ :type bytes: Int
74
+ :param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to
75
+ the mbarrier is converted to a remote address in the peer CTA's
76
+ SMEM.
77
+ """
78
+ arch = CuTeDSL._get_dsl().envar.arch
79
+ check_value_in(
80
+ arch,
81
+ [
82
+ "sm_90",
83
+ "sm_90a",
84
+ "sm_100a",
85
+ "sm_100f",
86
+ ],
87
+ "arch",
88
+ )
89
+
90
+ mbar_llvm_ptr = mbar_ptr.llvm_ptr
91
+ if peer_cta_rank_in_cluster is not None:
92
+ mbar_llvm_ptr = nvvm.mapa_shared_cluster(
93
+ mbar_llvm_ptr.type,
94
+ mbar_llvm_ptr,
95
+ Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
96
+ loc=loc,
97
+ ip=ip,
98
+ )
99
+ space = nvvm.MBarrierSpaceKind.CLUSTER
100
+ else:
101
+ space = nvvm.MBarrierSpaceKind.CTA
102
+
103
+ nvvm.mbarrier_txn(
104
+ mbar_llvm_ptr,
105
+ Int32(bytes).ir_value(loc=loc, ip=ip),
106
+ kind=nvvm.MBarrierTxnKind.ARRIVE_EXPECT_TX,
107
+ space=space,
108
+ loc=loc,
109
+ ip=ip,
110
+ )
111
+
112
+
113
+ @dsl_user_op
114
+ def mbarrier_expect_tx(
115
+ mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None
116
+ ) -> None:
117
+ """
118
+ Expects a specified number of transaction bytes without an arrive.
119
+
120
+ :param mbar_ptr: A pointer to the mbarrier in SMEM
121
+ :type mbar_ptr: Pointer
122
+ :param bytes: The number of transaction bytes
123
+ :type bytes: Int
124
+ :param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to
125
+ the mbarrier is converted to a remote address in the peer CTA's
126
+ SMEM.
127
+ """
128
+ arch = CuTeDSL._get_dsl().envar.arch
129
+ check_value_in(
130
+ arch,
131
+ [
132
+ "sm_90",
133
+ "sm_90a",
134
+ "sm_100a",
135
+ "sm_100f",
136
+ ],
137
+ "arch",
138
+ )
139
+
140
+ mbar_llvm_ptr = mbar_ptr.llvm_ptr
141
+ if peer_cta_rank_in_cluster is not None:
142
+ mbar_llvm_ptr = nvvm.mapa(
143
+ mbar_llvm_ptr.type,
144
+ mbar_llvm_ptr,
145
+ Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
146
+ loc=loc,
147
+ ip=ip,
148
+ )
149
+ space = nvvm.MBarrierSpaceKind.CLUSTER
150
+ else:
151
+ space = nvvm.MBarrierSpaceKind.CTA
152
+
153
+ nvvm.mbarrier_txn(
154
+ mbar_llvm_ptr,
155
+ Int32(bytes).ir_value(loc=loc, ip=ip),
156
+ kind=nvvm.MBarrierTxnKind.EXPECT_TX,
157
+ space=space,
158
+ loc=loc,
159
+ ip=ip,
160
+ )
161
+
162
+
163
+ @dsl_user_op
164
+ def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None:
165
+ """
166
+ Waits on a mbarrier with a specified phase.
167
+
168
+ :param mbar_ptr: A pointer to the mbarrier in SMEM
169
+ :type mbar_ptr: Pointer
170
+ :param phase: The phase to wait for (either 0 or 1)
171
+ :type phase: Int
172
+ """
173
+ arch = CuTeDSL._get_dsl().envar.arch
174
+ check_value_in(
175
+ arch,
176
+ [
177
+ "sm_90",
178
+ "sm_90a",
179
+ "sm_100a",
180
+ "sm_100f",
181
+ ],
182
+ "arch",
183
+ )
184
+
185
+ timeout_ns = 10000000
186
+ # This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX
187
+ # The timeout in ns only applies to the latter and this call is truly blocking
188
+ nvvm.mbarrier_try_wait_parity_shared(
189
+ mbar_ptr.llvm_ptr,
190
+ Int32(phase).ir_value(loc=loc, ip=ip),
191
+ Int32(timeout_ns).ir_value(loc=loc, ip=ip),
192
+ loc=loc,
193
+ ip=ip,
194
+ )
195
+
196
+
197
+ @dsl_user_op
198
+ def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Boolean:
199
+ """
200
+ Attempts to wait on a mbarrier with a specified phase in a non-blocking fashion.
201
+
202
+ :param mbar_ptr: A pointer to the mbarrier in SMEM
203
+ :type mbar_ptr: Pointer
204
+ :param phase: The phase to wait for (either 0 or 1)
205
+ :type phase: Int
206
+ :return: A boolean value indicating whether the wait operation was successful
207
+ :rtype: Boolean
208
+ """
209
+ arch = CuTeDSL._get_dsl().envar.arch
210
+ check_value_in(
211
+ arch,
212
+ [
213
+ "sm_90",
214
+ "sm_90a",
215
+ "sm_100a",
216
+ "sm_100f",
217
+ ],
218
+ "arch",
219
+ )
220
+
221
+ return Boolean(
222
+ nvvm.mbarrier_wait_parity(
223
+ T.bool(),
224
+ mbar_ptr.llvm_ptr,
225
+ Int32(phase).ir_value(loc=loc, ip=ip),
226
+ nvvm.MBarrierWaitKind.TRY,
227
+ loc=loc,
228
+ ip=ip,
229
+ )
230
+ )
231
+
232
+
233
+ @dsl_user_op
234
+ def mbarrier_conditional_try_wait(
235
+ cond, mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None
236
+ ) -> Boolean:
237
+ """
238
+ Conditionally attempts to wait on a mbarrier with a specified phase in a non-blocking fashion.
239
+
240
+ :param cond: A boolean predicate
241
+ :param mbar_ptr: A pointer to the mbarrier in SMEM
242
+ :type mbar_ptr: Pointer
243
+ :param phase: The phase to wait for (either 0 or 1)
244
+ :type phase: Int
245
+ :return: A boolean value indicating whether the wait operation was successful
246
+ :rtype: Boolean
247
+ """
248
+ arch = CuTeDSL._get_dsl().envar.arch
249
+ check_value_in(
250
+ arch,
251
+ [
252
+ "sm_90",
253
+ "sm_90a",
254
+ "sm_100a",
255
+ "sm_100f",
256
+ ],
257
+ "arch",
258
+ )
259
+ return if_generate(
260
+ cond,
261
+ lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip),
262
+ lambda: Boolean(True).ir_value(loc=loc, ip=ip),
263
+ None,
264
+ [Boolean],
265
+ )
266
+
267
+
268
+ @dsl_user_op
269
+ def mbarrier_arrive(
270
+ mbar_ptr: Pointer,
271
+ peer_cta_rank_in_cluster: Optional[Int] = None,
272
+ *,
273
+ loc=None,
274
+ ip=None,
275
+ ) -> None:
276
+ """
277
+ Arrives on an mbarrier.
278
+
279
+ :param mbar_ptr: A pointer to the mbarrier in SMEM
280
+ :type mbar_ptr: Pointer
281
+ :param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to
282
+ the mbarrier is converted to a remote address in the peer CTA's
283
+ SMEM.
284
+ """
285
+ mbar_llvm_ptr = mbar_ptr.llvm_ptr
286
+ if peer_cta_rank_in_cluster is not None:
287
+ arch = CuTeDSL._get_dsl().envar.arch
288
+ check_value_in(
289
+ arch,
290
+ [
291
+ "sm_90",
292
+ "sm_90a",
293
+ "sm_100a",
294
+ "sm_100f",
295
+ ],
296
+ "arch",
297
+ )
298
+
299
+ mbar_llvm_ptr = nvvm.mapa_shared_cluster(
300
+ mbar_llvm_ptr.type,
301
+ mbar_llvm_ptr,
302
+ Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
303
+ loc=loc,
304
+ ip=ip,
305
+ )
306
+ space = nvvm.MBarrierSpaceKind.CLUSTER
307
+ else:
308
+ space = nvvm.MBarrierSpaceKind.CTA
309
+
310
+ nvvm.mbarrier_txn(
311
+ mbar_llvm_ptr,
312
+ Int32(1).ir_value(loc=loc, ip=ip),
313
+ kind=nvvm.MBarrierTxnKind.ARRIVE,
314
+ space=space,
315
+ loc=loc,
316
+ ip=ip,
317
+ )
318
+
319
+
320
+ @dsl_user_op
321
+ def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> None:
322
+ """
323
+ Arrives on an mbarrier for async load **without incrementing** the arrival count
324
+ (`cp.async.mbarrier.arrive.shared ..., noinc=1`).
325
+ Used in the warp-specialized kernel when the non-TMA load warp(producer) is not the same
326
+ as the math/epilogue warp(consumer).
327
+
328
+ :param mbar_ptr: A pointer to the mbarrier in SMEM
329
+ :type mbar_ptr: Pointer
330
+ """
331
+ arch = CuTeDSL._get_dsl().envar.arch
332
+ check_value_in(
333
+ arch,
334
+ [
335
+ "sm_90",
336
+ "sm_90a",
337
+ "sm_100a",
338
+ "sm_100f",
339
+ ],
340
+ "arch",
341
+ )
342
+
343
+ mbar_llvm_ptr = mbar_ptr.llvm_ptr
344
+ nvvm.cp_async_mbarrier_arrive_shared(
345
+ mbar_llvm_ptr,
346
+ noinc=True,
347
+ loc=loc,
348
+ ip=ip,
349
+ )
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from functools import partial
13
+ from typing import Optional, Tuple, Union, Callable
14
+ from typing_extensions import deprecated
15
+
16
+ from cutlass.cutlass_dsl import T, dsl_user_op
17
+
18
+ from cutlass._mlir import ir
19
+ from cutlass._mlir.dialects import llvm, nvvm, vector
20
+
21
+ # Forward nvvm enums
22
+ from cutlass._mlir.dialects.nvvm import (
23
+ ProxyKind,
24
+ SharedSpace,
25
+ Tcgen05WaitKind,
26
+ SetMaxRegisterAction,
27
+ RoundingModeKind,
28
+ )
29
+
30
+ from ..typing import (
31
+ Int,
32
+ Boolean,
33
+ Int16,
34
+ Uint16,
35
+ Int32,
36
+ Uint32,
37
+ Int64,
38
+ Float32,
39
+ BFloat16,
40
+ Numeric,
41
+ as_numeric,
42
+ )
43
+
44
+ WARP_SIZE = 32
45
+ FULL_MASK = 0xFFFFFFFF
46
+
47
+
48
+ @dsl_user_op
49
+ def lane_idx(*, loc=None, ip=None) -> Int32:
50
+ """
51
+ Returns the lane index of the current thread within the warp.
52
+ """
53
+ return Int32(nvvm.read_ptx_sreg_laneid(T.i32(), loc=loc, ip=ip))
54
+
55
+
56
+ @dsl_user_op
57
+ def warp_idx(*, loc=None, ip=None) -> Int32:
58
+ """
59
+ Returns the warp index within a CTA.
60
+ """
61
+ warp_size = 32
62
+ tid_x = Int32(nvvm.read_ptx_sreg_tid_x(T.i32(), loc=loc, ip=ip))
63
+ tid_y = Int32(nvvm.read_ptx_sreg_tid_y(T.i32(), loc=loc, ip=ip))
64
+ tid_z = Int32(nvvm.read_ptx_sreg_tid_z(T.i32(), loc=loc, ip=ip))
65
+ ntid_x = Int32(nvvm.read_ptx_sreg_ntid_x(T.i32(), loc=loc, ip=ip))
66
+ ntid_y = Int32(nvvm.read_ptx_sreg_ntid_y(T.i32(), loc=loc, ip=ip))
67
+ tid = tid_x + tid_y * ntid_x + tid_z * ntid_x * ntid_y
68
+ return tid // warp_size
69
+
70
+
71
+ @dsl_user_op
72
+ def thread_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
73
+ """
74
+ Returns the thread index within a CTA.
75
+ """
76
+ return (
77
+ Int32(nvvm.read_ptx_sreg_tid_x(T.i32(), loc=loc, ip=ip)),
78
+ Int32(nvvm.read_ptx_sreg_tid_y(T.i32(), loc=loc, ip=ip)),
79
+ Int32(nvvm.read_ptx_sreg_tid_z(T.i32(), loc=loc, ip=ip)),
80
+ )
81
+
82
+
83
+ @dsl_user_op
84
+ def block_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
85
+ """
86
+ Returns the number of threads in each dimension of the CTA.
87
+ """
88
+ return (
89
+ Int32(nvvm.read_ptx_sreg_ntid_x(T.i32(), loc=loc, ip=ip)),
90
+ Int32(nvvm.read_ptx_sreg_ntid_y(T.i32(), loc=loc, ip=ip)),
91
+ Int32(nvvm.read_ptx_sreg_ntid_z(T.i32(), loc=loc, ip=ip)),
92
+ )
93
+
94
+
95
+ @dsl_user_op
96
+ def block_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
97
+ """
98
+ Returns the CTA identifier within a grid.
99
+ """
100
+ return (
101
+ Int32(nvvm.read_ptx_sreg_ctaid_x(T.i32(), loc=loc, ip=ip)),
102
+ Int32(nvvm.read_ptx_sreg_ctaid_y(T.i32(), loc=loc, ip=ip)),
103
+ Int32(nvvm.read_ptx_sreg_ctaid_z(T.i32(), loc=loc, ip=ip)),
104
+ )
105
+
106
+
107
+ @dsl_user_op
108
+ def grid_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
109
+ """
110
+ Returns the number of CTAs in each dimension of the grid.
111
+ """
112
+ return (
113
+ Int32(nvvm.read_ptx_sreg_nctaid_x(T.i32(), loc=loc, ip=ip)),
114
+ Int32(nvvm.read_ptx_sreg_nctaid_y(T.i32(), loc=loc, ip=ip)),
115
+ Int32(nvvm.read_ptx_sreg_nctaid_z(T.i32(), loc=loc, ip=ip)),
116
+ )
117
+
118
+
119
+ @dsl_user_op
120
+ def cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
121
+ """
122
+ Returns the cluster identifier within a grid.
123
+ """
124
+ return (
125
+ Int32(nvvm.read_ptx_sreg_clusterid_x(T.i32(), loc=loc, ip=ip)),
126
+ Int32(nvvm.read_ptx_sreg_clusterid_y(T.i32(), loc=loc, ip=ip)),
127
+ Int32(nvvm.read_ptx_sreg_clusterid_z(T.i32(), loc=loc, ip=ip)),
128
+ )
129
+
130
+
131
+ @dsl_user_op
132
+ def cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
133
+ """
134
+ Returns the number of clusters in each dimension of the grid.
135
+ """
136
+ return (
137
+ Int32(nvvm.read_ptx_sreg_nclusterid_x(T.i32(), loc=loc, ip=ip)),
138
+ Int32(nvvm.read_ptx_sreg_nclusterid_y(T.i32(), loc=loc, ip=ip)),
139
+ Int32(nvvm.read_ptx_sreg_nclusterid_z(T.i32(), loc=loc, ip=ip)),
140
+ )
141
+
142
+
143
+ @dsl_user_op
144
+ def block_in_cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
145
+ """
146
+ Returns the CTA index within a cluster across all dimensions.
147
+ """
148
+ return (
149
+ Int32(nvvm.read_ptx_sreg_cluster_ctaid_x(T.i32(), loc=loc, ip=ip)),
150
+ Int32(nvvm.read_ptx_sreg_cluster_ctaid_y(T.i32(), loc=loc, ip=ip)),
151
+ Int32(nvvm.read_ptx_sreg_cluster_ctaid_z(T.i32(), loc=loc, ip=ip)),
152
+ )
153
+
154
+
155
+ @dsl_user_op
156
+ def block_in_cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
157
+ """
158
+ Returns the dimensions of the cluster.
159
+ """
160
+ return (
161
+ Int32(nvvm.read_ptx_sreg_cluster_nctaid_x(T.i32(), loc=loc, ip=ip)),
162
+ Int32(nvvm.read_ptx_sreg_cluster_nctaid_y(T.i32(), loc=loc, ip=ip)),
163
+ Int32(nvvm.read_ptx_sreg_cluster_nctaid_z(T.i32(), loc=loc, ip=ip)),
164
+ )
165
+
166
+
167
+ @dsl_user_op
168
+ def block_idx_in_cluster(*, loc=None, ip=None) -> Int32:
169
+ """
170
+ Returns the linearized identifier of the CTA within the cluster.
171
+ """
172
+ return Int32(nvvm.read_ptx_sreg_cluster_ctarank(T.i32(), loc=loc, ip=ip))
173
+
174
+
175
+ @dsl_user_op
176
+ def shuffle_sync_op(
177
+ value: Numeric,
178
+ offset: Int,
179
+ mask: Int = FULL_MASK,
180
+ mask_and_clamp: Int = WARP_SIZE - 1,
181
+ kind: nvvm.ShflKind = nvvm.ShflKind.idx,
182
+ *,
183
+ loc=None,
184
+ ip=None,
185
+ ) -> Numeric:
186
+ """
187
+ Shuffles a value within the threads of a warp.
188
+
189
+ :param value: The value to shuffle
190
+ :type value: Numeric
191
+ :param mask: A mask describing the threads participating in this operation
192
+ :type mask: Int
193
+ :param offset: A source lane or a source lane offset depending on kind
194
+ :type offset: Int
195
+ :param mask_and_clamp: An integer containing two packed values specifying a mask for logically
196
+ splitting warps into sub-segments and an upper bound for clamping the
197
+ source lane index.
198
+ :type mask_and_clamp: Int
199
+ :param kind: The kind of shuffle, can be idx, up, down, or bfly
200
+ :type kind: ShflKind
201
+ :return: The shuffled value
202
+ :rtype: Numeric
203
+ """
204
+ if not isinstance(value, Numeric):
205
+ value = as_numeric(value)
206
+ if value.width > 64:
207
+ raise ValueError("shuffle_sync only supports values up to 64 bits")
208
+
209
+ orig_type = type(value)
210
+ if value.width < 32:
211
+ if value.dtype.is_float:
212
+ value = value.to(Float32)
213
+ else:
214
+ if value.signed:
215
+ value = value.to(Int32)
216
+ else:
217
+ value = value.to(Uint32)
218
+ return orig_type(
219
+ nvvm.shfl_sync(
220
+ type(value).mlir_type,
221
+ Int32(mask).ir_value(loc=loc, ip=ip),
222
+ value.ir_value(loc=loc, ip=ip),
223
+ Int32(offset).ir_value(loc=loc, ip=ip),
224
+ Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
225
+ kind,
226
+ loc=loc,
227
+ ip=ip,
228
+ )
229
+ )
230
+ elif value.width == 32:
231
+ return orig_type(
232
+ nvvm.shfl_sync(
233
+ type(value).mlir_type,
234
+ Int32(mask).ir_value(loc=loc, ip=ip),
235
+ value.ir_value(loc=loc, ip=ip),
236
+ Int32(offset).ir_value(loc=loc, ip=ip),
237
+ Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
238
+ kind,
239
+ loc=loc,
240
+ ip=ip,
241
+ )
242
+ )
243
+ else:
244
+ if value.width != 64:
245
+ raise ValueError(
246
+ "shuffle_sync only supports 64 bits values when the bit width is larger than 32"
247
+ )
248
+ value = llvm.bitcast(
249
+ T.i64(), value.to(ir.Value, loc=loc, ip=ip), loc=loc, ip=ip
250
+ )
251
+ # extract low 32 bits
252
+ low_32_bits = llvm.trunc(
253
+ T.i32(), value, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip
254
+ )
255
+ # extract high 32 bits
256
+ high_32_bits = llvm.lshr(
257
+ value, Int64(32).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
258
+ )
259
+ high_32_bits = llvm.trunc(
260
+ T.i32(), high_32_bits, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip
261
+ )
262
+
263
+ low_32_bits_shfl = nvvm.shfl_sync(
264
+ T.i32(),
265
+ Int32(mask).ir_value(loc=loc, ip=ip),
266
+ low_32_bits,
267
+ Int32(offset).ir_value(loc=loc, ip=ip),
268
+ Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
269
+ kind,
270
+ loc=loc,
271
+ ip=ip,
272
+ )
273
+ high_32_bits_shfl = nvvm.shfl_sync(
274
+ T.i32(),
275
+ Int32(mask).ir_value(loc=loc, ip=ip),
276
+ high_32_bits,
277
+ Int32(offset).ir_value(loc=loc, ip=ip),
278
+ Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
279
+ kind,
280
+ loc=loc,
281
+ ip=ip,
282
+ )
283
+
284
+ # combine low and high 32 bits
285
+ low_64_bit = llvm.zext(T.i64(), low_32_bits_shfl, loc=loc, ip=ip)
286
+ high_64_bit = llvm.zext(T.i64(), high_32_bits_shfl, loc=loc, ip=ip)
287
+ shlf_res = llvm.shl(
288
+ high_64_bit,
289
+ Int64(32).ir_value(loc=loc, ip=ip),
290
+ llvm.IntegerOverflowFlags.none,
291
+ loc=loc,
292
+ ip=ip,
293
+ )
294
+ shlf_res = llvm.or_(shlf_res, low_64_bit, loc=loc, ip=ip)
295
+ shlf_res = llvm.bitcast(orig_type.mlir_type, shlf_res, loc=loc, ip=ip)
296
+ return orig_type(shlf_res)
297
+
298
+ shuffle_sync = partial(shuffle_sync_op, kind=nvvm.ShflKind.idx)
299
+ shuffle_sync_up = partial(shuffle_sync_op, kind=nvvm.ShflKind.up)
300
+ shuffle_sync_down = partial(shuffle_sync_op, kind=nvvm.ShflKind.down)
301
+ shuffle_sync_bfly = partial(shuffle_sync_op, kind=nvvm.ShflKind.bfly)
302
+
303
+
304
+ @dsl_user_op
305
+ def barrier(*, barrier_id=None, number_of_threads=None, loc=None, ip=None) -> None:
306
+ """
307
+ Creates a barrier, optionally named.
308
+ """
309
+ if barrier_id is not None:
310
+ barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip)
311
+
312
+ if number_of_threads is not None:
313
+ number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip)
314
+
315
+ nvvm.barrier(
316
+ barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip
317
+ )
318
+
319
+
320
+ @dsl_user_op
321
+ def barrier_arrive(
322
+ *, barrier_id=None, number_of_threads=None, loc=None, ip=None
323
+ ) -> None:
324
+ if barrier_id is not None:
325
+ barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip)
326
+
327
+ if number_of_threads is None:
328
+ raise ValueError(
329
+ "barrier_arrive needs pass number_of_threads to arrive the barrier",
330
+ )
331
+ number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip)
332
+
333
+ nvvm.barrier_arrive(
334
+ barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip
335
+ )
336
+
337
+
338
+ @dsl_user_op
339
+ def sync_threads(*, loc=None, ip=None) -> None:
340
+ """
341
+ Synchronizes all threads within a CTA.
342
+ """
343
+ nvvm.barrier(loc=loc, ip=ip)
344
+
345
+
346
+ @dsl_user_op
347
+ def sync_warp(mask: Int = FULL_MASK, *, loc=None, ip=None) -> None:
348
+ """
349
+ Performs a warp-wide sync with an optional mask.
350
+ """
351
+ nvvm.bar_warp_sync(Int32(mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)
352
+
353
+
354
+ @dsl_user_op
355
+ def fence_acq_rel_cta(*, loc=None, ip=None) -> None:
356
+ """
357
+ Fence operation with acquire-release semantics.
358
+
359
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
360
+ """
361
+ nvvm.fence_acq_rel_cta(loc=loc, ip=ip)
362
+
363
+
364
+ @dsl_user_op
365
+ def fence_acq_rel_cluster(*, loc=None, ip=None) -> None:
366
+ """
367
+ Fence operation with acquire-release semantics.
368
+
369
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
370
+ """
371
+ nvvm.fence_acq_rel_cluster(loc=loc, ip=ip)
372
+
373
+
374
+ @dsl_user_op
375
+ def fence_acq_rel_gpu(*, loc=None, ip=None) -> None:
376
+ """
377
+ Fence operation with acquire-release semantics.
378
+
379
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
380
+ """
381
+ nvvm.fence_acq_rel_gpu(loc=loc, ip=ip)
382
+
383
+
384
+ @dsl_user_op
385
+ def fence_acq_rel_sys(*, loc=None, ip=None) -> None:
386
+ """
387
+ Fence operation with acquire-release semantics.
388
+
389
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
390
+ """
391
+ nvvm.fence_acq_rel_sys(loc=loc, ip=ip)
392
+
393
+
394
+ @dsl_user_op
395
+ def cp_async_commit_group(*, loc=None, ip=None) -> None:
396
+ """
397
+ Commits all prior initiated but uncommitted cp.async instructions.
398
+
399
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-commit-group>`__.
400
+ """
401
+ nvvm.cp_async_commit_group(loc=loc, ip=ip)
402
+
403
+
404
+ @dsl_user_op
405
+ def cp_async_wait_group(n, *, loc=None, ip=None) -> None:
406
+ """
407
+ Waits till only a specified numbers of cp.async groups are pending.
408
+
409
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-wait-group-cp-async-wait-all>`__.
410
+ """
411
+ nvvm.cp_async_wait_group(n, loc=loc, ip=ip)
412
+
413
+
414
+ @dsl_user_op
415
+ def cp_async_bulk_commit_group(*, loc=None, ip=None) -> None:
416
+ """
417
+ Commits all prior initiated but uncommitted cp.async.bulk instructions.
418
+
419
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-commit-group>`__.
420
+ """
421
+ nvvm.cp_async_bulk_commit_group(loc=loc, ip=ip)
422
+
423
+
424
+ @dsl_user_op
425
+ def cp_async_bulk_wait_group(group, *, read=None, loc=None, ip=None) -> None:
426
+ """
427
+ Waits till only a specified numbers of cp.async.bulk groups are pending.
428
+
429
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-wait-group>`__.
430
+ """
431
+ nvvm.cp_async_bulk_wait_group(group, read=read, loc=loc, ip=ip)
432
+
433
+
434
+ @dsl_user_op
435
+ def cluster_wait(*, loc=None, ip=None) -> None:
436
+ """
437
+ A cluster-wide wait operation.
438
+ """
439
+ nvvm.cluster_wait(loc=loc, ip=ip)
440
+
441
+
442
+ @dsl_user_op
443
+ def cluster_arrive(*, aligned=None, loc=None, ip=None) -> None:
444
+ """
445
+ A cluster-wide arrive operation.
446
+ """
447
+ nvvm.cluster_arrive(aligned=aligned, loc=loc, ip=ip)
448
+
449
+
450
+ @dsl_user_op
451
+ def cluster_arrive_relaxed(*, aligned=None, loc=None, ip=None) -> None:
452
+ """
453
+ A cluster-wide arrive operation with relaxed semantics.
454
+ """
455
+ nvvm.cluster_arrive_relaxed(aligned=aligned, loc=loc, ip=ip)
456
+
457
+
458
+ @dsl_user_op
459
+ def fence_proxy(
460
+ kind: ProxyKind,
461
+ *,
462
+ space: Optional[SharedSpace] = None,
463
+ use_intrinsic=None,
464
+ loc=None,
465
+ ip=None,
466
+ ) -> None:
467
+ nvvm.fence_proxy(
468
+ kind=kind, space=space, use_intrinsic=use_intrinsic, loc=loc, ip=ip
469
+ )
470
+
471
+
472
+ @dsl_user_op
473
+ def vote_ballot_sync(
474
+ pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None
475
+ ) -> Int32:
476
+ """
477
+ Performs a ballot operation across the warp.
478
+ """
479
+ return Int32(
480
+ nvvm.vote_ballot_sync(
481
+ T.i32(),
482
+ Int32(mask).ir_value(loc=loc, ip=ip),
483
+ Boolean(pred).ir_value(loc=loc, ip=ip),
484
+ loc=loc,
485
+ ip=ip,
486
+ )
487
+ )
488
+
489
+
490
+ @dsl_user_op
491
+ def popc(value: Numeric, *, loc=None, ip=None) -> Numeric:
492
+ """
493
+ Performs a population count operation.
494
+ """
495
+ if not isinstance(value, Numeric):
496
+ value = as_numeric(value)
497
+ return type(value)(llvm.intr_ctpop(value.ir_value(), loc=loc, ip=ip))
498
+
499
+
500
+ @dsl_user_op
501
+ def fence_view_async_tmem_op(
502
+ kind: Tcgen05WaitKind,
503
+ *,
504
+ loc=None,
505
+ ip=None,
506
+ ) -> None:
507
+ """
508
+ Perform a fence operation on the async TMEM load or store.
509
+
510
+ .. note::
511
+ This function is only available on sm_100a and above.
512
+ The fence is required to synchronize the TMEM load/store
513
+ and let the pipeline release or commit the buffer.
514
+
515
+ Take a mma2acc pipeline as an example of LOAD fence, the ACC tensor is from TMEM.
516
+ ```
517
+ # Start to copy ACC from TMEM to register
518
+ cute.copy(tmem_load, tACC, rACC)
519
+ fence_view_async_tmem_load()
520
+ # After fence, we can ensure the TMEM buffer is consumed totally.
521
+ # Release the buffer to let the MMA know it can overwrite the buffer.
522
+ mma2accum_pipeline.consumer_release(curr_consumer_state)
523
+ ```
524
+ Take a TS GEMM kernel as an example of STORE fence, the A tensor is from TMEM.
525
+ ```
526
+ # Start to copy A from register to TMEM
527
+ cute.copy(tmem_store, rA, tA)
528
+ fence_view_async_tmem_store()
529
+ # After fence, we can ensure the TMEM buffer is ready.
530
+ # Commit the buffer to let the MMA know it can start to load A.
531
+ tmem_mma_pipeline.producer_commit(curr_producer_state)
532
+ ```
533
+
534
+
535
+ :param kind: The kind of fence operation to perform including LOAD and STORE.
536
+ :type kind: Tcgen05WaitKind
537
+ """
538
+ nvvm.tcgen05_wait(kind, loc=loc, ip=ip)
539
+
540
+
541
+ fence_view_async_tmem_load = partial(
542
+ fence_view_async_tmem_op, kind=Tcgen05WaitKind.LOAD
543
+ )
544
+ fence_view_async_tmem_store = partial(
545
+ fence_view_async_tmem_op, kind=Tcgen05WaitKind.STORE
546
+ )
547
+
548
+
549
+ @dsl_user_op
550
+ def warpgroup_reg_realloc_op(
551
+ reg_count: int,
552
+ kind: SetMaxRegisterAction,
553
+ *,
554
+ loc=None,
555
+ ip=None,
556
+ ) -> None:
557
+ nvvm.setmaxregister(reg_count, kind, loc=loc, ip=ip)
558
+
559
+
560
+ warpgroup_reg_alloc = partial(
561
+ warpgroup_reg_realloc_op, kind=SetMaxRegisterAction.increase
562
+ )
563
+ warpgroup_reg_dealloc = partial(
564
+ warpgroup_reg_realloc_op, kind=SetMaxRegisterAction.decrease
565
+ )
566
+
567
+
568
+ @dsl_user_op
569
+ def calc_packed_f32x2_op(
570
+ src_a: Tuple[Float32, Float32],
571
+ src_b: Tuple[Float32, Float32],
572
+ src_c: Tuple[Float32, Float32] | None,
573
+ calc_func: Callable,
574
+ *,
575
+ rnd=RoundingModeKind.RZ,
576
+ ftz=True,
577
+ loc=None,
578
+ ip=None,
579
+ ) -> Tuple[Float32, Float32]:
580
+ vec_type = ir.VectorType.get([2], Float32.mlir_type, loc=loc)
581
+ vec_src_a = vector.from_elements(
582
+ vec_type, tuple(as_numeric(a).ir_value() for a in src_a), loc=loc, ip=ip
583
+ )
584
+ vec_src_b = vector.from_elements(
585
+ vec_type, tuple(as_numeric(b).ir_value() for b in src_b), loc=loc, ip=ip
586
+ )
587
+ if src_c is not None:
588
+ vec_src_c = vector.from_elements(
589
+ vec_type, tuple(as_numeric(c).ir_value() for c in src_c), loc=loc, ip=ip
590
+ )
591
+ vec_res = calc_func(
592
+ vec_type, vec_src_a, vec_src_b, vec_src_c, rnd=rnd, ftz=ftz, loc=loc, ip=ip
593
+ )
594
+ else:
595
+ vec_res = calc_func(
596
+ vec_type, vec_src_a, vec_src_b, rnd=rnd, ftz=ftz, loc=loc, ip=ip
597
+ )
598
+
599
+ res0 = Float32(
600
+ vector.extract(
601
+ vec_res, dynamic_position=[], static_position=[0], loc=loc, ip=ip
602
+ )
603
+ )
604
+ res1 = Float32(
605
+ vector.extract(
606
+ vec_res, dynamic_position=[], static_position=[1], loc=loc, ip=ip
607
+ )
608
+ )
609
+ return res0, res1
610
+
611
+
612
+ fma_packed_f32x2 = partial(calc_packed_f32x2_op, calc_func=nvvm.fma_packed_f32x2)
613
+ mul_packed_f32x2 = partial(
614
+ calc_packed_f32x2_op, src_c=None, calc_func=nvvm.mul_packed_f32x2
615
+ )
616
+ add_packed_f32x2 = partial(
617
+ calc_packed_f32x2_op, src_c=None, calc_func=nvvm.add_packed_f32x2
618
+ )
619
+
620
+
621
+ @dsl_user_op
622
+ def fmax(
623
+ a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None
624
+ ) -> Float32:
625
+ return Float32(
626
+ nvvm.fmax(
627
+ T.f32(),
628
+ Float32(a).ir_value(loc=loc, ip=ip),
629
+ Float32(b).ir_value(loc=loc, ip=ip),
630
+ loc=loc,
631
+ ip=ip,
632
+ )
633
+ )
634
+
635
+
636
+ @dsl_user_op
637
+ def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None):
638
+ return Float32(
639
+ nvvm.rcp_approx_ftz_f(
640
+ T.f32(), Float32(a).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
641
+ )
642
+ )
643
+
644
+
645
+ @dsl_user_op
646
+ @deprecated(
647
+ "cute.arch.exp2 is deprecated, use cute.math.exp2 with `fastmath=True` instead"
648
+ )
649
+ def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
650
+ return Float32(
651
+ llvm.inline_asm(
652
+ T.f32(),
653
+ [Float32(a).ir_value(loc=loc, ip=ip)],
654
+ "ex2.approx.ftz.f32 $0, $1;",
655
+ "=f,f",
656
+ has_side_effects=True,
657
+ is_align_stack=False,
658
+ asm_dialect=llvm.AsmDialect.AD_ATT,
659
+ )
660
+ )
661
+
662
+
663
+ @dsl_user_op
664
+ @deprecated(
665
+ "cute.arch.exp is deprecated, use cute.math.exp with `fastmath=True` instead"
666
+ )
667
+ def exp(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
668
+ LOG2_E = 1.4426950408889634
669
+ return exp2(a * LOG2_E, loc=loc, ip=ip)
670
+
671
+
672
+ @dsl_user_op
673
+ @deprecated(
674
+ "cute.arch.exp_packed_f32x2 is deprecated, use cute.arch.mul_packed_f32x2 and cute.math.exp2 with `fastmath=True` instead"
675
+ )
676
+ def exp_packed_f32x2(
677
+ a: Tuple[Float32, Float32], *, loc=None, ip=None
678
+ ) -> Tuple[Float32, Float32]:
679
+ LOG2_E = Float32(1.4426950408889634)
680
+ b = mul_packed_f32x2(a, (LOG2_E, LOG2_E), loc=loc, ip=ip)
681
+ return exp2(b[0], loc=loc, ip=ip), exp2(b[1], loc=loc, ip=ip)
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/smem.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from typing import Optional, Type
13
+
14
+ from cutlass.cutlass_dsl import T, dsl_user_op
15
+
16
+ import cutlass._mlir.dialects.cute as _cute_ir
17
+ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
18
+ from cutlass._mlir import ir
19
+
20
+ from ..typing import Pointer, Numeric, NumericMeta
21
+
22
+
23
+ @dsl_user_op
24
+ def alloc_smem(
25
+ element_type: Type[Numeric],
26
+ size_in_elems: int,
27
+ alignment: Optional[int] = None,
28
+ *,
29
+ loc=None,
30
+ ip=None,
31
+ ) -> Pointer:
32
+ """
33
+ Statically allocates SMEM.
34
+
35
+ :param element_type: The pointee type of the pointer.
36
+ :type element_type: Type[Numeric]
37
+ :param size_in_elems: The size of the allocation in terms of number of elements of the
38
+ pointee type
39
+ :type size_in_elems: int
40
+ :param alignment: An optional pointer alignment for the allocation
41
+ :type alignment: int
42
+ :return: A pointer to the start of the allocation
43
+ :rtype: Pointer
44
+ """
45
+ if not isinstance(element_type, NumericMeta):
46
+ raise TypeError(
47
+ f"element_type must be a type of Numeric, but got {element_type}"
48
+ )
49
+
50
+ if alignment is None:
51
+ # Default alignment based on the element type's width
52
+ alignment = element_type.width // 8
53
+ ptr_ty = _cute_ir.PtrType.get(
54
+ element_type.mlir_type, _cute_ir.AddressSpace.smem, alignment
55
+ )
56
+ return _cute_nvgpu_ir.arch_alloc_smem(
57
+ ptr=ptr_ty,
58
+ input=ir.IntegerAttr.get(T.i32(), size_in_elems),
59
+ loc=loc,
60
+ ip=ip,
61
+ )
62
+
63
+
64
+ @dsl_user_op
65
+ def get_dyn_smem(
66
+ element_type: Type[Numeric],
67
+ alignment: Optional[int] = None,
68
+ *,
69
+ loc=None,
70
+ ip=None,
71
+ ) -> Pointer:
72
+ """
73
+ Retrieves a pointer to a dynamic SMEM allocation.
74
+
75
+ :param element_type: The pointee type of the pointer.
76
+ :type element_type: Type[Numeric]
77
+ :param alignment: An optional pointer alignment, the result pointer is offset appropriately
78
+ :type alignment: int
79
+ :return: A pointer to the start of the dynamic SMEM allocation with a correct
80
+ alignement
81
+ :rtype: Pointer
82
+ """
83
+ if not isinstance(element_type, NumericMeta):
84
+ raise TypeError(
85
+ f"element_type must be a type of Numeric, but got {element_type}"
86
+ )
87
+
88
+ if alignment is None:
89
+ # Default alignment based on the element type's width
90
+ alignment = element_type.width // 8
91
+ ptr_ty = _cute_ir.PtrType.get(
92
+ element_type.mlir_type,
93
+ _cute_ir.AddressSpace.smem,
94
+ alignment,
95
+ )
96
+ return _cute_nvgpu_ir.arch_get_dyn_smem(ptr=ptr_ty, loc=loc, ip=ip)
97
+
98
+
99
+ @dsl_user_op
100
+ def get_dyn_smem_size(*, loc=None, ip=None) -> int:
101
+ """
102
+ Gets the size in bytes of the dynamic shared memory that was specified at kernel launch time.
103
+ This can be used for bounds checking during shared memory allocation.
104
+
105
+ :return: The size of dynamic shared memory in bytes
106
+ :rtype: int
107
+ """
108
+ return _cute_nvgpu_ir.arch_get_dyn_smem_size(loc=loc, ip=ip)
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/tmem.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from typing import Type
13
+
14
+ from cutlass.cutlass_dsl import dsl_user_op
15
+
16
+ import cutlass._mlir.dialects.cute as _cute_ir
17
+ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
18
+
19
+ from ..typing import Pointer, Int, Int32, Numeric, NumericMeta
20
+
21
+
22
+ SM100_TMEM_CAPACITY_COLUMNS = 512
23
+ SM100_TMEM_MIN_ALLOC_COLUMNS = 32
24
+
25
+
26
+ @dsl_user_op
27
+ def retrieve_tmem_ptr(
28
+ element_type: Type[Numeric],
29
+ alignment: int,
30
+ ptr_to_buffer_holding_addr: Pointer,
31
+ *,
32
+ loc=None,
33
+ ip=None,
34
+ ) -> Pointer:
35
+ """
36
+ Retrieves a pointer to TMEM with the provided element type and alignment.
37
+
38
+ :param element_type: The pointee type of the pointer.
39
+ :type element_type: Type[Numeric]
40
+ :param alignment: The alignment of the result pointer
41
+ :type alignment: int
42
+ :param ptr_to_buffer_holding_addr: A pointer to a SMEM buffer holding the TMEM address of the
43
+ start of the allocation allocation
44
+ :type ptr_to_buffer_holding_addr: Pointer
45
+ :return: A pointer to TMEM
46
+ :rtype: Pointer
47
+ """
48
+ if not isinstance(element_type, NumericMeta):
49
+ raise TypeError(
50
+ f"element_type must be a type of Numeric, but got {element_type}"
51
+ )
52
+
53
+ res_ty = _cute_ir.PtrType.get(
54
+ element_type.mlir_type, _cute_ir.AddressSpace.tmem, alignment
55
+ )
56
+ return _cute_nvgpu_ir.arch_sm100_retrieve_tmem_ptr(
57
+ res_ty, ptr_to_buffer_holding_addr.value, loc=loc, ip=ip
58
+ )
59
+
60
+
61
+ @dsl_user_op
62
+ def alloc_tmem(
63
+ num_columns: Int,
64
+ smem_ptr_to_write_address: Pointer,
65
+ is_two_cta=None,
66
+ *,
67
+ loc=None,
68
+ ip=None,
69
+ ) -> None:
70
+ """
71
+ Allocates TMEM.
72
+
73
+ :param num_columns: The number of TMEM columns to allocate
74
+ :type num_columns: Int
75
+ :param smem_ptr_to_write_address: A pointer to a SMEM buffer where the TMEM address is written
76
+ to
77
+ :type smem_ptr_to_write_address: Pointer
78
+ :param is_two_cta: Optional boolean parameter for 2-CTA MMAs
79
+ """
80
+ if isinstance(num_columns, int):
81
+ if (
82
+ num_columns < SM100_TMEM_MIN_ALLOC_COLUMNS
83
+ or num_columns > SM100_TMEM_CAPACITY_COLUMNS
84
+ or not (num_columns & (num_columns - 1) == 0)
85
+ ):
86
+ raise ValueError(
87
+ f"num_columns must be between 32 and 512, and must be pow of 2, but got {num_columns}"
88
+ )
89
+ _cute_nvgpu_ir.arch_sm100_alloc_tmem(
90
+ Int32(num_columns).ir_value(loc=loc, ip=ip),
91
+ smem_ptr_to_write_address.value,
92
+ is_two_cta=is_two_cta,
93
+ loc=loc,
94
+ ip=ip,
95
+ )
96
+
97
+
98
+ @dsl_user_op
99
+ def relinquish_tmem_alloc_permit(is_two_cta=None, *, loc=None, ip=None) -> None:
100
+ """
101
+ Relinquishes the right to allocate TMEM so that other CTAs potentially in a different grid can
102
+ allocate.
103
+ """
104
+ _cute_nvgpu_ir.arch_sm100_relinquish_tmem_alloc_permit(
105
+ is_two_cta=is_two_cta, loc=loc, ip=ip
106
+ )
107
+
108
+
109
+ @dsl_user_op
110
+ def dealloc_tmem(
111
+ tmem_ptr: Pointer,
112
+ num_columns: Int,
113
+ is_two_cta=None,
114
+ *,
115
+ loc=None,
116
+ ip=None,
117
+ ) -> None:
118
+ """
119
+ Deallocates TMEM using the provided pointer and number of columns.
120
+
121
+ :param tmem_ptr: A pointer to the TMEM allocation to de-allocate
122
+ :type tmem_ptr: Pointer
123
+ :param num_columns: The number of columns in the TMEM allocation
124
+ :type num_columns: Int
125
+ :param is_two_cta: Optional boolean parameter for 2-CTA MMAs
126
+ """
127
+ if isinstance(num_columns, int):
128
+ if (
129
+ num_columns < SM100_TMEM_MIN_ALLOC_COLUMNS
130
+ or num_columns > SM100_TMEM_CAPACITY_COLUMNS
131
+ or not (num_columns & (num_columns - 1) == 0)
132
+ ):
133
+ raise ValueError(
134
+ f"num_columns must be between 32 and 512, and must be pow of 2, but got {num_columns}"
135
+ )
136
+ _cute_nvgpu_ir.arch_sm100_dealloc_tmem(
137
+ tmem_ptr.value,
138
+ Int32(num_columns).ir_value(loc=loc, ip=ip),
139
+ is_two_cta=is_two_cta,
140
+ loc=loc,
141
+ ip=ip,
142
+ )
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/core.py ADDED
The diff for this file is too large to render. See raw diff
 
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/math.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from .core import TensorSSA
13
+ from .typing import Numeric
14
+ from cutlass._mlir.dialects import math, arith
15
+
16
+ from typing import Callable, Union
17
+
18
+
19
+ def _math_op(func: Callable, fastmath: bool, *args, **kwargs):
20
+ """Dispatch the function to either a TensorSSA or a Numeric(Float).
21
+
22
+ :param func: The function to dispatch
23
+ :param args: The input tensor or scalar
24
+ :param kwargs: The input tensor or scalar
25
+ """
26
+ arg_type = type(args[0])
27
+ for arg in args:
28
+ if not isinstance(arg, TensorSSA) and (
29
+ not isinstance(arg, Numeric) or not type(arg).is_float
30
+ ):
31
+ raise TypeError(
32
+ f"Expected a TensorSSA or Numeric(Float), but got {type(arg)}"
33
+ )
34
+ if not isinstance(arg, arg_type):
35
+ raise TypeError(
36
+ f"Expected all inputs to be of type {arg_type}, but got {type(arg)}"
37
+ )
38
+
39
+ fastmath_flag = arith.FastMathFlags.fast if fastmath else arith.FastMathFlags.none
40
+ if isinstance(args[0], TensorSSA):
41
+ return TensorSSA(
42
+ func(*args, fastmath=fastmath_flag), args[0].shape, args[0].dtype
43
+ )
44
+ else:
45
+ args = [a.ir_value() for a in args]
46
+ return func(*args, fastmath=fastmath_flag)
47
+
48
+
49
+ def acos(
50
+ a: Union[TensorSSA, Numeric], fastmath: bool = False
51
+ ) -> Union[TensorSSA, Numeric]:
52
+ """Compute element-wise arc cosine of the input tensor.
53
+
54
+ :param a: Input tensor
55
+ :type a: Union[TensorSSA, Numeric]
56
+ :param fastmath: Enable fast math optimizations, defaults to False
57
+ :type fastmath: bool, optional
58
+ :return: Tensor containing the arc cosine of each element in input tensor
59
+ :rtype: Union[TensorSSA, Numeric]
60
+
61
+ Example:
62
+
63
+ .. code-block::
64
+
65
+ x = cute.make_fragment(layout) # Create tensor
66
+ y = x.load() # Load values
67
+ z = acos(y) # Compute arc cosine
68
+ """
69
+ return _math_op(math.acos, fastmath, a)
70
+
71
+
72
+ def asin(
73
+ a: Union[TensorSSA, Numeric], fastmath: bool = False
74
+ ) -> Union[TensorSSA, Numeric]:
75
+ """Compute element-wise arc sine of the input tensor.
76
+
77
+ :param a: Input tensor
78
+ :type a: Union[TensorSSA, Numeric]
79
+ :param fastmath: Enable fast math optimizations, defaults to False
80
+ :type fastmath: bool, optional
81
+ :return: Tensor containing the arc sine of each element in input tensor
82
+ :rtype: Union[TensorSSA, Numeric]
83
+
84
+ Example:
85
+
86
+ .. code-block::
87
+
88
+ x = cute.make_fragment(layout) # Create tensor
89
+ y = x.load() # Load values
90
+ z = asin(y) # Compute arc sine
91
+ """
92
+ return _math_op(math.asin, fastmath, a)
93
+
94
+
95
+ def atan(
96
+ a: Union[TensorSSA, Numeric], fastmath: bool = False
97
+ ) -> Union[TensorSSA, Numeric]:
98
+ """Compute element-wise arc tangent of the input tensor.
99
+
100
+ :param a: Input tensor
101
+ :type a: Union[TensorSSA, Numeric]
102
+ :param fastmath: Enable fast math optimizations, defaults to False
103
+ :type fastmath: bool, optional
104
+ :return: Tensor containing the arc tangent of each element in input tensor
105
+ :rtype: Union[TensorSSA, Numeric]
106
+
107
+ Example:
108
+
109
+ .. code-block::
110
+
111
+ x = cute.make_fragment(layout) # Create tensor
112
+ y = x.load() # Load values
113
+ z = atan(y) # Compute arc tangent
114
+ """
115
+ raise NotImplementedError("atan is not implemented")
116
+ return _math_op(math.atan, fastmath, a)
117
+
118
+
119
+ def atan2(
120
+ a: Union[TensorSSA, Numeric], b: Union[TensorSSA, Numeric], fastmath: bool = False
121
+ ) -> Union[TensorSSA, Numeric]:
122
+ """Compute element-wise arc tangent of two tensors.
123
+
124
+ Computes atan2(a, b) element-wise. The function atan2(a, b) is the angle in radians
125
+ between the positive x-axis and the point given by the coordinates (b, a).
126
+
127
+ :param a: First input tensor (y-coordinates)
128
+ :type a: Union[TensorSSA, Numeric]
129
+ :param b: Second input tensor (x-coordinates)
130
+ :type b: Union[TensorSSA, Numeric]
131
+ :param fastmath: Enable fast math optimizations, defaults to False
132
+ :type fastmath: bool, optional
133
+ :return: Tensor containing the arc tangent of a/b element-wise
134
+ :rtype: Union[TensorSSA, Numeric]
135
+
136
+ Example:
137
+
138
+ .. code-block::
139
+
140
+ y = cute.make_fragment(ptr1, layout).load() # y coordinates
141
+ x = cute.make_fragment(ptr2, layout).load() # x coordinates
142
+ theta = atan2(y, x) # Compute angles
143
+ """
144
+ return _math_op(math.atan2, fastmath, a, b)
145
+
146
+
147
+ def cos(
148
+ a: Union[TensorSSA, Numeric], fastmath: bool = False
149
+ ) -> Union[TensorSSA, Numeric]:
150
+ """Compute element-wise cosine of the input tensor.
151
+
152
+ :param a: Input tensor (in radians)
153
+ :type a: Union[TensorSSA, Numeric]
154
+ :param fastmath: Enable fast math optimizations, defaults to False
155
+ :type fastmath: bool, optional
156
+ :return: Tensor containing the cosine of each element
157
+ :rtype: Union[TensorSSA, Numeric]
158
+
159
+ Example:
160
+
161
+ .. code-block::
162
+
163
+ x = cute.make_fragment(layout) # Create tensor
164
+ y = x.load() # Load values
165
+ z = cos(y) # Compute cosine
166
+ """
167
+ return _math_op(math.cos, fastmath, a)
168
+
169
+
170
+ def erf(
171
+ a: Union[TensorSSA, Numeric], fastmath: bool = False
172
+ ) -> Union[TensorSSA, Numeric]:
173
+ """Compute element-wise error function of the input tensor.
174
+
175
+ The error function is defined as:
176
+ erf(x) = 2/√π ∫[0 to x] exp(-t²) dt
177
+
178
+ :param a: Input tensor
179
+ :type a: Union[TensorSSA, Numeric]
180
+ :param fastmath: Enable fast math optimizations, defaults to False
181
+ :type fastmath: bool, optional
182
+ :return: Tensor containing the error function value for each element
183
+ :rtype: Union[TensorSSA, Numeric]
184
+
185
+ Example:
186
+
187
+ .. code-block::
188
+
189
+ x = cute.make_fragment(layout) # Create tensor
190
+ y = x.load() # Load values
191
+ z = erf(y) # Compute error function
192
+ """
193
+ return _math_op(math.erf, fastmath, a)
194
+
195
+
196
+ def exp(
197
+ a: Union[TensorSSA, Numeric], fastmath: bool = False
198
+ ) -> Union[TensorSSA, Numeric]:
199
+ """Compute element-wise exponential of the input tensor.
200
+
201
+ :param a: Input tensor
202
+ :type a: Union[TensorSSA, Numeric]
203
+ :param fastmath: Enable fast math optimizations, defaults to False
204
+ :type fastmath: bool, optional
205
+ :return: Tensor containing the exponential of each element
206
+ :rtype: Union[TensorSSA, Numeric]
207
+
208
+ Example:
209
+
210
+ .. code-block::
211
+
212
+ x = cute.make_fragment(layout) # Create tensor
213
+ y = x.load() # Load values
214
+ z = exp(y) # Compute exponential
215
+ """
216
+ return _math_op(math.exp, fastmath, a)
217
+
218
+
219
+ def exp2(
220
+ a: Union[TensorSSA, Numeric], fastmath: bool = False
221
+ ) -> Union[TensorSSA, Numeric]:
222
+ """Compute element-wise base-2 exponential of the input tensor.
223
+
224
+ :param a: Input tensor
225
+ :type a: Union[TensorSSA, Numeric]
226
+ :param fastmath: Enable fast math optimizations, defaults to False
227
+ :type fastmath: bool, optional
228
+ :return: Tensor containing 2 raised to the power of each element
229
+ :rtype: Union[TensorSSA, Numeric]
230
+
231
+ Example:
232
+
233
+ .. code-block::
234
+
235
+ x = cute.make_fragment(layout) # Create tensor
236
+ y = x.load() # Load values
237
+ z = exp2(y) # Compute 2^x
238
+ """
239
+ return _math_op(math.exp2, fastmath, a)
240
+
241
+
242
+ def log(
243
+ a: Union[TensorSSA, Numeric], fastmath: bool = False
244
+ ) -> Union[TensorSSA, Numeric]:
245
+ """Compute element-wise natural logarithm of the input tensor.
246
+
247
+ :param a: Input tensor
248
+ :type a: Union[TensorSSA, Numeric]
249
+ :param fastmath: Enable fast math optimizations, defaults to False
250
+ :type fastmath: bool, optional
251
+ :return: Tensor containing the natural logarithm of each element
252
+ :rtype: Union[TensorSSA, Numeric]
253
+
254
+ Example:
255
+
256
+ .. code-block::
257
+
258
+ x = cute.make_fragment(layout) # Create tensor
259
+ y = x.load() # Load values
260
+ z = log(y) # Compute natural logarithm
261
+ """
262
+ return _math_op(math.log, fastmath, a)
263
+
264
+
265
+ def log2(
266
+ a: Union[TensorSSA, Numeric], fastmath: bool = False
267
+ ) -> Union[TensorSSA, Numeric]:
268
+ """Compute element-wise base-2 logarithm of the input tensor.
269
+
270
+ :param a: Input tensor
271
+ :type a: Union[TensorSSA, Numeric]
272
+ :param fastmath: Enable fast math optimizations, defaults to False
273
+ :type fastmath: bool, optional
274
+ :return: Tensor containing the base-2 logarithm of each element
275
+ :rtype: Union[TensorSSA, Numeric]
276
+
277
+ Example:
278
+
279
+ .. code-block::
280
+
281
+ x = cute.make_fragment(layout) # Create tensor
282
+ y = x.load() # Load values
283
+ z = log2(y) # Compute log base 2
284
+ """
285
+ return _math_op(math.log2, fastmath, a)
286
+
287
+
288
+ def log10(
289
+ a: Union[TensorSSA, Numeric], fastmath: bool = False
290
+ ) -> Union[TensorSSA, Numeric]:
291
+ """Compute element-wise base-10 logarithm of the input tensor.
292
+
293
+ :param a: Input tensor
294
+ :type a: Union[TensorSSA, Numeric]
295
+ :param fastmath: Enable fast math optimizations, defaults to False
296
+ :type fastmath: bool, optional
297
+ :return: Tensor containing the base-10 logarithm of each element
298
+ :rtype: Union[TensorSSA, Numeric]
299
+
300
+ Example:
301
+
302
+ .. code-block::
303
+
304
+ x = cute.make_fragment(layout) # Create tensor
305
+ y = x.load() # Load values
306
+ z = log10(y) # Compute log base 10
307
+ """
308
+ return _math_op(math.log10, fastmath, a)
309
+
310
+
311
+ def rsqrt(
312
+ a: Union[TensorSSA, Numeric], fastmath: bool = False
313
+ ) -> Union[TensorSSA, Numeric]:
314
+ """Compute element-wise reciprocal square root of the input tensor.
315
+
316
+ Computes 1/√x element-wise.
317
+
318
+ :param a: Input tensor
319
+ :type a: Union[TensorSSA, Numeric]
320
+ :param fastmath: Enable fast math optimizations, defaults to False
321
+ :type fastmath: bool, optional
322
+ :return: Tensor containing the reciprocal square root of each element
323
+ :rtype: Union[TensorSSA, Numeric]
324
+
325
+ Example:
326
+
327
+ .. code-block::
328
+
329
+ x = cute.make_fragment(layout) # Create tensor
330
+ y = x.load() # Load values
331
+ z = rsqrt(y) # Compute 1/√x
332
+ """
333
+ return _math_op(math.rsqrt, fastmath, a)
334
+
335
+
336
+ def sin(
337
+ a: Union[TensorSSA, Numeric], fastmath: bool = False
338
+ ) -> Union[TensorSSA, Numeric]:
339
+ """Compute element-wise sine of the input tensor.
340
+
341
+ :param a: Input tensor (in radians)
342
+ :type a: Union[TensorSSA, Numeric]
343
+ :param fastmath: Enable fast math optimizations, defaults to False
344
+ :type fastmath: bool, optional
345
+ :return: Tensor containing the sine of each element
346
+ :rtype: Union[TensorSSA, Numeric]
347
+
348
+ Example:
349
+
350
+ .. code-block::
351
+
352
+ x = cute.make_fragment(layout) # Create tensor
353
+ y = x.load() # Load values
354
+ z = sin(y) # Compute sine
355
+ """
356
+ return _math_op(math.sin, fastmath, a)
357
+
358
+
359
+ def sqrt(
360
+ a: Union[TensorSSA, Numeric], fastmath: bool = False
361
+ ) -> Union[TensorSSA, Numeric]:
362
+ """Compute element-wise square root of the input tensor.
363
+
364
+ :param a: Input tensor
365
+ :type a: Union[TensorSSA, Numeric]
366
+ :param fastmath: Enable fast math optimizations, defaults to False
367
+ :type fastmath: bool, optional
368
+ :return: Tensor containing the square root of each element
369
+ :rtype: Union[TensorSSA, Numeric]
370
+
371
+ Example:
372
+
373
+ .. code-block::
374
+
375
+ x = cute.make_fragment(layout) # Create tensor
376
+ y = x.load() # Load values
377
+ z = sqrt(y) # Compute square root
378
+ """
379
+ return _math_op(math.sqrt, fastmath, a)
380
+
381
+
382
+ def tan(
383
+ a: Union[TensorSSA, Numeric], fastmath: bool = False
384
+ ) -> Union[TensorSSA, Numeric]:
385
+ """Compute element-wise tangent of the input tensor.
386
+
387
+ :param a: Input tensor (in radians)
388
+ :type a: Union[TensorSSA, Numeric]
389
+ :param fastmath: Enable fast math optimizations, defaults to False
390
+ :type fastmath: bool, optional
391
+ :return: Tensor containing the tangent of each element
392
+ :rtype: Union[TensorSSA, Numeric]
393
+
394
+ Example:
395
+
396
+ .. code-block::
397
+
398
+ x = cute.make_fragment(layout) # Create tensor
399
+ y = x.load() # Load values
400
+ z = tan(y) # Compute tangent
401
+ """
402
+ return _math_op(math.tan, fastmath, a)
403
+
404
+
405
+ def tanh(
406
+ a: Union[TensorSSA, Numeric], fastmath: bool = False
407
+ ) -> Union[TensorSSA, Numeric]:
408
+ """Compute element-wise hyperbolic tangent of the input tensor.
409
+
410
+ :param a: Input tensor
411
+ :type a: Union[TensorSSA, Numeric]
412
+ :param fastmath: Enable fast math optimizations, defaults to False
413
+ :type fastmath: bool, optional
414
+ :return: Tensor containing the hyperbolic tangent of each element
415
+ :rtype: Union[TensorSSA, Numeric]
416
+
417
+ Example:
418
+
419
+ .. code-block::
420
+
421
+ x = cute.make_fragment(layout) # Create tensor
422
+ y = x.load() # Load values
423
+ z = tanh(y) # Compute hyperbolic tangent
424
+ """
425
+ return _math_op(math.tanh, fastmath, a)
426
+
427
+
428
+ __all__ = [
429
+ "acos",
430
+ "asin",
431
+ "atan",
432
+ "atan2",
433
+ "cos",
434
+ "erf",
435
+ "exp",
436
+ "exp2",
437
+ "log",
438
+ "log10",
439
+ "log2",
440
+ "rsqrt",
441
+ "sin",
442
+ "sqrt",
443
+ "tan",
444
+ "tanh",
445
+ ]
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from . import warp
13
+ from . import cpasync
14
+ from . import warpgroup
15
+ from . import tcgen05
16
+
17
+ from .common import *
18
+ from .helpers import *
19
+
20
+
21
+ # __all__ is required here for documentation generation
22
+ __all__ = [
23
+ "OpError",
24
+ "MmaUniversalOp",
25
+ "CopyUniversalOp",
26
+ ]
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/common.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+ import enum
12
+ from dataclasses import dataclass
13
+ from typing import Type, Optional
14
+
15
+ from cutlass.cutlass_dsl import DSLBaseError
16
+
17
+ import cutlass._mlir.dialects.cute as _cute_ir
18
+ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
19
+ from cutlass._mlir import ir
20
+
21
+ from .. import core
22
+ from ..typing import Float16, Float32, Float64, Numeric
23
+
24
+
25
+ class OpError(DSLBaseError):
26
+ """
27
+ An exception class for Op construction errors.
28
+ """
29
+
30
+ def __init__(
31
+ self, op: core.Op, message: str, suggestion: Optional[str] = None
32
+ ) -> None:
33
+ if suggestion is None:
34
+ # Default suggestion
35
+ suggestion = "Check your Op construction code"
36
+ super().__init__(
37
+ message,
38
+ error_code=f"{op.__class__.__name__} error",
39
+ suggestion=suggestion,
40
+ )
41
+
42
+
43
+ ####################################################################################################
44
+ #
45
+ # MMA Ops and Traits
46
+ #
47
+ ####################################################################################################
48
+
49
+
50
+ @dataclass(frozen=True)
51
+ class MmaUniversalOp(core.MmaOp):
52
+ """
53
+ The universal MMA Operation.
54
+
55
+ This Operation currently expects the A/B operands as well as the accumulator to share the same
56
+ data types.
57
+
58
+ :param abacc_dtype: The data type for the A/B operands and the accumulator
59
+ :type abacc_dtype: Type[Numeric]
60
+ """
61
+
62
+ abacc_dtype: Type[Numeric]
63
+
64
+ def __post_init__(self) -> None:
65
+ if self.abacc_dtype not in [Float16, Float32, Float64]:
66
+ raise OpError(
67
+ self,
68
+ f"expects the 'abacc_dtype' Op parameter to be one of Float16, Float32, or Float64",
69
+ )
70
+
71
+ def __str__(self) -> str:
72
+ return (
73
+ "universal MMA Operation using FMA"
74
+ f"\n A/B/Accumulator data type = {self.abacc_dtype}"
75
+ )
76
+
77
+ def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaUniversalTrait":
78
+ shape_mnk_attr = ir.Attribute.parse(f'#cute.shape<"(1,1,1)">')
79
+ atom_ty = _cute_nvgpu_ir.UniversalFmaAtomType.get(
80
+ shape_mnk_attr,
81
+ self.abacc_dtype.mlir_type,
82
+ self.abacc_dtype.mlir_type,
83
+ self.abacc_dtype.mlir_type,
84
+ )
85
+ return MmaUniversalTrait(_cute_ir.atom(atom_ty, loc=loc, ip=ip))
86
+
87
+ def _verify_fragment_A(self, input, *, loc=None, ip=None):
88
+ pass
89
+
90
+ def _verify_fragment_B(self, input, *, loc=None, ip=None):
91
+ pass
92
+
93
+ class MmaUniversalTrait(core.Trait):
94
+ pass
95
+
96
+
97
+ ####################################################################################################
98
+ #
99
+ # Copy Ops and Traits
100
+ #
101
+ ####################################################################################################
102
+
103
+
104
+ class MemoryOrder(enum.Enum):
105
+ WEAK = _cute_ir.MemOrderKind.WEAK
106
+ RELAXED = _cute_ir.MemOrderKind.RELAXED
107
+ ACQUIRE = _cute_ir.MemOrderKind.ACQUIRE
108
+ RELEASE = _cute_ir.MemOrderKind.RELEASE
109
+ ACQ_REL = _cute_ir.MemOrderKind.ACQ_REL
110
+ SC = _cute_ir.MemOrderKind.SC
111
+ MMIO = _cute_ir.MemOrderKind.MMIO
112
+ CONSTANT = _cute_ir.MemOrderKind.CONSTANT
113
+ VOLATILE = _cute_ir.MemOrderKind.VOLATILE
114
+
115
+ def __str__(self) -> str:
116
+ return f"{self.__class__.__name__}.{self.name}"
117
+
118
+ def __repr__(self) -> str:
119
+ return f"<{self.__class__.__name__}.{self.name}>"
120
+
121
+ def _to_ir(self) -> _cute_ir.MemOrderKind:
122
+ return self.value
123
+
124
+
125
+ class MemoryScope(enum.Enum):
126
+ CTA = _cute_ir.MemScopeKind.CTA
127
+ CLUSTER = _cute_ir.MemScopeKind.CLUSTER
128
+ GPU = _cute_ir.MemScopeKind.GPU
129
+ SYS = _cute_ir.MemScopeKind.SYS
130
+
131
+ def __str__(self) -> str:
132
+ return f"{self.__class__.__name__}.{self.name}"
133
+
134
+ def __repr__(self) -> str:
135
+ return f"<{self.__class__.__name__}.{self.name}>"
136
+
137
+ def _to_ir(self) -> _cute_ir.MemScopeKind:
138
+ return self.value
139
+
140
+ @dataclass(frozen=True)
141
+ class CopyUniversalOp(core.CopyOp):
142
+ """
143
+ The universal Copy Operation.
144
+
145
+ When creating a Copy Atom out of this operation, the expected usage pattern is
146
+
147
+ .. code-block:: python
148
+
149
+ op = cute.nvgpu.CopyUniversalOp()
150
+ atom = cute.make_copy_atom(op, tensor_dtype, num_bits_per_copy=64)
151
+
152
+ - ``tensor_dtype`` is the data type used to build the reference TV Layout (either the source \
153
+ or the destination TV Layout) in unit of tensor elements and is used for partitioning by \
154
+ ``TiledCopy`` for example
155
+ - ``num_bits_per_copy`` is a kw argument specifying the number of bits to copy per Atom \
156
+ execution. This can be larger than the width of the above data type. When not provided, \
157
+ the compiler will do a best effort at auto-vectorizing.
158
+ """
159
+
160
+ def __str__(self) -> str:
161
+ return "universal Copy Operation"
162
+
163
+ def _make_trait(
164
+ self,
165
+ copy_internal_type: Type[Numeric],
166
+ *,
167
+ loc=None,
168
+ ip=None,
169
+ **kwargs,
170
+ ) -> "CopyUniversalTrait":
171
+ num_bits_per_copy = kwargs.get("num_bits_per_copy", 0)
172
+ memory_order = kwargs.get("memory_order", MemoryOrder.WEAK)
173
+ memory_scope = kwargs.get("memory_scope", MemoryScope.CTA)
174
+ if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0):
175
+ raise ValueError(
176
+ "expects a 'num_bits_per_copy' kw argument of type int that is non-negative "
177
+ f"when creating a copy Atom for {self.__class__.__name__}"
178
+ )
179
+ ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get(
180
+ copy_internal_type.mlir_type,
181
+ num_bits_per_copy,
182
+ memory_order._to_ir(),
183
+ memory_scope._to_ir(),
184
+ )
185
+ return CopyUniversalTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
186
+
187
+
188
+ class CopyUniversalTrait(core.Trait):
189
+ pass
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from .copy import *
13
+ from .helpers import *
14
+
15
+
16
+ # __all__ is required here for documentation generation
17
+ __all__ = [
18
+ #
19
+ # copy.py
20
+ #
21
+ "LoadCacheMode",
22
+ "CopyG2SOp",
23
+ "CopyBulkTensorTileG2SOp",
24
+ "CopyBulkTensorTileG2SMulticastOp",
25
+ "CopyBulkTensorTileS2GOp",
26
+ "CopyReduceBulkTensorTileS2GOp",
27
+ #
28
+ # helpers.py
29
+ #
30
+ "make_tiled_tma_atom",
31
+ "tma_partition",
32
+ "create_tma_multicast_mask",
33
+ "prefetch_descriptor",
34
+ "copy_tensormap",
35
+ "update_tma_descriptor",
36
+ "fence_tma_desc_acquire",
37
+ "cp_fence_tma_desc_release",
38
+ "fence_tma_desc_release",
39
+ ]
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ import enum
13
+ from dataclasses import dataclass
14
+ from typing import Optional, Type
15
+
16
+ from cutlass.cutlass_dsl import CuTeDSL, t
17
+
18
+ import cutlass._mlir.dialects.cute as _cute_ir
19
+ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
20
+ from cutlass._mlir import ir
21
+
22
+ from ...core import CopyOp, Trait, ReductionOp
23
+ from ...typing import Int16, Pointer, Integer, Numeric
24
+ from ..common import OpError
25
+ from ..tcgen05.mma import CtaGroup
26
+
27
+
28
+ ####################################################################################################
29
+ #
30
+ # Aynchronous copies
31
+ #
32
+ ####################################################################################################
33
+
34
+
35
+ class LoadCacheMode(enum.Enum):
36
+ """
37
+ An enumeration for the possible cache modes of a non-bulk ``cp.async`` instruction.
38
+
39
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#cache-operators>`__.
40
+ """
41
+
42
+ ALWAYS = _cute_nvgpu_ir.LoadCacheMode.always
43
+ GLOBAL = _cute_nvgpu_ir.LoadCacheMode.global_
44
+ STREAMING = _cute_nvgpu_ir.LoadCacheMode.streaming
45
+ LAST_USE = _cute_nvgpu_ir.LoadCacheMode.last_use
46
+ NONE = _cute_nvgpu_ir.LoadCacheMode.none
47
+
48
+ def __str__(self) -> str:
49
+ return f"{self.__class__.__name__}.{self.name}"
50
+
51
+ def __repr__(self) -> str:
52
+ return f"<{self.__class__.__name__}.{self.name}>"
53
+
54
+ def _to_ir(self) -> _cute_nvgpu_ir.LoadCacheMode:
55
+ return self.value
56
+
57
+
58
+ @dataclass(frozen=True)
59
+ class CopyG2SOp(CopyOp):
60
+ """
61
+ Non-bulk asynchronous GMEM to SMEM Copy Operation.
62
+
63
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-non-bulk-copy>`__.
64
+ """
65
+
66
+ cache_mode: LoadCacheMode = LoadCacheMode.ALWAYS
67
+
68
+ def __str__(self) -> str:
69
+ res = "cp.async GMEM -> SMEM copy Operation"
70
+ if self.cache_mode != LoadCacheMode.ALWAYS:
71
+ res += f"\n with cache mode = {self.cache_mode}"
72
+ return res
73
+
74
+ def _make_trait(
75
+ self,
76
+ copy_internal_type: Type[t.Numeric],
77
+ *,
78
+ loc=None,
79
+ ip=None,
80
+ **kwargs,
81
+ ) -> "CopyG2STrait":
82
+ num_bits_per_copy = kwargs.get("num_bits_per_copy", None)
83
+ # Verify that the user provided enum values
84
+ if not isinstance(self.cache_mode, LoadCacheMode):
85
+ raise OpError(
86
+ self,
87
+ "expects the 'cache_mode' Op parameter to be a LoadCacheMode instance",
88
+ )
89
+ if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy <= 0):
90
+ raise ValueError(
91
+ "expects a 'num_bits_per_copy' kw argument of type int that is positive "
92
+ f"when creating a copy Atom for {self.__class__.__name__}"
93
+ )
94
+ # Verify that the user provided enum values
95
+ if not isinstance(self.cache_mode, LoadCacheMode):
96
+ raise OpError(
97
+ self,
98
+ "expects the 'cache_mode' Op parameter to be a LoadCacheMode instance",
99
+ )
100
+ ty = _cute_nvgpu_ir.CopyAtomSIMTAsyncCopyType.get(
101
+ copy_internal_type.mlir_type, self.cache_mode._to_ir(), num_bits_per_copy
102
+ )
103
+ return CopyG2STrait(_cute_ir.atom(ty, loc=loc, ip=ip))
104
+
105
+
106
+ class CopyG2STrait(Trait):
107
+ pass
108
+
109
+
110
+ ####################################################################################################
111
+ #
112
+ # Bulk tensor copies a.k.a TMA copies
113
+ #
114
+ ####################################################################################################
115
+
116
+ TMA_MBAR_PTR_FIELD_NAME = "tma_bar"
117
+ TMA_MASK_FIELD_NAME = "mcast_mask"
118
+ TMA_DESC_PTR_FIELD_NAME = "tma_descriptor_ptr"
119
+
120
+ #
121
+ # TMA GMEM -> SMEM copies
122
+ #
123
+
124
+
125
+ @dataclass(frozen=True)
126
+ class CopyBulkTensorTileG2SOp(CopyOp):
127
+ """
128
+ Bulk tensor asynchrnous GMEM to SMEM Copy Operation using the TMA unit.
129
+
130
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`__.
131
+ This Operation uses TMA in the ``.tile`` mode.
132
+ """
133
+
134
+ cta_group: CtaGroup = CtaGroup.ONE
135
+
136
+ admissible_archs = [
137
+ "sm_90",
138
+ "sm_90a",
139
+ "sm_100a",
140
+ "sm_100f",
141
+ ]
142
+
143
+ def __post_init__(self) -> None:
144
+ if not isinstance(self.cta_group, CtaGroup):
145
+ raise OpError(
146
+ self, "expects the 'cta_group' parameter to be a CtaGroup instance"
147
+ )
148
+ # Arch verification
149
+ arch = CuTeDSL._get_dsl().envar.arch
150
+ if arch not in self.admissible_archs:
151
+ raise OpError(
152
+ self,
153
+ f"expects arch to be one of {self.admissible_archs}, but got {arch}",
154
+ suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
155
+ )
156
+ if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90":
157
+ raise OpError(
158
+ self,
159
+ f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}",
160
+ suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
161
+ )
162
+
163
+ def __str__(self) -> str:
164
+ res = "cp.async GMEM -> SMEM bulk tensor copy Operation"
165
+ if self.cta_group == CtaGroup.TWO:
166
+ res += f"\n CTA group = 2"
167
+ return res
168
+
169
+ def _make_trait(
170
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
171
+ ) -> "CopyBulkTensorTileG2SNonExecTrait":
172
+ raise NotImplementedError(
173
+ "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
174
+ )
175
+
176
+ def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum:
177
+ if self.cta_group == CtaGroup.ONE:
178
+ return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_90
179
+ elif self.cta_group == CtaGroup.TWO:
180
+ return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_100_2sm
181
+ else:
182
+ assert False, "unrecognized self.cta_group"
183
+
184
+
185
+ class CopyBulkTensorTileG2SNonExecTrait(Trait):
186
+ # We allow kw args to be dropped so that the user can write common code for non-multicast
187
+ # and multicast loads.
188
+ def unpack(
189
+ self,
190
+ *,
191
+ loc=None,
192
+ ip=None,
193
+ tma_bar_ptr: Optional[Pointer] = None,
194
+ tma_desc_ptr: Optional[Pointer] = None,
195
+ **kwargs,
196
+ ):
197
+ """
198
+ Custom implementation of unpack for non-executable TMAs.
199
+
200
+ The non-multicast TMA load requires a `tma_bar_ptr` keyword argument to be provided when
201
+ using `cute.copy`. Any other kw arguments will be ignored instead of triggering an error.
202
+ """
203
+ if not isinstance(tma_bar_ptr, Pointer):
204
+ raise ValueError(
205
+ "expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument"
206
+ )
207
+ exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
208
+ attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_MBAR_PTR_FIELD_NAME}>"
209
+ attr = ir.Attribute.parse(attr_str)
210
+ exec_value = _cute_nvgpu_ir.atom_set_value(
211
+ exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip
212
+ )
213
+ if isinstance(tma_desc_ptr, Pointer):
214
+ attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>"
215
+ attr = ir.Attribute.parse(attr_str)
216
+ exec_value = _cute_nvgpu_ir.atom_set_value(
217
+ exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
218
+ )
219
+ return exec_value
220
+
221
+
222
+ #
223
+ # TMA GMEM -> SMEM multicast copies
224
+ #
225
+
226
+
227
+ @dataclass(frozen=True)
228
+ class CopyBulkTensorTileG2SMulticastOp(CopyOp):
229
+ """
230
+ Bulk tensor asynchrnous multicast GMEM to SMEM Copy Operation using the TMA unit.
231
+
232
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`__.
233
+ This Operation uses TMA in the ``.tile`` mode.
234
+ """
235
+
236
+ cta_group: CtaGroup = CtaGroup.ONE
237
+
238
+ admissible_archs = [
239
+ "sm_90",
240
+ "sm_90a",
241
+ "sm_100a",
242
+ "sm_100f",
243
+ ]
244
+
245
+ def __post_init__(self):
246
+ if not isinstance(self.cta_group, CtaGroup):
247
+ raise OpError(
248
+ self, "expects the 'cta_group' parameter to be a CtaGroup instance"
249
+ )
250
+ # Arch verification
251
+ arch = CuTeDSL._get_dsl().envar.arch
252
+ if arch not in self.admissible_archs:
253
+ raise OpError(
254
+ self,
255
+ f"expects arch to be one of {self.admissible_archs}, but got {arch}",
256
+ suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
257
+ )
258
+ if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90":
259
+ raise OpError(
260
+ self,
261
+ f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}",
262
+ suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
263
+ )
264
+
265
+ def __str__(self) -> str:
266
+ res = "cp.async GMEM -> SMEM bulk tensor multicast copy Operation"
267
+ if self.cta_group == CtaGroup.TWO:
268
+ res += f"\n CTA group = 2"
269
+ return res
270
+
271
+ def _make_trait(
272
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
273
+ ) -> "CopyBulkTensorTileG2SMulticastNonExecTrait":
274
+ raise NotImplementedError(
275
+ "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
276
+ )
277
+
278
+ def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum:
279
+ if self.cta_group == CtaGroup.ONE:
280
+ return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_90_multicast
281
+ elif self.cta_group == CtaGroup.TWO:
282
+ return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_100_2sm_multicast
283
+ else:
284
+ assert False, "unrecognized self.cta_group"
285
+
286
+
287
+ class CopyBulkTensorTileG2SMulticastNonExecTrait(Trait):
288
+ def unpack(
289
+ self,
290
+ *,
291
+ loc=None,
292
+ ip=None,
293
+ tma_bar_ptr: Optional[Pointer] = None,
294
+ mcast_mask=None,
295
+ tma_desc_ptr=None,
296
+ ):
297
+ """
298
+ Custom implementation of unpack for non-executable TMAs.
299
+
300
+ The multicast TMA load requires a `tma_bar_ptr` and a `mcast_mask` keyword arguments to be
301
+ provided when using `cute.copy`.
302
+ """
303
+ if not isinstance(tma_bar_ptr, Pointer):
304
+ raise ValueError(
305
+ "expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument"
306
+ )
307
+ if not isinstance(mcast_mask, Integer):
308
+ raise ValueError(
309
+ "expects a multicast mask to be provided via the mcast_mask kw argument"
310
+ )
311
+ exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
312
+ attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<tma_bar>"
313
+ attr = ir.Attribute.parse(attr_str)
314
+ exec_value = _cute_nvgpu_ir.atom_set_value(
315
+ exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip
316
+ )
317
+ attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<mcast_mask>"
318
+ attr = ir.Attribute.parse(attr_str)
319
+ exec_value = _cute_nvgpu_ir.atom_set_value(
320
+ exec_value, attr, Int16(mcast_mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
321
+ )
322
+ if isinstance(tma_desc_ptr, Pointer):
323
+ attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>"
324
+ attr = ir.Attribute.parse(attr_str)
325
+ exec_value = _cute_nvgpu_ir.atom_set_value(
326
+ exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
327
+ )
328
+ return exec_value
329
+
330
+
331
+ #
332
+ # TMA SMEM -> GMEM copies
333
+ #
334
+
335
+
336
+ @dataclass(frozen=True)
337
+ class CopyBulkTensorTileS2GOp(CopyOp):
338
+ """
339
+ Bulk tensor asynchronous SMEM to GMEM Copy Operation using the TMA unit.
340
+
341
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`__.
342
+ This Operation uses TMA in the ``.tile`` mode.
343
+ """
344
+
345
+ admissible_archs = [
346
+ "sm_90",
347
+ "sm_90a",
348
+ "sm_100a",
349
+ "sm_100f",
350
+ ]
351
+
352
+ def __post_init__(self):
353
+ # Arch verification
354
+ arch = CuTeDSL._get_dsl().envar.arch
355
+ if arch not in self.admissible_archs:
356
+ raise OpError(
357
+ self,
358
+ f"expects arch to be one of {self.admissible_archs}, but got {arch}",
359
+ suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
360
+ )
361
+
362
+ def __str__(self) -> str:
363
+ return "cp.async SMEM -> GMEM bulk tensor copy Operation"
364
+
365
+ def _make_trait(
366
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
367
+ ) -> "CopyBulkTensorTileS2GTrait":
368
+ raise NotImplementedError(
369
+ "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
370
+ )
371
+
372
+
373
+ class CopyBulkTensorTileS2GTrait(Trait):
374
+ def unpack(self, *, loc=None, ip=None, tma_desc_ptr: Optional[Pointer] = None):
375
+ """
376
+ Custom implementation of unpack for non-executable TMAs.
377
+ """
378
+ exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
379
+ if isinstance(tma_desc_ptr, Pointer):
380
+ attr_str = (
381
+ f"#cute_nvgpu.atom_copy_field_tmastore<{TMA_DESC_PTR_FIELD_NAME}>"
382
+ )
383
+ attr = ir.Attribute.parse(attr_str)
384
+ exec_value = _cute_nvgpu_ir.atom_set_value(
385
+ exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
386
+ )
387
+ return exec_value
388
+
389
+ @dataclass(frozen=True)
390
+ class CopyReduceBulkTensorTileS2GOp(CopyOp):
391
+ """
392
+ Bulk tensor asynchronous SMEM to GMEM Reduction Operation using the TMA unit.
393
+
394
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk>`__.
395
+ This Operation uses TMA in the ``.tile`` mode.
396
+ """
397
+
398
+ reduction_kind: ReductionOp = ReductionOp.ADD
399
+
400
+ admissible_archs = [
401
+ "sm_90",
402
+ "sm_90a",
403
+ "sm_100a",
404
+ "sm_100f",
405
+ ]
406
+
407
+ def __post__init__(self):
408
+ # Arch verification
409
+ arch = CuTeDSL.__get_dsl().envar.arch
410
+ if arch not in self.admissible_archs:
411
+ raise OpError(
412
+ self,
413
+ f"expects arch to be one of {self.admissible_archs}, but got {arch}",
414
+ suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
415
+ )
416
+
417
+ def __str__(self) -> str:
418
+ return "cp.async SMEM -> GMEM bulk tensor reduction Operation"
419
+
420
+ def _make_trait(
421
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
422
+ ) -> "CopyReduceBulkTensorTileS2GTrait":
423
+ raise NotImplementedError(
424
+ "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
425
+ )
426
+
427
+ def _to_ir(self) -> _cute_nvgpu_ir.ReductionKind:
428
+ if self.reduction_kind == ReductionOp.ADD:
429
+ return _cute_nvgpu_ir.ReductionKind.ADD
430
+ elif self.reduction_kind == ReductionOp.MIN:
431
+ return _cute_nvgpu_ir.ReductionKind.MIN
432
+ elif self.reduction_kind == ReductionOp.MAX:
433
+ return _cute_nvgpu_ir.ReductionKind.MAX
434
+ elif self.reduction_kind == ReductionOp.INC:
435
+ return _cute_nvgpu_ir.ReductionKind.INC
436
+ elif self.reduction_kind == ReductionOp.DEC:
437
+ return _cute_nvgpu_ir.ReductionKind.DEC
438
+ elif self.reduction_kind == ReductionOp.AND:
439
+ return _cute_nvgpu_ir.ReductionKind.AND
440
+ elif self.reduction_kind == ReductionOp.OR:
441
+ return _cute_nvgpu_ir.ReductionKind.OR
442
+ elif self.reduction_kind == ReductionOp.XOR:
443
+ return _cute_nvgpu_ir.ReductionKind.XOR
444
+ else:
445
+ assert False, "unrecognized self.reduction_kind"
446
+
447
+
448
+ class CopyReduceBulkTensorTileS2GTrait(Trait):
449
+ def unpack(self, *, loc=None, ip=None, tma_desc_ptr: Optional[Pointer] = None):
450
+ """
451
+ Custom implementation of unpack for non-executable TMAs.
452
+ """
453
+ exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
454
+ if isinstance(tma_desc_ptr, Pointer):
455
+ attr_str = (
456
+ f"#cute_nvgpu.atom_copy_field_tmareduce<{TMA_DESC_PTR_FIELD_NAME}>"
457
+ )
458
+ attr = ir.Attribute.parse(attr_str)
459
+ exec_value = _cute_nvgpu_ir.atom_set_value(
460
+ exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
461
+ )
462
+ return exec_value
463
+
464
+ __all__ = [
465
+ "LoadCacheMode",
466
+ "CopyG2SOp",
467
+ "CopyBulkTensorTileG2SOp",
468
+ "CopyBulkTensorTileG2SMulticastOp",
469
+ "CopyBulkTensorTileS2GOp",
470
+ "CopyReduceBulkTensorTileS2GOp",
471
+ ]
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from typing import Optional, Tuple, Type, Union
13
+
14
+ from cutlass.cutlass_dsl import dsl_user_op
15
+
16
+ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
17
+ from cutlass._mlir.dialects import llvm
18
+
19
+ from ...typing import Coord, Layout, Tensor, Tiler, Pointer, Int16, Numeric, NumericMeta
20
+ from ... import core
21
+ from .copy import (
22
+ CopyBulkTensorTileG2SOp,
23
+ CopyBulkTensorTileG2SMulticastOp,
24
+ CopyBulkTensorTileS2GOp,
25
+ CopyReduceBulkTensorTileS2GOp,
26
+ CopyBulkTensorTileG2SNonExecTrait,
27
+ CopyBulkTensorTileG2SMulticastNonExecTrait,
28
+ CopyBulkTensorTileS2GTrait,
29
+ CopyReduceBulkTensorTileS2GTrait,
30
+ )
31
+
32
+
33
+ @dsl_user_op
34
+ def make_tiled_tma_atom(
35
+ op: Union[
36
+ CopyBulkTensorTileG2SOp,
37
+ CopyBulkTensorTileG2SMulticastOp,
38
+ CopyBulkTensorTileS2GOp,
39
+ CopyReduceBulkTensorTileS2GOp,
40
+ ],
41
+ gmem_tensor: Tensor,
42
+ smem_layout: Union[Layout, core.ComposedLayout],
43
+ cta_tiler: Tiler,
44
+ num_multicast: int = 1,
45
+ *,
46
+ internal_type: Optional[Type[Numeric]] = None,
47
+ loc=None,
48
+ ip=None,
49
+ ) -> Tuple[core.CopyAtom, Tensor]:
50
+ """
51
+ Makes a TMA Copy Atom in the ``.tile`` mode to copy tiles of a GMEM tensor to/from SMEM
52
+ buffer with the given Layout.
53
+
54
+ Given
55
+
56
+ - a GMEM tensor
57
+ - a SMEM layout
58
+ - a CTA-level Tiler
59
+
60
+ this function figures out the bulk tensor asynchronous copy instruction to use with the maximum
61
+ "TMA vector length" to copy tiles of the GMEM tensor to/from an SMEM buffer with the provided
62
+ layout and consistent with the provided Tiler.
63
+
64
+ This function returns two results:
65
+
66
+ 1. the Copy Atom
67
+ 2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates \
68
+ that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the \
69
+ associated layout can output coordinates. Otherwise, TMA tensors can be partitioned \
70
+ similarly to any other CuTe tensors using the algebra.
71
+
72
+ :param op: The Copy Operation to construct an Atom for
73
+ :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp, CopyReduceBulkTensorTileS2GOp]
74
+ :param gmem_tensor: The GMEM tensor involved in the Copy
75
+ :type gmem_tensor: Tensor
76
+ :param smem_layout: The SMEM layout to construct the Copy Atom for
77
+ :type smem_layout: Union[Layout, core.ComposedLayout]
78
+ :param cta_tiler: The CTA Tiler to use
79
+ :type cta_tiler: Tiler
80
+ :param num_multicast: The multicast factor
81
+ :type num_multicast: int
82
+ :param internal_type: An optional parameter for the internal data type to use when the actual data type is not supported by the TMA unit
83
+ :type internal_type: Type[Numeric]
84
+ :return: A Copy Atom for this Operation and the associated TMA tensor
85
+ :rtype: Tuple[core.CopyAtom, Tensor]
86
+ """
87
+
88
+ if internal_type is not None:
89
+ if not isinstance(internal_type, NumericMeta):
90
+ raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
91
+ internal_type = internal_type.mlir_type
92
+
93
+ cta_v_map = core.composition(
94
+ core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip),
95
+ cta_tiler,
96
+ loc=loc,
97
+ ip=ip,
98
+ )
99
+
100
+ if isinstance(op, CopyBulkTensorTileG2SOp):
101
+ if num_multicast != 1:
102
+ raise ValueError(
103
+ f"expects num_multicast to be 1 for non multicast G2S copies, "
104
+ f"but got {num_multicast}"
105
+ )
106
+ res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
107
+ gmem_tensor.value,
108
+ smem_layout,
109
+ cta_v_map,
110
+ op._to_ir(),
111
+ num_multicast=num_multicast,
112
+ internal_type=internal_type,
113
+ loc=loc,
114
+ ip=ip,
115
+ )
116
+ return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1]
117
+ elif isinstance(op, CopyBulkTensorTileG2SMulticastOp):
118
+ if num_multicast < 1:
119
+ raise ValueError(
120
+ f"expects num_multicast to be >= 1 for multicast G2S copies, "
121
+ f"but got {num_multicast}"
122
+ )
123
+ res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
124
+ gmem_tensor.value,
125
+ smem_layout,
126
+ cta_v_map,
127
+ op._to_ir(),
128
+ num_multicast=num_multicast,
129
+ internal_type=internal_type,
130
+ loc=loc,
131
+ ip=ip,
132
+ )
133
+ return (
134
+ core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])),
135
+ res[1],
136
+ )
137
+ elif isinstance(op, CopyBulkTensorTileS2GOp):
138
+ res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_store(
139
+ gmem_tensor.value,
140
+ smem_layout,
141
+ cta_v_map,
142
+ internal_type=internal_type,
143
+ loc=loc,
144
+ ip=ip,
145
+ )
146
+ return core.CopyAtom(op, CopyBulkTensorTileS2GTrait(res[0])), res[1]
147
+ elif isinstance(op, CopyReduceBulkTensorTileS2GOp):
148
+ res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_reduce(
149
+ gmem_tensor.value,
150
+ smem_layout,
151
+ cta_v_map,
152
+ op._to_ir(),
153
+ internal_type=internal_type,
154
+ loc=loc,
155
+ ip=ip,
156
+ )
157
+ return core.CopyAtom(op, CopyReduceBulkTensorTileS2GTrait(res[0])), res[1]
158
+ else:
159
+ raise ValueError(f"expects a bulk tensor (TMA) Copy Op, but got {op}")
160
+
161
+
162
+ @dsl_user_op
163
+ def tma_partition(
164
+ atom: core.CopyAtom,
165
+ cta_coord: Coord,
166
+ cta_layout: Layout,
167
+ smem_tensor: Tensor,
168
+ gmem_tensor: Tensor,
169
+ *,
170
+ loc=None,
171
+ ip=None,
172
+ ) -> Tuple[Tensor, Tensor]:
173
+ """
174
+ Tiles the GMEM and SMEM tensors for the provided TMA Copy Atom.
175
+ """
176
+ cta_coord_val = core._pack_coord(cta_coord, loc=loc, ip=ip)
177
+ s, d = _cute_nvgpu_ir.atom_tma_partition(
178
+ atom._trait.value,
179
+ cta_coord=cta_coord_val,
180
+ cta_layout=cta_layout,
181
+ smem_tensor=smem_tensor.value,
182
+ gmem_tensor=gmem_tensor.value,
183
+ loc=loc,
184
+ ip=ip,
185
+ )
186
+ return s, d
187
+
188
+
189
+ @dsl_user_op
190
+ def create_tma_multicast_mask(
191
+ cta_layout_vmnk: Layout,
192
+ cta_coord_vmnk: Coord,
193
+ mcast_mode: int,
194
+ *,
195
+ loc=None,
196
+ ip=None,
197
+ ) -> Int16:
198
+ """
199
+ Computes a multicast mask for a TMA load Copy.
200
+
201
+ :param cta_layout_vmnk: The VMNK layout of the cluster
202
+ :type cta_layout_vmnk: Layout
203
+ :param cta_coord_vmnk: The VMNK coordinate of the current CTA
204
+ :type cta_coord_vmnk: Coord
205
+ :param mcast_mode: The tensor mode in which to multicast
206
+ :type mcast_mode: int
207
+ :return: The resulting mask
208
+ :rtype: Int16
209
+ """
210
+ if core.rank(cta_layout_vmnk) != 4:
211
+ raise ValueError(
212
+ f"cta_layout_vmnk must be rank 4, but got {core.pretty_str(cta_layout_vmnk)}"
213
+ )
214
+ if core.rank(cta_coord_vmnk) != 4:
215
+ raise ValueError(
216
+ f"cta_coord_vmnk must be rank 4, but got {core.pretty_str(cta_coord_vmnk)}"
217
+ )
218
+ return core.make_layout_image_mask(
219
+ cta_layout_vmnk, cta_coord_vmnk, mcast_mode, loc=loc, ip=ip
220
+ )
221
+
222
+
223
+ @dsl_user_op
224
+ def prefetch_descriptor(tma_atom: core.CopyAtom, *, loc=None, ip=None) -> None:
225
+ """
226
+ Prefetches the TMA descriptor associated with the TMA Atom.
227
+ """
228
+ _cute_nvgpu_ir.prefetch_tma_desc(tma_atom._trait.value, loc=loc, ip=ip)
229
+
230
+
231
+ @dsl_user_op
232
+ def copy_tensormap(
233
+ tma_atom: core.CopyAtom, tensormap_ptr: Pointer, *, loc=None, ip=None
234
+ ) -> None:
235
+ """
236
+ Copies the tensormap held by a TMA Copy Atom to the memory location pointed to by the provided
237
+ pointer.
238
+
239
+ :param tma_atom: The TMA Copy Atom
240
+ :type tma_atom: CopyAtom
241
+ :param tensormap_ptr: The pointer to the memory location to copy the tensormap to
242
+ :type tensormap_ptr: Pointer
243
+ """
244
+ _cute_nvgpu_ir.copy_tma_desc(
245
+ tma_atom._trait.value, tensormap_ptr.value, loc=loc, ip=ip
246
+ )
247
+
248
+
249
+ @dsl_user_op
250
+ def update_tma_descriptor(
251
+ tma_atom: core.CopyAtom,
252
+ gmem_tensor: Tensor,
253
+ tma_desc_ptr: Pointer,
254
+ *,
255
+ loc=None,
256
+ ip=None,
257
+ ) -> None:
258
+ """
259
+ Updates the TMA descriptor in the memory location pointed to by the provided pointer using
260
+ information from a TMA Copy Atom and the provided GMEM tensor.
261
+
262
+ Specifically, the following fields of the TMA descriptor will be updated:
263
+
264
+ 1. the GMEM tensor base address
265
+ 2. the GMEM tensor shape
266
+ 3. the GMEM tensor stride
267
+
268
+ Other fields of the TMA descriptor are left unchanged.
269
+
270
+ :param tma_atom: The TMA Copy Atom
271
+ :type tma_atom: CopyAtom
272
+ :param gmem_tensor: The GMEM tensor
273
+ :type gmem_tensor: Tensor
274
+ :param tensormap_ptr: The pointer to the memory location of the descriptor to udpate
275
+ :type tensormap_ptr: Pointer
276
+ """
277
+ _cute_nvgpu_ir.update_tma_desc(
278
+ tma_atom._trait.value, gmem_tensor.value, tma_desc_ptr.value, loc=loc, ip=ip
279
+ )
280
+
281
+
282
+ @dsl_user_op
283
+ def fence_tma_desc_acquire(
284
+ tma_desc_ptr: Pointer,
285
+ *,
286
+ loc=None,
287
+ ip=None,
288
+ ) -> None:
289
+ """
290
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
291
+ """
292
+ tma_desc_ptr_i64 = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value()
293
+ llvm.inline_asm(
294
+ None,
295
+ [tma_desc_ptr_i64],
296
+ "fence.proxy.tensormap::generic.acquire.gpu [$0], 128;",
297
+ "l",
298
+ has_side_effects=True,
299
+ is_align_stack=False,
300
+ asm_dialect=llvm.AsmDialect.AD_ATT,
301
+ )
302
+
303
+
304
+ @dsl_user_op
305
+ def cp_fence_tma_desc_release(
306
+ tma_desc_global_ptr: Pointer,
307
+ tma_desc_shared_ptr: Pointer,
308
+ *,
309
+ loc=None,
310
+ ip=None,
311
+ ) -> None:
312
+ """
313
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-tensormap-cp-fenceproxy>`__.
314
+ """
315
+ tma_desc_global_ptr_i64 = tma_desc_global_ptr.toint(loc=loc, ip=ip).ir_value()
316
+ tma_desc_shared_ptr_i32 = tma_desc_shared_ptr.toint(loc=loc, ip=ip).ir_value()
317
+ llvm.inline_asm(
318
+ None,
319
+ [tma_desc_global_ptr_i64, tma_desc_shared_ptr_i32],
320
+ "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [$0], [$1], 128;",
321
+ "l,r",
322
+ has_side_effects=True,
323
+ is_align_stack=False,
324
+ asm_dialect=llvm.AsmDialect.AD_ATT,
325
+ )
326
+
327
+
328
+ @dsl_user_op
329
+ def fence_tma_desc_release(*, loc=None, ip=None) -> None:
330
+ """
331
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
332
+ """
333
+ llvm.inline_asm(
334
+ None,
335
+ [],
336
+ "fence.proxy.tensormap::generic.release.gpu;",
337
+ "",
338
+ has_side_effects=True,
339
+ is_align_stack=False,
340
+ asm_dialect=llvm.AsmDialect.AD_ATT,
341
+ )
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from typing import Optional, Tuple, Type, Union
13
+
14
+ from cutlass.cutlass_dsl import dsl_user_op
15
+
16
+ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
17
+
18
+ from .. import core
19
+ from ..typing import Shape, Layout, Tensor, Numeric, NumericMeta
20
+ from ...impl_utils import check_type_in
21
+ from .cpasync.copy import (
22
+ CopyBulkTensorTileG2SOp,
23
+ CopyBulkTensorTileG2SNonExecTrait,
24
+ CopyBulkTensorTileG2SMulticastOp,
25
+ CopyBulkTensorTileG2SMulticastNonExecTrait,
26
+ )
27
+
28
+
29
+ ####################################################################################################
30
+ #
31
+ # TMA creation helpers for tcgen05 MMAs
32
+ #
33
+ ####################################################################################################
34
+
35
+
36
+ @dsl_user_op
37
+ def make_tiled_tma_atom_A(
38
+ op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
39
+ gmem_tensor: Tensor,
40
+ smem_layout: Union[Layout, core.ComposedLayout],
41
+ mma_tiler_mnk: Shape,
42
+ tiled_mma: core.TiledMma,
43
+ cluster_shape_vmnk: Shape,
44
+ *,
45
+ internal_type: Optional[Type[Numeric]] = None,
46
+ loc=None,
47
+ ip=None,
48
+ ) -> Tuple[core.CopyAtom, Tensor]:
49
+ """
50
+ Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation
51
+ accounting for the MK projections of the TiledMMA for A tensor loads.
52
+
53
+ Given
54
+
55
+ - a GMEM tensor
56
+ - a SMEM layout
57
+ - a MMA Tiler
58
+ - a TiledMma
59
+ - a Cluster-level shape
60
+
61
+ this function figures out the bulk tensor asynchronous copy instruction to use with the maximum
62
+ "TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided
63
+ layout and consistent with the provided Tiler & tiled_mma (considering the M-mode & K-mode).
64
+ The Cluster-level shape is used to determine the multicast factor across the N-mode for A tensor loads.
65
+
66
+ This function returns two results:
67
+
68
+ 1. the Copy Atom
69
+ 2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates
70
+ that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the
71
+ associated layout can output coordinates. Otherwise, TMA tensors can be partitioned
72
+ similarly to any other CuTe tensors using the algebra.
73
+
74
+ :param op: The Copy Operation to construct an Atom for
75
+ :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp]
76
+ :param gmem_tensor: The GMEM tensor to be loaded by this copy atom
77
+ :type gmem_tensor: Tensor
78
+ :param smem_layout: Shared memory layout to load the tensor into (PDSL)
79
+ :type smem_layout: Union[Layout, core.ComposedLayout]
80
+ :param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions
81
+ :type mma_tiler_mnk: Shape
82
+ :param tiled_mma: The TiledMMA that will consume the load as operands
83
+ :type tiled_mma: core.TiledMma
84
+ :param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions
85
+ :type cluster_shape_vmnk: Shape
86
+ :param internal_type: An optional parameter for the internal data type to when element
87
+ type does not match the copy type
88
+ :type internal_type: Type[Numeric]
89
+ :return: A copy atom for this operation and the associated TMA coord tensor
90
+ :rtype: Tuple[core.CopyAtom, Tensor]
91
+
92
+ """
93
+
94
+ if internal_type is not None:
95
+ if not isinstance(internal_type, NumericMeta):
96
+ raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
97
+ internal_type = internal_type.mlir_type
98
+ check_type_in(
99
+ op,
100
+ [CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
101
+ "op",
102
+ "make_tiled_tma_atom_A",
103
+ )
104
+
105
+ ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip)
106
+ mma_tiler_mk = (mma_tiler_mnk[0], *mma_tiler_mnk[2:])
107
+ g_tile = core.composition(ident, mma_tiler_mk, loc=loc, ip=ip)
108
+ cta_v_map = tiled_mma._thrfrg_A(g_tile)
109
+ cta_v_map = core.get(cta_v_map, mode=[1])
110
+ cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile)))
111
+
112
+ if isinstance(op, CopyBulkTensorTileG2SOp):
113
+ num_multicast = 1
114
+ else:
115
+ assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
116
+ # multicast across the N-mode since those would share the same tile of A
117
+ num_multicast = core.size(cluster_shape_vmnk, mode=[2])
118
+
119
+ # res[0] = the IR Value for the non-executable atom instance
120
+ # res[1] = the IR Value for the associated TMA tensor
121
+ res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
122
+ gmem_tensor.value,
123
+ smem_layout,
124
+ cta_v_map,
125
+ op._to_ir(),
126
+ num_multicast=num_multicast,
127
+ internal_type=internal_type,
128
+ loc=loc,
129
+ ip=ip,
130
+ )
131
+ if isinstance(op, CopyBulkTensorTileG2SOp):
132
+ return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1]
133
+ else:
134
+ assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
135
+ return (
136
+ core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])),
137
+ res[1],
138
+ )
139
+
140
+
141
+ @dsl_user_op
142
+ def make_tiled_tma_atom_B(
143
+ op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
144
+ gmem_tensor: Tensor,
145
+ smem_layout: Union[Layout, core.ComposedLayout],
146
+ mma_tiler_mnk: Shape,
147
+ tiled_mma: core.TiledMma,
148
+ cluster_shape_vmnk: Shape,
149
+ *,
150
+ internal_type: Optional[Type[Numeric]] = None,
151
+ loc=None,
152
+ ip=None,
153
+ ) -> Tuple[core.CopyAtom, Tensor]:
154
+ """
155
+ Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation
156
+ accounting for the NK projections of the TiledMMA for B tensor loads.
157
+
158
+ Given
159
+
160
+ - a GMEM tensor
161
+ - a SMEM layout
162
+ - a MMA Tiler
163
+ - a TiledMma
164
+ - a Cluster-level shape
165
+
166
+ this function figures out the bulk tensor asynchronous copy instruction to use with the maximum
167
+ "TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided
168
+ layout and consistent with the provided Tiler & tiled_mma (considering the N-mode & K-mode).
169
+ The Cluster-level shape is used to determine the multicast factor across the M-mode for B tensor loads.
170
+
171
+ This function returns two results:
172
+
173
+ 1. the Copy Atom
174
+ 2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates
175
+ that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the
176
+ associated layout can output coordinates. Otherwise, TMA tensors can be partitioned
177
+ similarly to any other CuTe tensors using the algebra.
178
+
179
+ :param op: The Copy Operation to construct an Atom for
180
+ :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp]
181
+ :param gmem_tensor: The GMEM tensor to be loaded by this copy atom
182
+ :type gmem_tensor: Tensor
183
+ :param smem_layout: Shared memory layout to load the tensor into (PDSL)
184
+ :type smem_layout: Union[Layout, core.ComposedLayout]
185
+ :param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions
186
+ :type mma_tiler_mnk: Shape
187
+ :param tiled_mma: The TiledMMA that will consume the load as operands
188
+ :type tiled_mma: core.TiledMma
189
+ :param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions
190
+ :type cluster_shape_vmnk: Shape
191
+ :param internal_type: An optional parameter for the internal data type to when element
192
+ type does not match the copy type
193
+ :type internal_type: Type[Numeric]
194
+ :return: A Copy Atom for this Operation and the associated TMA tensor
195
+ :rtype: Tuple[core.CopyAtom, Tensor]
196
+
197
+ """
198
+
199
+ if internal_type is not None:
200
+ if not isinstance(internal_type, NumericMeta):
201
+ raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
202
+ internal_type = internal_type.mlir_type
203
+ check_type_in(
204
+ op,
205
+ [CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
206
+ "op",
207
+ "make_tiled_tma_atom_B",
208
+ )
209
+
210
+ ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip)
211
+ mma_tiler_nk = (mma_tiler_mnk[1], *mma_tiler_mnk[2:])
212
+ g_tile = core.composition(ident, mma_tiler_nk, loc=loc, ip=ip)
213
+ cta_v_map = tiled_mma._thrfrg_B(g_tile)
214
+ cta_v_map = core.get(cta_v_map, mode=[1])
215
+ cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile)))
216
+
217
+ if isinstance(op, CopyBulkTensorTileG2SOp):
218
+ num_multicast = 1
219
+ else:
220
+ assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
221
+ # multicast across the M-mode since those would share the same tile of B
222
+ num_multicast = core.size(cluster_shape_vmnk, mode=[1])
223
+
224
+ # res[0] = the IR Value for the non-executable atom instance
225
+ # res[1] = the IR Value for the associated TMA tensor
226
+ res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
227
+ gmem_tensor.value,
228
+ smem_layout,
229
+ cta_v_map,
230
+ op._to_ir(),
231
+ num_multicast=num_multicast,
232
+ internal_type=internal_type,
233
+ loc=loc,
234
+ ip=ip,
235
+ )
236
+ if isinstance(op, CopyBulkTensorTileG2SOp):
237
+ return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1]
238
+ else:
239
+ assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
240
+ return (
241
+ core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])),
242
+ res[1],
243
+ )
244
+
245
+
246
+ __all__ = [
247
+ "make_tiled_tma_atom_A",
248
+ "make_tiled_tma_atom_B",
249
+ ]
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from .copy import *
13
+ from .mma import *
14
+ from .helpers import *
15
+
16
+ # __all__ is required here for documentation generation
17
+ __all__ = [
18
+ #
19
+ # copy.py
20
+ #
21
+ "Repetition",
22
+ "Pack",
23
+ "Unpack",
24
+ "Ld16x64bOp",
25
+ "Ld16x128bOp",
26
+ "Ld16x256bOp",
27
+ "Ld16x32bx2Op",
28
+ "Ld32x32bOp",
29
+ "St16x64bOp",
30
+ "St16x128bOp",
31
+ "St16x256bOp",
32
+ "St16x32bx2Op",
33
+ "St32x32bOp",
34
+ #
35
+ # mma.py
36
+ #
37
+ "OperandMajorMode",
38
+ "OperandSource",
39
+ "CtaGroup",
40
+ "Field",
41
+ "MmaTF32Op",
42
+ "MmaF16BF16Op",
43
+ "MmaI8Op",
44
+ "MmaFP8Op",
45
+ "MmaMXF8Op",
46
+ "MmaMXF4Op",
47
+ "MmaMXF4NVF4Op",
48
+ "SmemLayoutAtomKind",
49
+ #
50
+ # helpers.py
51
+ #
52
+ "make_smem_layout_atom",
53
+ "tile_to_mma_shape",
54
+ "commit",
55
+ "is_tmem_load",
56
+ "is_tmem_store",
57
+ "get_tmem_copy_properties",
58
+ "find_tmem_tensor_col_offset",
59
+ "make_tmem_copy",
60
+ "make_s2t_copy",
61
+ "get_s2t_smem_desc_tensor",
62
+ ]
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ import enum
13
+ from dataclasses import dataclass
14
+ from typing import Type
15
+
16
+ from cutlass.cutlass_dsl import CuTeDSL
17
+
18
+ import cutlass._mlir.dialects.cute as _cute_ir
19
+ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
20
+ from cutlass._mlir import ir
21
+
22
+ from ..common import OpError
23
+ from ...core import CopyOp, Trait
24
+ from ...typing import Numeric
25
+
26
+ from .mma import CtaGroup
27
+
28
+
29
+ class Repetition(enum.Enum):
30
+ """
31
+ An enumeration for the number of repetitions of a given TMEM copy within the instruction.
32
+ """
33
+
34
+ x1 = 1
35
+ x2 = 2
36
+ x4 = 4
37
+ x8 = 8
38
+ x16 = 16
39
+ x32 = 32
40
+ x64 = 64
41
+ x128 = 128
42
+
43
+ def __str__(self) -> str:
44
+ return f"{self.__class__.__name__}.{self.name}"
45
+
46
+ def __repr__(self) -> str:
47
+ return f"<{self.__class__.__name__}.{self.name}>"
48
+
49
+ @classmethod
50
+ def _missing_(cls, value):
51
+ if isinstance(value, int):
52
+ if value == 1:
53
+ return Repetition.x1
54
+ elif value == 2:
55
+ return Repetition.x2
56
+ elif value == 8:
57
+ return Repetition.x8
58
+ elif value == 16:
59
+ return Repetition.x16
60
+ elif value == 32:
61
+ return Repetition.x32
62
+ elif value == 64:
63
+ return Repetition.x64
64
+ elif value == 128:
65
+ return Repetition.x128
66
+
67
+
68
+ class Pack(enum.Enum):
69
+ """
70
+ An enumeration for the possible packing patterns for TMEM to RMEM copies.
71
+ """
72
+
73
+ NONE = enum.auto()
74
+ PACK_16b_IN_32b = enum.auto()
75
+
76
+ def __str__(self) -> str:
77
+ return f"{self.__class__.__name__}.{self.name}"
78
+
79
+ def __repr__(self) -> str:
80
+ return f"<{self.__class__.__name__}.{self.name}>"
81
+
82
+
83
+ class Unpack(enum.Enum):
84
+ """
85
+ An enumeration for the possible unpacking patterns for RMEM to TMEM copies.
86
+ """
87
+
88
+ NONE = enum.auto()
89
+ UNPACK_32b_IN_16b = enum.auto()
90
+
91
+ def __str__(self) -> str:
92
+ return f"{self.__class__.__name__}.{self.name}"
93
+
94
+ def __repr__(self) -> str:
95
+ return f"<{self.__class__.__name__}.{self.name}>"
96
+
97
+
98
+ @dataclass(frozen=True)
99
+ class _LdBase(CopyOp):
100
+ repeat: Repetition = Repetition.x1
101
+ pack: Pack = Pack.NONE
102
+
103
+ admissible_archs = [
104
+ "sm_100a",
105
+ "sm_100f",
106
+ ]
107
+
108
+ def __post_init__(self) -> None:
109
+ # Arch verification
110
+ arch = CuTeDSL._get_dsl().envar.arch
111
+ if arch not in self.admissible_archs:
112
+ raise OpError(
113
+ self,
114
+ f"expects arch to be one of {self.admissible_archs}, but got {arch}",
115
+ suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
116
+ )
117
+
118
+ if not isinstance(self.repeat, Repetition):
119
+ raise OpError(
120
+ self,
121
+ "expects the 'repeat' Op parameter to be a tcgen05.Repetition instance",
122
+ )
123
+ if not isinstance(self.pack, Pack):
124
+ raise OpError(
125
+ self,
126
+ "expects the 'pack' Op parameter to be a tcgen05.Pack instance",
127
+ )
128
+
129
+ def __str__(self) -> str:
130
+ res = (
131
+ f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation"
132
+ + f"\n number of repetitions = {self.repeat.value}"
133
+ )
134
+ if self.pack == Pack.PACK_16b_IN_32b:
135
+ res += f"\n with 2x 16-bit to 32b packing"
136
+ return res
137
+
138
+
139
+ @dataclass(frozen=True)
140
+ class Ld16x64bOp(_LdBase):
141
+ """
142
+ 16x64b TMEM load Operation.
143
+
144
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
145
+ This Operation corresponds to the ``.16x64b`` qualifier.
146
+ """
147
+
148
+ def _make_trait(
149
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
150
+ ) -> "Ld16x64bTrait":
151
+ ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
152
+ copy_internal_type.mlir_type,
153
+ 16,
154
+ 64,
155
+ self.repeat.value,
156
+ ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
157
+ )
158
+ return Ld16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
159
+
160
+
161
+ class Ld16x64bTrait(Trait):
162
+ pass
163
+
164
+
165
+ @dataclass(frozen=True)
166
+ class Ld16x128bOp(_LdBase):
167
+ """
168
+ 16x128b TMEM load Operation.
169
+
170
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
171
+ This Operation corresponds to the ``.16x128b`` qualifier.
172
+ """
173
+
174
+ def __post_init__(self) -> None:
175
+ super().__post_init__()
176
+ if self.repeat == Repetition.x128:
177
+ raise OpError(
178
+ self,
179
+ "x128 repetition is not supported",
180
+ suggestion="choose one of x1, x2, x4, x8, x16, x32, x64",
181
+ )
182
+
183
+ def _make_trait(
184
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
185
+ ) -> "Ld16x128bTrait":
186
+ ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
187
+ copy_internal_type.mlir_type,
188
+ 16,
189
+ 128,
190
+ self.repeat.value,
191
+ ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
192
+ )
193
+ return Ld16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
194
+
195
+
196
+ class Ld16x128bTrait(Trait):
197
+ pass
198
+
199
+
200
+ @dataclass(frozen=True)
201
+ class Ld16x256bOp(_LdBase):
202
+ """
203
+ 16x256b TMEM load Operation.
204
+
205
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
206
+ This Operation corresponds to the ``.16x256b`` qualifier.
207
+ """
208
+
209
+ def __post_init__(self) -> None:
210
+ super().__post_init__()
211
+ if self.repeat in (Repetition.x128, Repetition.x64):
212
+ raise OpError(
213
+ self,
214
+ "x64 and x128 repetition is not supported",
215
+ suggestion="choose one of x1, x2, x4, x8, x16, x32",
216
+ )
217
+
218
+ def _make_trait(
219
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
220
+ ) -> "Ld16x256bTrait":
221
+ ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
222
+ copy_internal_type.mlir_type,
223
+ 16,
224
+ 256,
225
+ self.repeat.value,
226
+ ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
227
+ )
228
+ return Ld16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
229
+
230
+
231
+ class Ld16x256bTrait(Trait):
232
+ pass
233
+
234
+
235
+ @dataclass(frozen=True)
236
+ class Ld16x32bx2Op(_LdBase):
237
+ """
238
+ 16x32bx2 TMEM load Operation.
239
+
240
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
241
+ This Operation corresponds to the ``.16x32bx2`` qualifier.
242
+ """
243
+
244
+ def _make_trait(
245
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
246
+ ) -> "Ld16x32bx2Trait":
247
+ ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
248
+ copy_internal_type.mlir_type,
249
+ 16,
250
+ 32,
251
+ self.repeat.value,
252
+ ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
253
+ )
254
+ return Ld16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
255
+
256
+
257
+ class Ld16x32bx2Trait(Trait):
258
+ pass
259
+
260
+
261
+ @dataclass(frozen=True)
262
+ class Ld32x32bOp(_LdBase):
263
+ """
264
+ 32x32b TMEM load Operation.
265
+
266
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
267
+ This Operation corresponds to the ``.32x32`` qualifier.
268
+ """
269
+
270
+ def _make_trait(
271
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
272
+ ) -> "Ld32x32bTrait":
273
+ ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
274
+ copy_internal_type.mlir_type,
275
+ 32,
276
+ 32,
277
+ self.repeat.value,
278
+ ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
279
+ )
280
+ return Ld32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
281
+
282
+
283
+ class Ld32x32bTrait(Trait):
284
+ pass
285
+
286
+
287
+ @dataclass(frozen=True)
288
+ class _StBase(CopyOp):
289
+ repeat: Repetition
290
+ unpack: Unpack = Unpack.NONE
291
+
292
+ admissible_archs = [
293
+ "sm_100a",
294
+ "sm_100f",
295
+ ]
296
+
297
+ def __post_init__(self) -> None:
298
+ # Arch verification
299
+ arch = CuTeDSL._get_dsl().envar.arch
300
+ if arch not in self.admissible_archs:
301
+ raise OpError(
302
+ self,
303
+ f"expects arch to be one of {self.admissible_archs}, but got {arch}",
304
+ suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
305
+ )
306
+
307
+ if not isinstance(self.repeat, Repetition):
308
+ raise OpError(
309
+ self,
310
+ "expects the 'repeat' Op parameter to be a tcgen05.Repetition instance",
311
+ )
312
+ if not isinstance(self.unpack, Unpack):
313
+ raise OpError(
314
+ self,
315
+ "expects the 'pack' Op parameter to be a tcgen05.Unpack instance",
316
+ )
317
+
318
+ def __str__(self) -> str:
319
+ res = (
320
+ f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation"
321
+ + f"\n number of repetitions = {self.repeat.value}"
322
+ )
323
+ if self.unpack == Unpack.UNPACK_32b_IN_16b:
324
+ res += f"\n with 32-bit to 2x 16b unpacking"
325
+ return res
326
+
327
+
328
+ @dataclass(frozen=True)
329
+ class St16x64bOp(_StBase):
330
+ """
331
+ 16x64b TMEM store Operation.
332
+
333
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
334
+ This Operation corresponds to the ``.16x64`` qualifier.
335
+ """
336
+
337
+ def _make_trait(
338
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
339
+ ) -> "St16x64bTrait":
340
+ ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
341
+ copy_internal_type.mlir_type,
342
+ 16,
343
+ 64,
344
+ self.repeat.value,
345
+ ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
346
+ )
347
+ return St16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
348
+
349
+
350
+ class St16x64bTrait(Trait):
351
+ pass
352
+
353
+
354
+ @dataclass(frozen=True)
355
+ class St16x128bOp(_StBase):
356
+ """
357
+ 16x128b TMEM store Operation.
358
+
359
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
360
+ This Operation corresponds to the ``.16x128`` qualifier.
361
+ """
362
+
363
+ def __post_init__(self) -> None:
364
+ super().__post_init__()
365
+ if self.repeat == Repetition.x128:
366
+ raise OpError(
367
+ self,
368
+ "x128 repetition is not supported",
369
+ suggestion="choose one of x1, x2, x4, x8, x16, x32, x64",
370
+ )
371
+
372
+ def _make_trait(
373
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
374
+ ) -> "St16x128bTrait":
375
+ ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
376
+ copy_internal_type.mlir_type,
377
+ 16,
378
+ 128,
379
+ self.repeat.value,
380
+ ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
381
+ )
382
+ return St16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
383
+
384
+
385
+ class St16x128bTrait(Trait):
386
+ pass
387
+
388
+
389
+ @dataclass(frozen=True)
390
+ class St16x256bOp(_StBase):
391
+ """
392
+ 16x256b TMEM store Operation.
393
+
394
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
395
+ This Operation corresponds to the ``.16x256`` qualifier.
396
+ """
397
+
398
+ def __post_init__(self) -> None:
399
+ super().__post_init__()
400
+ if self.repeat in (Repetition.x128, Repetition.x64):
401
+ raise OpError(
402
+ self,
403
+ "x64 and x128 repetition is not supported",
404
+ suggestion="choose one of x1, x2, x4, x8, x16, x32",
405
+ )
406
+
407
+ def _make_trait(
408
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
409
+ ) -> "St16x256bTrait":
410
+ ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
411
+ copy_internal_type.mlir_type,
412
+ 16,
413
+ 256,
414
+ self.repeat.value,
415
+ ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
416
+ )
417
+ return St16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
418
+
419
+
420
+ class St16x256bTrait(Trait):
421
+ pass
422
+
423
+
424
+ @dataclass(frozen=True)
425
+ class St16x32bx2Op(_StBase):
426
+ """
427
+ 16x32x2b TMEM store Operation.
428
+
429
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
430
+ This Operation corresponds to the ``.16x32x2`` qualifier.
431
+ """
432
+
433
+ def _make_trait(
434
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
435
+ ) -> "St16x32bx2Trait":
436
+ ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
437
+ copy_internal_type.mlir_type,
438
+ 16,
439
+ 32,
440
+ self.repeat.value,
441
+ ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
442
+ )
443
+ return St16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
444
+
445
+
446
+ class St16x32bx2Trait(Trait):
447
+ pass
448
+
449
+
450
+ @dataclass(frozen=True)
451
+ class St32x32bOp(_StBase):
452
+ """
453
+ 32x32b TMEM store Operation.
454
+
455
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
456
+ This Operation corresponds to the ``.32x32`` qualifier.
457
+ """
458
+
459
+ def _make_trait(
460
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
461
+ ) -> "St32x32bTrait":
462
+ ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
463
+ copy_internal_type.mlir_type,
464
+ 32,
465
+ 32,
466
+ self.repeat.value,
467
+ ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
468
+ )
469
+ return St32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
470
+
471
+
472
+ class St32x32bTrait(Trait):
473
+ pass
474
+
475
+
476
+ @dataclass(frozen=True)
477
+ class _S2TCopyBase(CopyOp):
478
+ cta_group: CtaGroup
479
+
480
+ admissible_archs = [
481
+ "sm_100a",
482
+ "sm_100f",
483
+ ]
484
+
485
+ def __post_init__(self) -> None:
486
+ # Arch verification
487
+ arch = CuTeDSL._get_dsl().envar.arch
488
+ if arch not in self.admissible_archs:
489
+ raise OpError(
490
+ self,
491
+ f"expects arch to be one of {self.admissible_archs}, but got {arch}",
492
+ suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
493
+ )
494
+ # Verify that the user provided enum values
495
+ if not isinstance(self.cta_group, CtaGroup):
496
+ raise OpError(
497
+ self,
498
+ "expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance",
499
+ )
500
+
501
+ def __str__(self) -> str:
502
+ res = (
503
+ f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation"
504
+ + f"\n CTA group = {self.cta_group}"
505
+ )
506
+
507
+ return res
508
+
509
+
510
+ @dataclass(frozen=True)
511
+ class Cp128x256bOp(_S2TCopyBase):
512
+ """
513
+ 128x256b SMEM to TMEM Copy Operation.
514
+
515
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
516
+ This Operation corresponds to the ``.128x256b`` qualifier.
517
+ """
518
+
519
+ def _make_trait(
520
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
521
+ ) -> "Cp128x256bTrait":
522
+ ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
523
+ copy_internal_type.mlir_type,
524
+ 128,
525
+ 256,
526
+ self.cta_group.value,
527
+ _cute_nvgpu_ir.CopyS2TBroadcast.none,
528
+ )
529
+ return Cp128x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
530
+
531
+
532
+ class Cp128x256bTrait(Trait):
533
+ pass
534
+
535
+
536
+ @dataclass(frozen=True)
537
+ class Cp128x128bOp(_S2TCopyBase):
538
+ """
539
+ 128x128b SMEM to TMEM Copy Operation.
540
+
541
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
542
+ This Operation corresponds to the ``.128x128b`` qualifier.
543
+ """
544
+
545
+ def _make_trait(
546
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
547
+ ) -> "Cp128x128bTrait":
548
+ ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
549
+ copy_internal_type.mlir_type,
550
+ 128,
551
+ 128,
552
+ self.cta_group.value,
553
+ _cute_nvgpu_ir.CopyS2TBroadcast.none,
554
+ )
555
+ return Cp128x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
556
+
557
+
558
+ class Cp128x128bTrait(Trait):
559
+ pass
560
+
561
+
562
+ @dataclass(frozen=True)
563
+ class Cp4x256bOp(_S2TCopyBase):
564
+ """
565
+ 4x256b SMEM to TMEM Copy Operation.
566
+
567
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
568
+ This Operation corresponds to the ``.4x256b`` qualifier.
569
+ """
570
+
571
+ def _make_trait(
572
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
573
+ ) -> "Cp4x256bTrait":
574
+ ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
575
+ copy_internal_type.mlir_type,
576
+ 4,
577
+ 256,
578
+ self.cta_group.value,
579
+ _cute_nvgpu_ir.CopyS2TBroadcast.none,
580
+ )
581
+ return Cp4x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
582
+
583
+
584
+ class Cp4x256bTrait(Trait):
585
+ pass
586
+
587
+
588
+ @dataclass(frozen=True)
589
+ class Cp4x32x128bOp(_S2TCopyBase):
590
+ """
591
+ 32x128b SMEM to TMEM Copy Operation.
592
+
593
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
594
+ This Operation corresponds to the ``.32x128b`` qualifier with ``warpx4`` broadcast qualifier enabled.
595
+ """
596
+
597
+ def _make_trait(
598
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
599
+ ) -> "Cp4x32x128bTrait":
600
+ ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
601
+ copy_internal_type.mlir_type,
602
+ 32,
603
+ 128,
604
+ self.cta_group.value,
605
+ _cute_nvgpu_ir.CopyS2TBroadcast.x4,
606
+ )
607
+ return Cp4x32x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
608
+
609
+
610
+ class Cp4x32x128bTrait(Trait):
611
+ pass
612
+
613
+
614
+ @dataclass(frozen=True)
615
+ class Cp2x64x128b0213Op(_S2TCopyBase):
616
+ """
617
+ 64x128b SMEM to TMEM Copy Operation.
618
+
619
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
620
+ This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::02_13`` broadcast qualifier enabled.
621
+ """
622
+
623
+ def _make_trait(
624
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
625
+ ) -> "Cp2x64x128b0213Trait":
626
+ ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
627
+ copy_internal_type.mlir_type,
628
+ 64,
629
+ 128,
630
+ self.cta_group.value,
631
+ _cute_nvgpu_ir.CopyS2TBroadcast.lw_0213,
632
+ )
633
+ return Cp2x64x128b0213Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
634
+
635
+
636
+ class Cp2x64x128b0213Trait(Trait):
637
+ pass
638
+
639
+
640
+ @dataclass(frozen=True)
641
+ class Cp2x64x128b0123Op(_S2TCopyBase):
642
+ """
643
+ 64x128b SMEM to TMEM Copy Operation.
644
+
645
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
646
+ This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::01_23`` broadcast qualifier enabled.
647
+ """
648
+
649
+ def _make_trait(
650
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
651
+ ) -> "Cp2x64x128b0123Trait":
652
+ ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
653
+ copy_internal_type.mlir_type,
654
+ 64,
655
+ 128,
656
+ self.cta_group.value,
657
+ _cute_nvgpu_ir.CopyS2TBroadcast.lw_0123,
658
+ )
659
+ return Cp2x64x128b0123Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
660
+
661
+
662
+ class Cp2x64x128b0123Trait(Trait):
663
+ pass
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from typing import overload, Type, Tuple, Union
13
+
14
+ from cutlass.cutlass_dsl import dsl_user_op
15
+
16
+ import cutlass._mlir.dialects.cute as _cute_ir
17
+ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
18
+ from cutlass._mlir.dialects import nvvm
19
+
20
+ from ...typing import (
21
+ Shape,
22
+ IntTuple,
23
+ Layout,
24
+ Tensor,
25
+ Int,
26
+ Numeric,
27
+ NumericMeta,
28
+ Int16,
29
+ Int32,
30
+ )
31
+ from ... import core
32
+ from .mma import SmemLayoutAtomKind, CtaGroup
33
+ from .copy import (
34
+ Pack,
35
+ Unpack,
36
+ Ld16x64bOp,
37
+ Ld16x128bOp,
38
+ Ld16x256bOp,
39
+ Ld16x32bx2Op,
40
+ Ld32x32bOp,
41
+ St16x64bOp,
42
+ St16x128bOp,
43
+ St16x256bOp,
44
+ St16x32bx2Op,
45
+ St32x32bOp,
46
+ )
47
+
48
+
49
+ ####################################################################################################
50
+ #
51
+ # Helper functions for MMA
52
+ #
53
+ ####################################################################################################
54
+
55
+
56
+ @dsl_user_op
57
+ def make_smem_layout_atom(
58
+ kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None
59
+ ) -> core.ComposedLayout:
60
+ """
61
+ Makes a SMEM layout Atom.
62
+
63
+ This function creates a composed layout in unit of elements consistent with the requested layout
64
+ Atom kind and element data type.
65
+
66
+ :param kind: The kind of layout Atom
67
+ :type kind: SmemLayoutAtomKind
68
+ :param element_type: The element data type to construct the layout for
69
+ :type element_type: Type[Numeric]
70
+ :return: The SMEM layout atom
71
+ :rtype: core.ComposedLayout
72
+ """
73
+ if not isinstance(element_type, NumericMeta):
74
+ raise TypeError(f"element_type must be a Numeric, but got {element_type}")
75
+
76
+ if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER):
77
+ num_contiguous_bits = 128
78
+ sw = core.make_swizzle(0, 4, 3)
79
+ elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32):
80
+ num_contiguous_bits = 256
81
+ sw = core.make_swizzle(1, 4, 3)
82
+ elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64):
83
+ num_contiguous_bits = 512
84
+ sw = core.make_swizzle(2, 4, 3)
85
+ elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128):
86
+ num_contiguous_bits = 1024
87
+ sw = core.make_swizzle(3, 4, 3)
88
+ elif kind == SmemLayoutAtomKind.MN_SW128_32B:
89
+ num_contiguous_bits = 1024
90
+ sw = core.make_swizzle(2, 5, 2)
91
+ else:
92
+ raise ValueError("unrecognized SMEM layout atom kind")
93
+ num_contiguous_elems = num_contiguous_bits // element_type.width
94
+
95
+ if kind in (
96
+ SmemLayoutAtomKind.MN_INTER,
97
+ SmemLayoutAtomKind.MN_SW32,
98
+ SmemLayoutAtomKind.MN_SW64,
99
+ SmemLayoutAtomKind.MN_SW128,
100
+ SmemLayoutAtomKind.MN_SW128_32B,
101
+ ):
102
+ # M/N-major layout
103
+ return core.make_composed_layout(
104
+ sw,
105
+ 0,
106
+ core.make_layout(
107
+ (num_contiguous_elems, 8), stride=(1, num_contiguous_elems)
108
+ ),
109
+ loc=loc,
110
+ ip=ip,
111
+ )
112
+ else:
113
+ # K-major layout
114
+ return core.make_composed_layout(
115
+ sw,
116
+ 0,
117
+ core.make_layout(
118
+ (8, num_contiguous_elems), stride=(num_contiguous_elems, 1)
119
+ ),
120
+ loc=loc,
121
+ ip=ip,
122
+ )
123
+
124
+
125
+ @overload
126
+ def tile_to_mma_shape(
127
+ atom: Layout, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None
128
+ ) -> Layout: ...
129
+
130
+
131
+ @overload
132
+ def tile_to_mma_shape(
133
+ atom: core.ComposedLayout,
134
+ mma_tile_shape: Shape,
135
+ order: IntTuple = None,
136
+ *,
137
+ loc=None,
138
+ ip=None,
139
+ ) -> core.ComposedLayout: ...
140
+
141
+
142
+ @dsl_user_op
143
+ def tile_to_mma_shape(
144
+ atom, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None
145
+ ):
146
+ """
147
+ Tiles a layout to an MMA shape.
148
+ """
149
+ # Default order is colexicographical
150
+ if order is None:
151
+ order = tuple(range(core.rank(mma_tile_shape) - 1))
152
+ if core.rank(order) != core.rank(mma_tile_shape) - 1:
153
+ raise ValueError(
154
+ f"rank(order)={core.rank(order)} must be equal to "
155
+ f"rank(mma_tile_shape)-1={core.rank(mma_tile_shape)-1}"
156
+ )
157
+ order_val = core._pack_int_tuple(order, loc=loc, ip=ip)
158
+ mma_tile_shape_val = core._pack_shape(mma_tile_shape, loc=loc, ip=ip)
159
+
160
+ if not (
161
+ core.is_static(atom)
162
+ and core.is_static(mma_tile_shape_val)
163
+ and core.is_static(order_val)
164
+ ):
165
+ raise ValueError("tile_to_mma_shape only supports static inputs")
166
+
167
+ res_ty = _cute_nvgpu_ir.tile_to_mma_shape(atom, mma_tile_shape_val, order_val)
168
+ return _cute_ir.static(res_ty, loc=loc, ip=ip)
169
+
170
+
171
+ @dsl_user_op
172
+ def commit(
173
+ mbar_ptr: core.Pointer,
174
+ mask=None,
175
+ cta_group: CtaGroup = CtaGroup.ONE,
176
+ *,
177
+ loc=None,
178
+ ip=None,
179
+ ) -> None:
180
+ """
181
+ Perform an arrive operation on a mbarrier upon completion of previous MMA operations.
182
+
183
+ :param mbar_ptr: A pointer to the mbarrier in SMEM
184
+ :type mbar_ptr: Pointer
185
+ :param mask: An optional multicast mask for the CTAs in the cluster to signal arrival to
186
+ :type mask: Int
187
+ """
188
+ if cta_group == CtaGroup.ONE:
189
+ group = nvvm.Tcgen05GroupKind.CTA_1
190
+ else:
191
+ assert cta_group == CtaGroup.TWO
192
+ group = nvvm.Tcgen05GroupKind.CTA_2
193
+
194
+ mbar_ptr = mbar_ptr.llvm_ptr
195
+ if mask is not None:
196
+ mask = Int16(mask).ir_value(loc=loc, ip=ip)
197
+ nvvm.tcgen05_commit_arrive(
198
+ mbar_ptr, multicast_mask=mask, group=group, loc=loc, ip=ip
199
+ )
200
+ else:
201
+ nvvm.tcgen05_commit_arrive(mbar_ptr, group=group, loc=loc, ip=ip)
202
+ return
203
+
204
+
205
+ ####################################################################################################
206
+ #
207
+ # Helper functions for Copies
208
+ #
209
+ ####################################################################################################
210
+
211
+
212
+ def is_tmem_load(atom: core.CopyAtom) -> bool:
213
+ """
214
+ Returns whether a CopyAtom instance is a TMEM load.
215
+ """
216
+ return isinstance(
217
+ atom.op,
218
+ (
219
+ Ld16x64bOp,
220
+ Ld16x128bOp,
221
+ Ld16x256bOp,
222
+ Ld16x32bx2Op,
223
+ Ld32x32bOp,
224
+ ),
225
+ )
226
+
227
+
228
+ def is_tmem_store(atom: core.CopyAtom) -> bool:
229
+ """
230
+ Returns whether a CopyAtom instance is a TMEM store.
231
+ """
232
+ return isinstance(
233
+ atom.op,
234
+ (
235
+ St16x64bOp,
236
+ St16x128bOp,
237
+ St16x256bOp,
238
+ St16x32bx2Op,
239
+ St32x32bOp,
240
+ ),
241
+ )
242
+
243
+
244
+ def get_tmem_copy_properties(
245
+ atom: core.CopyAtom,
246
+ ) -> Tuple[int, int, int, Union[Pack, Unpack]]:
247
+ """
248
+ Returns the properties of a TMEM copy atom (number of data paths, bits, repetitions,
249
+ and whether packing/unpacking is used).
250
+ """
251
+ if isinstance(atom.op, (Ld16x64bOp, St16x64bOp)):
252
+ num_dp, num_bits = 16, 64
253
+ elif isinstance(atom.op, (Ld16x128bOp, St16x128bOp)):
254
+ num_dp, num_bits = 16, 128
255
+ elif isinstance(atom.op, (Ld16x256bOp, St16x256bOp)):
256
+ num_dp, num_bits = 16, 256
257
+ elif isinstance(atom.op, (Ld16x32bx2Op, St16x32bx2Op)):
258
+ num_dp, num_bits = 16, 32
259
+ elif isinstance(atom.op, (Ld32x32bOp, St32x32bOp)):
260
+ num_dp, num_bits = 32, 32
261
+ else:
262
+ raise ValueError(f"expects 'atom' to be a TMEM copy, but got {atom}")
263
+ if is_tmem_load(atom):
264
+ return num_dp, num_bits, atom.op.repeat.value, atom.op.pack
265
+ else:
266
+ assert is_tmem_store(atom), "atom must be a TMEM store"
267
+ return num_dp, num_bits, atom.op.repeat.value, atom.op.unpack
268
+
269
+
270
+ @dsl_user_op
271
+ def find_tmem_tensor_col_offset(tmem_tensor: Tensor, *, loc=None, ip=None) -> Int:
272
+ """
273
+ Computes the TMEM column offset given a TMEM tensor.
274
+
275
+ :param tmem_tensor: The TMEM tensor to use to compute the columns offset
276
+ :type tmem_tensor: Tensor
277
+ :return: The columns offset
278
+ :rtype: Int
279
+ """
280
+ tmem_col_mask = 0x0000FFFF
281
+ offset = (
282
+ core.cosize(core.recast_tensor(tmem_tensor, Int32).layout, loc=loc, ip=ip)
283
+ & tmem_col_mask
284
+ )
285
+ if isinstance(offset, int):
286
+ return offset
287
+ return Int32(offset, loc=loc, ip=ip)
288
+
289
+
290
+ @dsl_user_op
291
+ def make_tmem_copy(
292
+ atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None
293
+ ) -> core.TiledCopy:
294
+ """
295
+ Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor.
296
+ """
297
+ tiled_copy_val = _cute_nvgpu_ir.atom_make_tmem_copy(
298
+ atom._trait.value, tmem_tensor.value, loc=loc, ip=ip
299
+ )
300
+ new_trait = type(atom._trait)(tiled_copy_val)
301
+ return core.TiledCopy(atom.op, new_trait)
302
+
303
+
304
+ @dsl_user_op
305
+ def make_s2t_copy(
306
+ atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None
307
+ ) -> core.TiledCopy:
308
+ """
309
+ Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor.
310
+ """
311
+ tiled_copy_val = _cute_nvgpu_ir.atom_make_s2t_copy(
312
+ atom._trait.value, tmem_tensor.value, loc=loc, ip=ip
313
+ )
314
+ new_trait = type(atom._trait)(tiled_copy_val)
315
+ return core.TiledCopy(atom.op, new_trait)
316
+
317
+
318
+ @dsl_user_op
319
+ def get_s2t_smem_desc_tensor(
320
+ atom: core.CopyAtom, smem_tensor: Tensor, *, loc=None, ip=None
321
+ ) -> Tensor:
322
+ """
323
+ Returns the SMEM descriptor tensor from a S2T copy atom and a SMEM tensor.
324
+ """
325
+ smem_desc_tensor = _cute_nvgpu_ir.atom_get_copy_s2t_smem_desc_view(
326
+ atom._trait.value, smem_tensor.value, loc=loc, ip=ip
327
+ )
328
+ return smem_desc_tensor
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py ADDED
@@ -0,0 +1,1041 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ import enum
13
+ from dataclasses import dataclass
14
+ from typing import Type
15
+
16
+ from cutlass.cutlass_dsl import CuTeDSL, T
17
+
18
+ import cutlass._mlir.dialects.cute as _cute_ir
19
+ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
20
+ from cutlass._mlir import ir
21
+
22
+ from ..common import OpError
23
+ from ... import core
24
+ from ...core import Trait, _pack_shape, rank, depth, _Tensor
25
+ from ...typing import (
26
+ Shape,
27
+ Float4E2M1FN,
28
+ Float8E8M0FNU,
29
+ Float8E5M2,
30
+ Float8E4M3FN,
31
+ Float16,
32
+ BFloat16,
33
+ Float32,
34
+ TFloat32,
35
+ Boolean,
36
+ Int8,
37
+ Uint8,
38
+ Int32,
39
+ Numeric,
40
+ AddressSpace,
41
+ Pointer,
42
+ )
43
+
44
+
45
+ ####################################################################################################
46
+ #
47
+ # MMA Ops and Traits
48
+ #
49
+ ####################################################################################################
50
+
51
+
52
+ class OperandMajorMode(enum.Enum):
53
+ """
54
+ An enumeration for the majorness of the input operands of the MMA.
55
+ """
56
+
57
+ MN = _cute_ir.MajorMode.mn
58
+ K = _cute_ir.MajorMode.k
59
+
60
+ def __str__(self) -> str:
61
+ return f"{self.__class__.__name__}.{self.name}"
62
+
63
+ def __repr__(self) -> str:
64
+ return f"<{self.__class__.__name__}.{self.name}>"
65
+
66
+ @classmethod
67
+ def _missing_(cls, value):
68
+ if isinstance(value, str):
69
+ value = value.upper()
70
+ if value == "MN":
71
+ return OperandMajorMode.MN
72
+ elif value == "K":
73
+ return OperandMajorMode.K
74
+
75
+ def _to_ir(self) -> _cute_ir.MajorMode:
76
+ return self.value
77
+
78
+
79
+ class OperandSource(enum.Enum):
80
+ """
81
+ An enumeration for the source memory location of the A input operand of the MMA.
82
+ """
83
+
84
+ TMEM = _cute_ir.MmaFragKind.tmem
85
+ SMEM = _cute_ir.MmaFragKind.smem_desc
86
+
87
+ def __str__(self) -> str:
88
+ return f"{self.__class__.__name__}.{self.name}"
89
+
90
+ def __repr__(self) -> str:
91
+ return f"<{self.__class__.__name__}.{self.name}>"
92
+
93
+ def _to_ir(self) -> _cute_ir.MmaFragKind:
94
+ return self.value
95
+
96
+
97
+ class CtaGroup(enum.Enum):
98
+ """
99
+ An enumeration for the ``cta_group`` qualifier of the MMA.
100
+ """
101
+
102
+ ONE = 1
103
+ TWO = 2
104
+
105
+ def __str__(self) -> str:
106
+ return f"{self.__class__.__name__}.{self.name}"
107
+
108
+ def __repr__(self) -> str:
109
+ return f"<{self.__class__.__name__}.{self.name}>"
110
+
111
+ class Field(enum.Enum):
112
+ """
113
+ An enumeration for the fields of the MMA Atom that can be modified at runtime.
114
+ """
115
+
116
+ NEGATE_A = "neg_a"
117
+ NEGATE_B = "neg_b"
118
+ ACCUMULATE = "accum_c"
119
+ SFA = "sf_a"
120
+ SFB = "sf_b"
121
+
122
+ def __str__(self) -> str:
123
+ return f"{self.__class__.__name__}.{self.name}"
124
+
125
+ def __repr__(self) -> str:
126
+ return f"<{self.__class__.__name__}.{self.name}>"
127
+
128
+ def _to_ir_field_name(self) -> str:
129
+ return self.value
130
+
131
+
132
+ # Base class for all tcgen05 MMA Ops with syntax `tcgen05.mma.cta_group.kind` used to factor out some internal code
133
+ @dataclass(frozen=True)
134
+ class MmaOp(core.MmaOp):
135
+ a_dtype: Type[Numeric]
136
+ b_dtype: Type[Numeric]
137
+ acc_dtype: Type[Numeric]
138
+ shape_mnk: Shape
139
+ cta_group: CtaGroup
140
+ a_src: OperandSource
141
+ a_major_mode: OperandMajorMode
142
+ b_major_mode: OperandMajorMode
143
+
144
+ admissible_archs = [
145
+ "sm_100a",
146
+ "sm_100f",
147
+ ]
148
+
149
+ def __post_init__(self) -> None:
150
+ # Verify arch
151
+ arch = CuTeDSL._get_dsl().envar.arch
152
+ if arch not in self.admissible_archs:
153
+ raise OpError(
154
+ self,
155
+ f"expects arch to be one of {self.admissible_archs}, but got {arch}",
156
+ suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
157
+ )
158
+ # Verify that the user provided enum values
159
+ if not isinstance(self.cta_group, CtaGroup):
160
+ raise OpError(
161
+ self,
162
+ "expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance",
163
+ )
164
+ if not isinstance(self.a_src, OperandSource):
165
+ raise OpError(
166
+ self,
167
+ "expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance",
168
+ )
169
+ if not isinstance(self.a_major_mode, OperandMajorMode):
170
+ raise OpError(
171
+ self,
172
+ "expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
173
+ )
174
+ if not isinstance(self.b_major_mode, OperandMajorMode):
175
+ raise OpError(
176
+ self,
177
+ "expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
178
+ )
179
+ # Verify the instruction shape
180
+ if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1):
181
+ raise OpError(
182
+ self,
183
+ f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, "
184
+ f"but got {self.shape_mnk}",
185
+ )
186
+ m, n = self.shape_mnk[0], self.shape_mnk[1]
187
+ if self.cta_group == CtaGroup.ONE:
188
+ if m not in [64, 128]:
189
+ raise OpError(self, f"expects the M-mode to be 64 or 128, but got {m}")
190
+ if m == 64:
191
+ if (n < 8) or (n > 256) or (n % 8 != 0):
192
+ raise OpError(
193
+ self,
194
+ f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}",
195
+ )
196
+ elif m == 128:
197
+ if (n < 16) or (n > 256) or (n % 16 != 0):
198
+ raise OpError(
199
+ self,
200
+ f"expects the N-mode to satisfy 8 <= N <= 256 and N % 16 == 0, but got {n}",
201
+ )
202
+ else:
203
+ if m not in [128, 256]:
204
+ raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}")
205
+ if (n < 32) or (n > 256) or (n % 32 != 0):
206
+ raise OpError(
207
+ self,
208
+ f"expects the N-mode to satisfy 32 <= N <= 256 and N % 32 == 0, but got {n}",
209
+ )
210
+
211
+ def __str__(self) -> str:
212
+ return (
213
+ self.__class__.descriptive_name # type: ignore
214
+ + f"\n A data type = {self.a_dtype}"
215
+ + f"\n B data type = {self.b_dtype}"
216
+ + f"\n Accumulator data type = {self.acc_dtype}"
217
+ + f"\n CTA group = {self.cta_group}"
218
+ + f"\n A source location = {self.a_src}"
219
+ + f"\n A major mode = {self.a_major_mode}"
220
+ + f"\n B major mode = {self.b_major_mode}"
221
+ + f"\n Instruction shape MNK = {self.shape_mnk}"
222
+ )
223
+
224
+ def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
225
+ if input.memspace == AddressSpace.smem and isinstance(
226
+ input.layout.type, _cute_ir.ComposedLayoutType
227
+ ):
228
+ raise OpError(
229
+ self,
230
+ f"Expected affine layout for {self._make_trait()}'s operand A, "
231
+ f"but got composed layout instead: {input.layout}"
232
+ f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
233
+ )
234
+ return True
235
+
236
+ def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
237
+ if input.memspace == AddressSpace.smem and isinstance(
238
+ input.layout.type, _cute_ir.ComposedLayoutType
239
+ ):
240
+ raise OpError(
241
+ self,
242
+ f"Expected affine layout for {self._make_trait()}'s operand B, "
243
+ f"but got composed layout instead: {input.layout}"
244
+ f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
245
+ )
246
+ return True
247
+
248
+
249
+ class MmaTrait(Trait):
250
+ admissible_fields = [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B]
251
+
252
+ def set(self, field, value, *, loc=None, ip=None) -> None:
253
+ if field not in self.admissible_fields:
254
+ raise ValueError(
255
+ f"expects field to be one of {self.admissible_fields}, but got {field}"
256
+ )
257
+ field_name = f"#cute_nvgpu.atom_mma_field_sm100<{field._to_ir_field_name()}>"
258
+ attr = ir.Attribute.parse(field_name)
259
+ self.value = _cute_nvgpu_ir.atom_set_value(
260
+ self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
261
+ )
262
+
263
+
264
+ # Base class for all tcgen05 BlockScaled MMA Ops with syntax `tcgen05.mma.cta_group.kind.block_scale` used to factor out some internal code
265
+ @dataclass(frozen=True)
266
+ class BlockScaledMmaOp(core.MmaOp):
267
+ a_dtype: Type[Numeric]
268
+ b_dtype: Type[Numeric]
269
+ acc_dtype: Float32
270
+ sf_dtype: Type[Numeric]
271
+ sf_vec_size: int
272
+ shape_mnk: Shape
273
+ cta_group: CtaGroup
274
+ a_src: OperandSource
275
+ a_major_mode: OperandMajorMode
276
+ b_major_mode: OperandMajorMode
277
+
278
+ admissible_archs = [
279
+ "sm_100a",
280
+ ]
281
+
282
+ def __post_init__(self) -> None:
283
+ # Verify arch
284
+ arch = CuTeDSL._get_dsl().envar.arch
285
+ if arch not in self.admissible_archs:
286
+ raise OpError(
287
+ self,
288
+ f"expects arch to be one of {self.admissible_archs}, but got {arch}",
289
+ suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
290
+ )
291
+ # Verify that the user provided enum values
292
+ if not isinstance(self.cta_group, CtaGroup):
293
+ raise OpError(
294
+ self,
295
+ "expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance",
296
+ )
297
+ if not isinstance(self.a_src, OperandSource):
298
+ raise OpError(
299
+ self,
300
+ "expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance",
301
+ )
302
+ if not isinstance(self.a_major_mode, OperandMajorMode):
303
+ raise OpError(
304
+ self,
305
+ "expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
306
+ )
307
+ if not isinstance(self.b_major_mode, OperandMajorMode):
308
+ raise OpError(
309
+ self,
310
+ "expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
311
+ )
312
+ # Verify the instruction shape
313
+ if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1):
314
+ raise OpError(
315
+ self,
316
+ f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, "
317
+ f"but got {self.shape_mnk}",
318
+ )
319
+ m, n = self.shape_mnk[0], self.shape_mnk[1]
320
+ if self.cta_group == CtaGroup.ONE:
321
+ if m != 128:
322
+ raise OpError(self, f"expects the M-mode to be 128, but got {m}")
323
+
324
+ if (n < 8) or (n > 256) or (n % 8 != 0):
325
+ raise OpError(
326
+ self,
327
+ f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}",
328
+ )
329
+ else:
330
+ if m not in [128, 256]:
331
+ raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}")
332
+ if (n < 16) or (n > 256) or (n % 16 != 0):
333
+ raise OpError(
334
+ self,
335
+ f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}",
336
+ )
337
+ if self.sf_vec_size not in [16, 32]:
338
+ raise OpError(
339
+ self,
340
+ f"expects the scale factor vector size to be 16 or 32, but got {self.sf_vec_size}",
341
+ )
342
+
343
+ def __str__(self) -> str:
344
+ return (
345
+ self.__class__.descriptive_name # type: ignore
346
+ + f"\n A data type = {self.a_dtype}"
347
+ + f"\n B data type = {self.b_dtype}"
348
+ + f"\n Accumulator data type = {self.acc_dtype}"
349
+ + f"\n Scale factor data type = {self.sf_dtype}"
350
+ + f"\n Scale factor vector size = {self.sf_vec_size}"
351
+ + f"\n CTA group = {self.cta_group}"
352
+ + f"\n A source location = {self.a_src}"
353
+ + f"\n A major mode = {self.a_major_mode}"
354
+ + f"\n B major mode = {self.b_major_mode}"
355
+ + f"\n Instruction shape MNK = {self.shape_mnk}"
356
+ )
357
+
358
+ def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
359
+ if input.memspace == AddressSpace.smem and isinstance(
360
+ input.layout.type, _cute_ir.ComposedLayoutType
361
+ ):
362
+ raise OpError(
363
+ self,
364
+ f"Expected affine layout for {self._make_trait()}'s operand A, "
365
+ f"but got composed layout instead: {input.layout}"
366
+ f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
367
+ )
368
+ return True
369
+
370
+ def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
371
+ if input.memspace == AddressSpace.smem and isinstance(
372
+ input.layout.type, _cute_ir.ComposedLayoutType
373
+ ):
374
+ raise OpError(
375
+ self,
376
+ f"Expected affine layout for {self._make_trait()}'s operand B, "
377
+ f"but got composed layout instead: {input.layout}"
378
+ f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
379
+ )
380
+ return True
381
+
382
+
383
+ class BlockScaledMmaTraits(Trait):
384
+ admissible_fields = [
385
+ Field.ACCUMULATE,
386
+ Field.NEGATE_A,
387
+ Field.NEGATE_B,
388
+ Field.SFA,
389
+ Field.SFB,
390
+ ]
391
+
392
+ def set(self, field, value, *, loc=None, ip=None) -> None:
393
+ if field not in self.admissible_fields:
394
+ raise ValueError(
395
+ f"expects field to be one of {self.admissible_fields}, but got {field}"
396
+ )
397
+ if field in [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B]:
398
+ value = Boolean(value).ir_value(loc=loc, ip=ip)
399
+ elif field in [Field.SFA, Field.SFB]:
400
+ if not isinstance(value, Pointer):
401
+ raise ValueError(
402
+ f"expects value to be a pointer for {field}, but got {type(value).__name__}"
403
+ )
404
+ value = value.value
405
+
406
+ field_name = f"#cute_nvgpu.atom_mma_field_sm100_block_scaled<{field._to_ir_field_name()}>"
407
+ attr = ir.Attribute.parse(field_name)
408
+ self.value = _cute_nvgpu_ir.atom_set_value(
409
+ self.value, attr, value, loc=loc, ip=ip
410
+ )
411
+
412
+
413
+ #
414
+ # TF32 MMA
415
+ #
416
+
417
+
418
+ @dataclass(frozen=True)
419
+ class MmaTF32Op(MmaOp):
420
+ """
421
+ TF32 tcgen05 MMA Operation.
422
+
423
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
424
+ This Operation corresponds to the ``.kind::tf32`` qualifier.
425
+ """
426
+
427
+ descriptive_name = "tcgen05 TF32 MMA Operation"
428
+
429
+ def __init__(
430
+ self,
431
+ instruction_shape: Shape,
432
+ cta_group: CtaGroup,
433
+ a_src: OperandSource,
434
+ a_major_mode: OperandMajorMode,
435
+ b_major_mode: OperandMajorMode,
436
+ ) -> None:
437
+ super().__init__(
438
+ TFloat32,
439
+ TFloat32,
440
+ Float32,
441
+ instruction_shape,
442
+ cta_group,
443
+ a_src,
444
+ a_major_mode,
445
+ b_major_mode,
446
+ )
447
+ self._verify()
448
+
449
+ def _verify(self) -> None:
450
+ # Verify the instruction shape
451
+ instruction_k = 8
452
+ if rank(self.shape_mnk) == 2:
453
+ object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
454
+ if self.shape_mnk[2] != instruction_k:
455
+ raise OpError(
456
+ self,
457
+ f"expects the instruction extent in the K-mode to be {instruction_k}, "
458
+ f"but got {self.shape_mnk[2]}",
459
+ )
460
+
461
+ def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaTF32Trait":
462
+ shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
463
+ ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
464
+ shape_mnk.type.attribute,
465
+ self.cta_group.value,
466
+ self.a_major_mode._to_ir(),
467
+ self.b_major_mode._to_ir(),
468
+ self.a_dtype.mlir_type,
469
+ self.b_dtype.mlir_type,
470
+ self.acc_dtype.mlir_type,
471
+ self.a_src._to_ir(),
472
+ 0,
473
+ )
474
+ return MmaTF32Trait(
475
+ _cute_nvgpu_ir.make_sm100_mma(
476
+ ty,
477
+ Boolean(False).ir_value(loc=loc, ip=ip),
478
+ Boolean(False).ir_value(loc=loc, ip=ip),
479
+ Boolean(False).ir_value(loc=loc, ip=ip),
480
+ loc=loc,
481
+ ip=ip,
482
+ )
483
+ )
484
+
485
+
486
+ class MmaTF32Trait(MmaTrait):
487
+ pass
488
+
489
+
490
+ #
491
+ # F16/BF16 MMA
492
+ #
493
+
494
+
495
+ @dataclass(frozen=True)
496
+ class MmaF16BF16Op(MmaOp):
497
+ """
498
+ F16/BF16 tcgen05 MMA Operation.
499
+
500
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
501
+ This Operation corresponds to the ``.kind::f16`` qualifier.
502
+ """
503
+
504
+ descriptive_name = "tcgen05 F16/BF16 MMA Operation"
505
+
506
+ def __init__(
507
+ self,
508
+ ab_dtype: Type[Numeric],
509
+ acc_dtype: Type[Numeric],
510
+ instruction_shape: Shape,
511
+ cta_group: CtaGroup,
512
+ a_src: OperandSource,
513
+ a_major_mode: OperandMajorMode,
514
+ b_major_mode: OperandMajorMode,
515
+ ) -> None:
516
+ super().__init__(
517
+ ab_dtype,
518
+ ab_dtype,
519
+ acc_dtype,
520
+ instruction_shape,
521
+ cta_group,
522
+ a_src,
523
+ a_major_mode,
524
+ b_major_mode,
525
+ )
526
+ self._verify()
527
+
528
+ def _verify(self) -> None:
529
+ # Input data type verification
530
+ if self.a_dtype not in [Float16, BFloat16]:
531
+ raise OpError(
532
+ self,
533
+ "expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16",
534
+ )
535
+ assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
536
+ # Accumulator data type verification
537
+ if self.acc_dtype not in [Float16, Float32]:
538
+ raise OpError(
539
+ self,
540
+ "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
541
+ )
542
+ # Instruction shape verification
543
+ instruction_k = 16
544
+ if rank(self.shape_mnk) == 2:
545
+ object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
546
+ if self.shape_mnk[2] != instruction_k:
547
+ raise OpError(
548
+ self,
549
+ f"expects the instruction extent in the K-mode to be {instruction_k}, "
550
+ f"but got {self.shape_mnk[2]}",
551
+ )
552
+
553
+ def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait":
554
+ shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
555
+ ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
556
+ shape_mnk.type.attribute,
557
+ self.cta_group.value,
558
+ self.a_major_mode._to_ir(),
559
+ self.b_major_mode._to_ir(),
560
+ self.a_dtype.mlir_type,
561
+ self.b_dtype.mlir_type,
562
+ self.acc_dtype.mlir_type,
563
+ self.a_src._to_ir(),
564
+ 0,
565
+ )
566
+ return MmaF16BF16Trait(
567
+ _cute_nvgpu_ir.make_sm100_mma(
568
+ ty,
569
+ Boolean(False).ir_value(loc=loc, ip=ip),
570
+ Boolean(False).ir_value(loc=loc, ip=ip),
571
+ Boolean(False).ir_value(loc=loc, ip=ip),
572
+ loc=loc,
573
+ ip=ip,
574
+ )
575
+ )
576
+
577
+
578
+ class MmaF16BF16Trait(MmaTrait):
579
+ pass
580
+
581
+
582
+ #
583
+ # I8 MMA
584
+ #
585
+
586
+
587
+ @dataclass(frozen=True)
588
+ class MmaI8Op(MmaOp):
589
+ """
590
+ I8 tcgen05 MMA Operation.
591
+
592
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
593
+ This Operation corresponds to the ``.kind::i8`` qualifier.
594
+ """
595
+
596
+ descriptive_name = "tcgen05 I8 MMA Operation"
597
+
598
+ def __init__(
599
+ self,
600
+ ab_dtype: Type[Numeric],
601
+ instruction_shape: Shape,
602
+ cta_group: CtaGroup,
603
+ a_src: OperandSource,
604
+ a_major_mode: OperandMajorMode,
605
+ b_major_mode: OperandMajorMode,
606
+ ) -> None:
607
+ super().__init__(
608
+ ab_dtype,
609
+ ab_dtype,
610
+ Int32,
611
+ instruction_shape,
612
+ cta_group,
613
+ a_src,
614
+ a_major_mode,
615
+ b_major_mode,
616
+ )
617
+ self._verify()
618
+
619
+ def _verify(self) -> None:
620
+ # Input data type verification
621
+ if self.a_dtype not in [Int8, Uint8]:
622
+ raise OpError(
623
+ self,
624
+ "expects the 'ab_dtype' Op parameter to be one of Int8 or Uint8",
625
+ )
626
+ assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
627
+ # Instruction shape verification
628
+ instruction_k = 32
629
+ if rank(self.shape_mnk) == 2:
630
+ object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
631
+ if self.shape_mnk[2] != instruction_k:
632
+ raise OpError(
633
+ self,
634
+ f"expects the instruction extent in the K-mode to be {instruction_k}, "
635
+ f"but got {self.shape_mnk[2]}",
636
+ )
637
+
638
+ def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaI8Trait":
639
+ shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
640
+ ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
641
+ shape_mnk.type.attribute,
642
+ self.cta_group.value,
643
+ self.a_major_mode._to_ir(),
644
+ self.b_major_mode._to_ir(),
645
+ (T.si8() if self.a_dtype.signed else T.ui8()),
646
+ (T.si8() if self.b_dtype.signed else T.ui8()),
647
+ T.si32(),
648
+ self.a_src._to_ir(),
649
+ 0,
650
+ )
651
+ return MmaI8Trait(
652
+ _cute_nvgpu_ir.make_sm100_mma(
653
+ ty,
654
+ Boolean(False).ir_value(loc=loc, ip=ip),
655
+ Boolean(False).ir_value(loc=loc, ip=ip),
656
+ Boolean(False).ir_value(loc=loc, ip=ip),
657
+ loc=loc,
658
+ ip=ip,
659
+ )
660
+ )
661
+
662
+
663
+ class MmaI8Trait(MmaTrait):
664
+ pass
665
+
666
+
667
+ #
668
+ # F8F6F4 MMA
669
+ #
670
+
671
+
672
+ @dataclass(frozen=True)
673
+ class MmaFP8Op(MmaOp):
674
+ """
675
+ F8 tcgen05 MMA Operation.
676
+
677
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
678
+ """
679
+
680
+ descriptive_name = "tcgen05 F8 MMA Operation"
681
+
682
+ def __init__(
683
+ self,
684
+ ab_dtype: Type[Numeric],
685
+ acc_dtype: Type[Numeric],
686
+ instruction_shape: Shape,
687
+ cta_group: CtaGroup,
688
+ a_src: OperandSource,
689
+ a_major_mode: OperandMajorMode,
690
+ b_major_mode: OperandMajorMode,
691
+ ) -> None:
692
+
693
+ super().__init__(
694
+ ab_dtype,
695
+ ab_dtype,
696
+ acc_dtype,
697
+ instruction_shape,
698
+ cta_group,
699
+ a_src,
700
+ a_major_mode,
701
+ b_major_mode,
702
+ )
703
+ self._verify()
704
+
705
+ def _verify(self) -> None:
706
+ # Input data type verification
707
+ if self.a_dtype not in [Float8E5M2, Float8E4M3FN]:
708
+ raise OpError(
709
+ self,
710
+ "expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
711
+ )
712
+ assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
713
+ # Accumulator data type verification
714
+ if self.acc_dtype not in [Float16, Float32]:
715
+ raise OpError(
716
+ self,
717
+ "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
718
+ )
719
+ # Instruction shape verification
720
+ instruction_k = 32
721
+ if rank(self.shape_mnk) == 2:
722
+ object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
723
+ if self.shape_mnk[2] != instruction_k:
724
+ raise OpError(
725
+ self,
726
+ f"expects the instruction extent in the K-mode to be {instruction_k}, "
727
+ f"but got {self.shape_mnk[2]}",
728
+ )
729
+
730
+ def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaFP8Trait":
731
+ shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
732
+ ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
733
+ shape_mnk.type.attribute,
734
+ self.cta_group.value,
735
+ self.a_major_mode._to_ir(),
736
+ self.b_major_mode._to_ir(),
737
+ self.a_dtype.mlir_type,
738
+ self.b_dtype.mlir_type,
739
+ self.acc_dtype.mlir_type,
740
+ self.a_src._to_ir(),
741
+ 0,
742
+ )
743
+ return MmaFP8Trait(
744
+ _cute_nvgpu_ir.make_sm100_mma(
745
+ ty,
746
+ Boolean(False).ir_value(loc=loc, ip=ip),
747
+ Boolean(False).ir_value(loc=loc, ip=ip),
748
+ Boolean(False).ir_value(loc=loc, ip=ip),
749
+ loc=loc,
750
+ ip=ip,
751
+ )
752
+ )
753
+
754
+
755
+ class MmaFP8Trait(MmaTrait):
756
+ pass
757
+
758
+
759
+ #
760
+ # MXF8F6F4 MMA
761
+ #
762
+
763
+
764
+ @dataclass(frozen=True)
765
+ class MmaMXF8Op(BlockScaledMmaOp):
766
+ """
767
+ MXF8 tcgen05 BlockScaled MMA Operation.
768
+
769
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
770
+ This Operation corresponds to the ``.kind::mxf8f6f4`` qualifier.
771
+ """
772
+
773
+ descriptive_name = "tcgen05 MXF8 BlockScaled MMA Operation"
774
+
775
+ def __init__(
776
+ self,
777
+ ab_dtype: Type[Numeric],
778
+ instruction_shape: Shape,
779
+ cta_group: CtaGroup,
780
+ a_src: OperandSource,
781
+ a_major_mode: OperandMajorMode,
782
+ b_major_mode: OperandMajorMode,
783
+ ) -> None:
784
+ super().__init__(
785
+ ab_dtype,
786
+ ab_dtype,
787
+ Float32,
788
+ Float8E8M0FNU,
789
+ 32,
790
+ instruction_shape,
791
+ cta_group,
792
+ a_src,
793
+ a_major_mode,
794
+ b_major_mode,
795
+ )
796
+ self._verify()
797
+
798
+ def _verify(self) -> None:
799
+ # Input data type verification
800
+ if self.a_dtype not in [Float8E5M2, Float8E4M3FN]:
801
+ raise OpError(
802
+ self,
803
+ "expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
804
+ )
805
+ assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
806
+ # Instruction shape verification
807
+ instruction_k = 32
808
+ if rank(self.shape_mnk) == 2:
809
+ object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
810
+ if self.shape_mnk[2] != instruction_k:
811
+ raise OpError(
812
+ self,
813
+ f"expects the instruction extent in the K-mode to be {instruction_k}, "
814
+ f"but got {self.shape_mnk[2]}",
815
+ )
816
+
817
+ def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait":
818
+ shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
819
+ ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get(
820
+ shape_mnk.type.attribute,
821
+ self.cta_group.value,
822
+ self.a_major_mode._to_ir(),
823
+ self.b_major_mode._to_ir(),
824
+ self.a_dtype.mlir_type,
825
+ self.b_dtype.mlir_type,
826
+ self.acc_dtype.mlir_type,
827
+ self.sf_dtype.mlir_type,
828
+ self.a_src._to_ir(),
829
+ self.sf_vec_size,
830
+ )
831
+ return MmaMXF8Trait(
832
+ _cute_nvgpu_ir.make_sm100_mma_bs(
833
+ ty,
834
+ Boolean(False).ir_value(loc=loc, ip=ip),
835
+ Boolean(False).ir_value(loc=loc, ip=ip),
836
+ Boolean(False).ir_value(loc=loc, ip=ip),
837
+ core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
838
+ core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
839
+ loc=loc,
840
+ ip=ip,
841
+ )
842
+ )
843
+
844
+
845
+ class MmaMXF8Trait(BlockScaledMmaTraits):
846
+ pass
847
+
848
+
849
+ #
850
+ # MXF4 MMA
851
+ #
852
+
853
+
854
+ @dataclass(frozen=True)
855
+ class MmaMXF4Op(BlockScaledMmaOp):
856
+ """
857
+ MXF4 tcgen05 BlockScaled MMA Operation.
858
+
859
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
860
+ This Operation corresponds to the ``.kind::mxf4`` qualifier.
861
+ """
862
+
863
+ descriptive_name = "tcgen05 MXF4 BlockScaled MMA Operation"
864
+
865
+ def __init__(
866
+ self,
867
+ instruction_shape: Shape,
868
+ cta_group: CtaGroup,
869
+ a_src: OperandSource,
870
+ ) -> None:
871
+ super().__init__(
872
+ Float4E2M1FN,
873
+ Float4E2M1FN,
874
+ Float32,
875
+ Float8E8M0FNU,
876
+ 32,
877
+ instruction_shape,
878
+ cta_group,
879
+ a_src,
880
+ OperandMajorMode.K,
881
+ OperandMajorMode.K,
882
+ )
883
+ self._verify()
884
+
885
+ def _verify(self) -> None:
886
+ # Instruction shape verification
887
+ instruction_k = 64
888
+ if rank(self.shape_mnk) == 2:
889
+ object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
890
+ if self.shape_mnk[2] != instruction_k:
891
+ raise OpError(
892
+ self,
893
+ f"expects the instruction extent in the K-mode to be {instruction_k}, "
894
+ f"but got {self.shape_mnk[2]}",
895
+ )
896
+
897
+ def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait":
898
+ shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
899
+ ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get(
900
+ shape_mnk.type.attribute,
901
+ self.cta_group.value,
902
+ self.a_major_mode._to_ir(),
903
+ self.b_major_mode._to_ir(),
904
+ self.a_dtype.mlir_type,
905
+ self.b_dtype.mlir_type,
906
+ self.acc_dtype.mlir_type,
907
+ self.sf_dtype.mlir_type,
908
+ self.a_src._to_ir(),
909
+ self.sf_vec_size,
910
+ )
911
+ return MmaMXF4Trait(
912
+ _cute_nvgpu_ir.make_sm100_mma_bs(
913
+ ty,
914
+ Boolean(False).ir_value(loc=loc, ip=ip),
915
+ Boolean(False).ir_value(loc=loc, ip=ip),
916
+ Boolean(False).ir_value(loc=loc, ip=ip),
917
+ core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
918
+ core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
919
+ loc=loc,
920
+ ip=ip,
921
+ )
922
+ )
923
+
924
+
925
+ class MmaMXF4Trait(BlockScaledMmaTraits):
926
+ pass
927
+
928
+
929
+ #
930
+ # MXF4NVF4 MMA
931
+ #
932
+
933
+
934
+ @dataclass(frozen=True)
935
+ class MmaMXF4NVF4Op(BlockScaledMmaOp):
936
+ """
937
+ MXF4NVF4 tcgen05 BlockScaled MMA Operation.
938
+
939
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
940
+ This Operation corresponds to the ``.kind::mxf4nvf4`` qualifier.
941
+ """
942
+
943
+ descriptive_name = "tcgen05 MXF4NVF4 BlockScaled MMA Operation"
944
+
945
+ def __init__(
946
+ self,
947
+ sf_dtype: Type[Numeric],
948
+ instruction_shape: Shape,
949
+ cta_group: CtaGroup,
950
+ a_src: OperandSource,
951
+ ) -> None:
952
+ super().__init__(
953
+ Float4E2M1FN,
954
+ Float4E2M1FN,
955
+ Float32,
956
+ sf_dtype,
957
+ 16,
958
+ instruction_shape,
959
+ cta_group,
960
+ a_src,
961
+ OperandMajorMode.K,
962
+ OperandMajorMode.K,
963
+ )
964
+ self._verify()
965
+
966
+ def _verify(self) -> None:
967
+ # Scale Factor data type verification
968
+ if self.sf_dtype not in [Float8E8M0FNU, Float8E4M3FN]:
969
+ raise OpError(
970
+ self,
971
+ "expects the 'sf_dtype' Op parameter to be one of Float8E8M0FNU",
972
+ )
973
+ # Instruction shape verification
974
+ instruction_k = 64
975
+ if rank(self.shape_mnk) == 2:
976
+ object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
977
+ if self.shape_mnk[2] != instruction_k:
978
+ raise OpError(
979
+ self,
980
+ f"expects the instruction extent in the K-mode to be {instruction_k}, "
981
+ f"but got {self.shape_mnk[2]}",
982
+ )
983
+
984
+ def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait":
985
+ shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
986
+ ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get(
987
+ shape_mnk.type.attribute,
988
+ self.cta_group.value,
989
+ self.a_major_mode._to_ir(),
990
+ self.b_major_mode._to_ir(),
991
+ self.a_dtype.mlir_type,
992
+ self.b_dtype.mlir_type,
993
+ self.acc_dtype.mlir_type,
994
+ self.sf_dtype.mlir_type,
995
+ self.a_src._to_ir(),
996
+ self.sf_vec_size,
997
+ )
998
+ return MmaMXF4NVF4Trait(
999
+ _cute_nvgpu_ir.make_sm100_mma_bs(
1000
+ ty,
1001
+ Boolean(False).ir_value(loc=loc, ip=ip),
1002
+ Boolean(False).ir_value(loc=loc, ip=ip),
1003
+ Boolean(False).ir_value(loc=loc, ip=ip),
1004
+ core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
1005
+ core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
1006
+ loc=loc,
1007
+ ip=ip,
1008
+ )
1009
+ )
1010
+
1011
+
1012
+ class MmaMXF4NVF4Trait(BlockScaledMmaTraits):
1013
+ pass
1014
+
1015
+ ####################################################################################################
1016
+ #
1017
+ # SMEM layout atoms
1018
+ #
1019
+ ####################################################################################################
1020
+
1021
+
1022
+ class SmemLayoutAtomKind(enum.Enum):
1023
+ """
1024
+ Enum class for the kinds of SMEM layout atoms for SM100.
1025
+
1026
+ Given a swizzle kind, an SMEM layout atom is the compact layout of smallest size that can be
1027
+ used to construct an SMEM layout using blocked product for operand A or B such that the
1028
+ resulting layout is legal for both TMA and UMMA.
1029
+
1030
+ Note that there are other ways of creating legal layouts for operand A and B.
1031
+ """
1032
+
1033
+ MN_INTER = enum.auto()
1034
+ MN_SW32 = enum.auto()
1035
+ MN_SW64 = enum.auto()
1036
+ MN_SW128 = enum.auto()
1037
+ MN_SW128_32B = enum.auto()
1038
+ K_INTER = enum.auto()
1039
+ K_SW32 = enum.auto()
1040
+ K_SW64 = enum.auto()
1041
+ K_SW128 = enum.auto()
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from .copy import *
13
+ from .mma import *
14
+
15
+
16
+ # __all__ is required here for documentation generation
17
+ __all__ = [
18
+ # mma.py
19
+ "MmaF16BF16Op",
20
+ # copy.py
21
+ "LdMatrix8x8x16bOp",
22
+ "LdMatrix16x16x8bOp",
23
+ "StMatrix8x8x16bOp",
24
+ "StMatrix16x8x8bOp",
25
+ ]
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from dataclasses import dataclass
13
+ from typing import Type
14
+
15
+ import cutlass._mlir.dialects.cute as _cute_ir
16
+ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
17
+ from cutlass._mlir import ir
18
+
19
+ from ..common import OpError
20
+ from ...core import CopyOp, Trait, _pack_shape
21
+ from ...typing import Numeric
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class BaseOp(CopyOp):
26
+ transpose: bool = False
27
+ num_matrices: int = 1
28
+
29
+ def __post_init__(self) -> None:
30
+ if not isinstance(self.transpose, bool):
31
+ raise OpError(
32
+ self,
33
+ "expects the 'transpose' Op parameter to be a bool instance",
34
+ )
35
+
36
+ def __str__(self) -> str:
37
+ res = (
38
+ f"{self.__class__.__name__[:-2]} Copy Operation"
39
+ + f"\n number of matrices = {self.num_matrices}"
40
+ )
41
+ if self.transpose:
42
+ res += f"\n transposed"
43
+ return res
44
+
45
+
46
+ @dataclass(frozen=True)
47
+ class LdMatrix8x8x16bOp(BaseOp):
48
+ """
49
+ 8x8 ``ldmatrix`` Operation.
50
+
51
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-load-instruction-ldmatrix>`__.
52
+ This operation corresponds to the ``.m8n8`` qualifier.
53
+ """
54
+
55
+ def __post_init__(self) -> None:
56
+ super().__post_init__()
57
+ if self.num_matrices not in [1, 2, 4]:
58
+ raise OpError(
59
+ self,
60
+ "expects the 'num_matrices' Op parameter to be one of [1,2,4]",
61
+ )
62
+
63
+ def _make_trait(
64
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
65
+ ) -> "LdMatrix8x8x16bTrait":
66
+ mode = _pack_shape((8, 8), loc=loc, ip=ip)
67
+ ty = _cute_nvgpu_ir.CopyAtomLdsmType.get(
68
+ copy_internal_type.mlir_type,
69
+ mode.type.attribute,
70
+ _cute_nvgpu_ir.LdsmSzPattern.u16,
71
+ self.num_matrices,
72
+ ir.UnitAttr.get() if self.transpose else None,
73
+ )
74
+ return LdMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
75
+
76
+
77
+ class LdMatrix8x8x16bTrait(Trait):
78
+ pass
79
+
80
+
81
+ @dataclass(frozen=True)
82
+ class LdMatrix16x16x8bOp(BaseOp):
83
+ """
84
+ 16x16 8-bit ``ldmatrix`` Operation.
85
+
86
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-load-instruction-ldmatrix>`__.
87
+ This operation corresponds to the ``.m16n16`` and the ``.b16`` qualifiers.
88
+ """
89
+
90
+ def __init__(self, num_matrices: int) -> None:
91
+ super().__init__(transpose=True, num_matrices=num_matrices)
92
+ self._verify()
93
+
94
+ def _verify(self):
95
+ assert self.transpose, "transpose must be True"
96
+ if self.num_matrices not in [1, 2]:
97
+ raise OpError(
98
+ self,
99
+ "expects the 'num_matrices' Op parameter to be one of [1,2]",
100
+ )
101
+
102
+ def _make_trait(
103
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
104
+ ) -> "LdMatrix16x16x8bTrait":
105
+ mode = _pack_shape((16, 16), loc=loc, ip=ip)
106
+ ty = _cute_nvgpu_ir.CopyAtomLdsmType.get(
107
+ copy_internal_type.mlir_type,
108
+ mode.type.attribute,
109
+ _cute_nvgpu_ir.LdsmSzPattern.u8,
110
+ self.num_matrices,
111
+ ir.UnitAttr.get(),
112
+ )
113
+ return LdMatrix16x16x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
114
+
115
+
116
+ class LdMatrix16x16x8bTrait(Trait):
117
+ pass
118
+
119
+
120
+ @dataclass(frozen=True)
121
+ class StMatrix8x8x16bOp(BaseOp):
122
+ """
123
+ 8x8 ``stmatrix`` Operation.
124
+
125
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-stmatrix>`__.
126
+ This operation corresponds to the ``m8n8`` qualifier.
127
+ """
128
+
129
+ def __post_init__(self) -> None:
130
+ super().__post_init__()
131
+ if self.num_matrices not in [1, 2, 4]:
132
+ raise OpError(
133
+ self,
134
+ "expects the 'num_matrices' Op parameter to be one of [1,2,4]",
135
+ )
136
+
137
+ def _make_trait(
138
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
139
+ ) -> "StMatrix8x8x16bTrait":
140
+ mode = _pack_shape((8, 8), loc=loc, ip=ip)
141
+ ty = _cute_nvgpu_ir.CopyAtomStsmType.get(
142
+ copy_internal_type.mlir_type,
143
+ mode.type.attribute,
144
+ self.num_matrices,
145
+ ir.UnitAttr.get() if self.transpose else None,
146
+ )
147
+ return StMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
148
+
149
+
150
+ class StMatrix8x8x16bTrait(Trait):
151
+ pass
152
+
153
+
154
+ @dataclass(frozen=True)
155
+ class StMatrix16x8x8bOp(BaseOp):
156
+ """
157
+ 16x8 ``stmatrix`` Operation.
158
+
159
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-stmatrix>`__.
160
+ This operation corresponds to the ``m16n8`` qualifier.
161
+ """
162
+
163
+ def __init__(self, num_matrices: int) -> None:
164
+ super().__init__(transpose=True, num_matrices=num_matrices)
165
+ self._verify()
166
+
167
+ def _verify(self):
168
+ if self.num_matrices not in [1, 2, 4]:
169
+ assert self.transpose, "transpose must be True"
170
+ raise OpError(
171
+ self,
172
+ "expects the 'num_matrices' Op parameter to be one of [1,2,4]",
173
+ )
174
+
175
+ def _make_trait(
176
+ self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
177
+ ) -> "StMatrix16x8x8bTrait":
178
+ mode = _pack_shape((16, 8), loc=loc, ip=ip)
179
+ ty = _cute_nvgpu_ir.CopyAtomStsmType.get(
180
+ copy_internal_type.mlir_type,
181
+ mode.type.attribute,
182
+ self.num_matrices,
183
+ ir.UnitAttr.get(),
184
+ )
185
+ return StMatrix16x8x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
186
+
187
+
188
+ class StMatrix16x8x8bTrait(Trait):
189
+ pass
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from dataclasses import dataclass
13
+ from typing import Type
14
+
15
+ import cutlass._mlir.dialects.cute as _cute_ir
16
+ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
17
+
18
+ from ..common import OpError
19
+ from ...core import MmaOp, Trait, _pack_shape, _Tensor
20
+ from ...typing import Shape, Float16, BFloat16, Float32, Numeric, AddressSpace
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class MmaF16BF16Op(MmaOp):
25
+ """
26
+ F16/BF16 tcgen05 MMA Operation.
27
+
28
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma>`__.
29
+ This Operation covers the instructions using the ``.f16`` or ``.bf16`` qualifiers for the input operands.
30
+ """
31
+
32
+ ab_dtype: Type[Numeric]
33
+ acc_dtype: Type[Numeric]
34
+ shape_mnk: Shape
35
+
36
+ def __post_init__(self) -> None:
37
+ if self.ab_dtype not in [Float16, BFloat16]:
38
+ raise OpError(
39
+ self,
40
+ "expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16",
41
+ )
42
+ if self.acc_dtype not in [Float16, Float32]:
43
+ raise OpError(
44
+ self,
45
+ "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
46
+ )
47
+ if (self.ab_dtype == BFloat16) and (self.acc_dtype != Float32):
48
+ raise OpError(
49
+ self,
50
+ "expects the 'acc_dtype' Op parameter to be Float32 when 'ab_dtype' is BFloat16",
51
+ )
52
+ if self.shape_mnk not in [(16, 8, 8), (16, 8, 16)]:
53
+ raise OpError(
54
+ self,
55
+ "expects the 'shape_mnk' Op parameter to be one of (16,8,8) or (16,8,16)",
56
+ )
57
+
58
+ def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait":
59
+ shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
60
+ ty = _cute_nvgpu_ir.MmaAtomSM80Type.get(
61
+ shape_mnk.type.attribute,
62
+ self.ab_dtype.mlir_type,
63
+ self.ab_dtype.mlir_type,
64
+ self.acc_dtype.mlir_type,
65
+ )
66
+ return MmaF16BF16Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
67
+
68
+ def __str__(self) -> str:
69
+ return (
70
+ "warp-level F16/BF16 MMA Operation"
71
+ + f"\n A/B data type = {self.ab_dtype}"
72
+ + f"\n Accumulator data type = {self.acc_dtype}"
73
+ + f"\n Instruction shape MNK = {self.shape_mnk}"
74
+ )
75
+
76
+ def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
77
+ pass
78
+
79
+ def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
80
+ pass
81
+
82
+ class MmaF16BF16Trait(Trait):
83
+ pass
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from .mma import *
13
+ from .helpers import *
14
+
15
+ # __all__ is required here for documentation generation
16
+ __all__ = [
17
+ # mma.py
18
+ "OperandMajorMode",
19
+ "OperandSource",
20
+ "Field",
21
+ "MmaF16BF16Op",
22
+ "MmaF8Op",
23
+ "SmemLayoutAtomKind",
24
+ # helpers.py
25
+ "make_smem_layout_atom",
26
+ "fence",
27
+ "commit_group",
28
+ "wait_group",
29
+ ]
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from typing import Type
13
+
14
+ from cutlass.cutlass_dsl import dsl_user_op
15
+
16
+ from cutlass._mlir.dialects import nvvm
17
+
18
+ from ...typing import Numeric, NumericMeta
19
+ from ... import core
20
+ from .mma import SmemLayoutAtomKind
21
+
22
+
23
+ @dsl_user_op
24
+ def make_smem_layout_atom(
25
+ kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None
26
+ ) -> core.ComposedLayout:
27
+ """
28
+ Makes a SMEM layout Atom.
29
+
30
+ This function creates a composed layout in unit of elements consistent with the requested layout
31
+ Atom kind and element data type.
32
+
33
+ :param kind: The kind of layout Atom
34
+ :type kind: SmemLayoutAtomKind
35
+ :param element_type: The element data type to construct the layout for
36
+ :type element_type: Type[Numeric]
37
+ :return: The SMEM layout atom
38
+ :rtype: core.ComposedLayout
39
+ """
40
+ if not isinstance(element_type, NumericMeta):
41
+ raise TypeError(f"element_type must be a Numeric, but got {element_type}")
42
+
43
+ if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER):
44
+ num_contiguous_bits = 128
45
+ sw = core.make_swizzle(0, 4, 3)
46
+ elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32):
47
+ num_contiguous_bits = 256
48
+ sw = core.make_swizzle(1, 4, 3)
49
+ elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64):
50
+ num_contiguous_bits = 512
51
+ sw = core.make_swizzle(2, 4, 3)
52
+ elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128):
53
+ num_contiguous_bits = 1024
54
+ sw = core.make_swizzle(3, 4, 3)
55
+ else:
56
+ raise ValueError("unrecognized SMEM layout atom kind")
57
+ num_contiguous_elems = num_contiguous_bits // element_type.width
58
+
59
+ if kind in (
60
+ SmemLayoutAtomKind.MN_INTER,
61
+ SmemLayoutAtomKind.MN_SW32,
62
+ SmemLayoutAtomKind.MN_SW64,
63
+ SmemLayoutAtomKind.MN_SW128,
64
+ ):
65
+ # M/N-major layout
66
+ return core.make_composed_layout(
67
+ sw,
68
+ 0,
69
+ core.make_layout(
70
+ (num_contiguous_elems, 8), stride=(1, num_contiguous_elems)
71
+ ),
72
+ loc=loc,
73
+ ip=ip,
74
+ )
75
+ else:
76
+ # K-major layout
77
+ return core.make_composed_layout(
78
+ sw,
79
+ 0,
80
+ core.make_layout(
81
+ (8, num_contiguous_elems), stride=(num_contiguous_elems, 1)
82
+ ),
83
+ loc=loc,
84
+ ip=ip,
85
+ )
86
+
87
+
88
+ @dsl_user_op
89
+ def fence(*, loc=None, ip=None) -> None:
90
+ """
91
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-fence>`__.
92
+ """
93
+ nvvm.wgmma_fence_aligned(loc=None, ip=None)
94
+
95
+
96
+ @dsl_user_op
97
+ def commit_group(*, loc=None, ip=None) -> None:
98
+ """
99
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group>`__.
100
+ """
101
+ nvvm.wgmma_commit_group_sync_aligned(loc=loc, ip=ip)
102
+
103
+
104
+ @dsl_user_op
105
+ def wait_group(group, *, loc=None, ip=None) -> None:
106
+ """
107
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-wait-group>`__.
108
+ """
109
+ nvvm.wgmma_wait_group_sync_aligned(group, loc=loc, ip=ip)
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ import enum
13
+ from dataclasses import dataclass
14
+ from typing import Type
15
+
16
+ from cutlass.cutlass_dsl import CuTeDSL
17
+
18
+ import cutlass._mlir.dialects.cute as _cute_ir
19
+ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
20
+ from cutlass._mlir import ir
21
+
22
+ from ..common import OpError
23
+ from ...core import MmaOp, Trait, _pack_shape, rank, depth, _Tensor
24
+ from ...typing import (
25
+ Shape,
26
+ Float16,
27
+ BFloat16,
28
+ Float32,
29
+ Boolean,
30
+ Float8E5M2,
31
+ Float8E4M3FN,
32
+ Numeric,
33
+ AddressSpace,
34
+ )
35
+
36
+
37
+ ####################################################################################################
38
+ #
39
+ # MMA Ops and Traits
40
+ #
41
+ ####################################################################################################
42
+
43
+
44
+ class OperandMajorMode(enum.Enum):
45
+ """
46
+ An enumeration for the majorness of the input operands of the MMA.
47
+ """
48
+
49
+ MN = _cute_ir.MajorMode.mn
50
+ K = _cute_ir.MajorMode.k
51
+
52
+ def __str__(self) -> str:
53
+ return f"{self.__class__.__name__}.{self.name}"
54
+
55
+ def __repr__(self) -> str:
56
+ return f"<{self.__class__.__name__}.{self.name}>"
57
+
58
+ @classmethod
59
+ def _missing_(cls, value):
60
+ if isinstance(value, str):
61
+ value = value.upper()
62
+ if value == "MN":
63
+ return OperandMajorMode.MN
64
+ elif value == "K":
65
+ return OperandMajorMode.K
66
+
67
+ def _to_ir(self) -> _cute_ir.MajorMode:
68
+ return self.value
69
+
70
+
71
+ class OperandSource(enum.Enum):
72
+ """
73
+ An enumeration for the source memory location of the A input operand of the MMA.
74
+ """
75
+
76
+ RMEM = _cute_ir.MmaFragKind.rmem
77
+ SMEM = _cute_ir.MmaFragKind.smem_desc
78
+
79
+ def __str__(self) -> str:
80
+ return f"{self.__class__.__name__}.{self.name}"
81
+
82
+ def __repr__(self) -> str:
83
+ return f"<{self.__class__.__name__}.{self.name}>"
84
+
85
+ def _to_ir(self) -> _cute_ir.MmaFragKind:
86
+ return self.value
87
+
88
+
89
+ class Field(enum.Enum):
90
+ """
91
+ An enumeration for the fields of the MMA Atom that can be modified at runtime.
92
+ """
93
+
94
+ ACCUMULATE = "accum_c"
95
+
96
+ def __str__(self) -> str:
97
+ return f"{self.__class__.__name__}.{self.name}"
98
+
99
+ def __repr__(self) -> str:
100
+ return f"<{self.__class__.__name__}.{self.name}>"
101
+
102
+ def _to_ir_field_name(self) -> str:
103
+ return self.value
104
+
105
+
106
+ @dataclass(frozen=True)
107
+ class MmaOp(MmaOp):
108
+ a_dtype: Type[Numeric]
109
+ b_dtype: Type[Numeric]
110
+ acc_dtype: Type[Numeric]
111
+ shape_mnk: Shape
112
+ a_src: OperandSource
113
+ a_major_mode: OperandMajorMode
114
+ b_major_mode: OperandMajorMode
115
+
116
+ admissible_archs = ["sm_90a"]
117
+
118
+ def __post_init__(self) -> None:
119
+ # Verify arch
120
+ arch = CuTeDSL._get_dsl().envar.arch
121
+ if arch not in self.admissible_archs:
122
+ raise OpError(
123
+ self,
124
+ f"expects arch to be one of {self.admissible_archs}, but got {arch}",
125
+ suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
126
+ )
127
+ # Verify that the user provided enum values
128
+ if not isinstance(self.a_src, OperandSource):
129
+ raise OpError(
130
+ self,
131
+ "expects the 'a_src' Op parameter to be a warpgroup.OperandSource instance",
132
+ )
133
+ if not isinstance(self.a_major_mode, OperandMajorMode):
134
+ raise OpError(
135
+ self,
136
+ "expects the 'a_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance",
137
+ )
138
+ if not isinstance(self.b_major_mode, OperandMajorMode):
139
+ raise OpError(
140
+ self,
141
+ "expects the 'b_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance",
142
+ )
143
+ # Verify instruction shape
144
+ if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1):
145
+ raise OpError(
146
+ self,
147
+ f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, "
148
+ f"but got {self.shape_mnk}",
149
+ )
150
+ m, n = self.shape_mnk[0], self.shape_mnk[1]
151
+ if m != 64:
152
+ raise OpError(self, f"expects the M-mode to be 64, but got {m}")
153
+ if (n < 8) or (n > 256) or (n % 8 != 0):
154
+ raise OpError(
155
+ self,
156
+ f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0. but got {n}",
157
+ )
158
+
159
+ def __str__(self) -> str:
160
+ return (
161
+ self.__class__.descriptive_name # type: ignore
162
+ + f"\n A data type = {self.a_dtype}"
163
+ + f"\n B data type = {self.b_dtype}"
164
+ + f"\n Accumulator data type = {self.acc_dtype}"
165
+ + f"\n A source location = {self.a_src}"
166
+ + f"\n A major mode = {self.a_major_mode}"
167
+ + f"\n B major mode = {self.b_major_mode}"
168
+ + f"\n Instruction shape MNK = {self.shape_mnk}"
169
+ )
170
+
171
+ def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
172
+ if input.memspace == AddressSpace.smem and isinstance(
173
+ input.layout.type, _cute_ir.ComposedLayoutType
174
+ ):
175
+ raise OpError(
176
+ self,
177
+ f"Expected affine layout for {self._make_trait()}'s operand A, "
178
+ f"but got composed layout instead: {input.layout}"
179
+ f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
180
+ )
181
+ return True
182
+
183
+ def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
184
+ if input.memspace == AddressSpace.smem and isinstance(
185
+ input.layout.type, _cute_ir.ComposedLayoutType
186
+ ):
187
+ raise OpError(
188
+ self,
189
+ f"Expected affine layout for {self._make_trait()}'s operand B, "
190
+ f"but got composed layout instead: {input.layout}"
191
+ f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
192
+ )
193
+ return True
194
+
195
+
196
+ class MmaTrait(Trait):
197
+ admissible_fields = [Field.ACCUMULATE]
198
+
199
+ def set(self, field, value, *, loc=None, ip=None) -> None:
200
+ if field not in self.admissible_fields:
201
+ raise ValueError(
202
+ f"invalid field, must be {Field.ACCUMULATE}, but got {field}"
203
+ )
204
+ field_name = f"#cute_nvgpu.atom_mma_field_sm90<{field._to_ir_field_name()}>"
205
+ attr = ir.Attribute.parse(field_name)
206
+ self.value = _cute_nvgpu_ir.atom_set_value(
207
+ self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
208
+ )
209
+
210
+
211
+ @dataclass(frozen=True)
212
+ class MmaF16BF16Op(MmaOp):
213
+ """
214
+ F16/BF16 warpgroup MMA Operation.
215
+
216
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-mma-async>`__.
217
+ This Operation covers the instructions using the ``.f16`` or ``.bf16`` qualifiers for the input operands.
218
+ """
219
+
220
+ descriptive_name = "warpgroup F16/BF16 MMA Operation"
221
+
222
+ def __init__(
223
+ self,
224
+ ab_dtype: Type[Numeric],
225
+ acc_dtype: Type[Numeric],
226
+ instruction_shape: Shape,
227
+ a_src: OperandSource,
228
+ a_major_mode: OperandMajorMode,
229
+ b_major_mode: OperandMajorMode,
230
+ ) -> None:
231
+ super().__init__(
232
+ ab_dtype,
233
+ ab_dtype,
234
+ acc_dtype,
235
+ instruction_shape,
236
+ a_src,
237
+ a_major_mode,
238
+ b_major_mode,
239
+ )
240
+ self._verify()
241
+
242
+ def _verify(self) -> None:
243
+ # Input data type verification
244
+ if self.a_dtype not in [Float16, BFloat16]:
245
+ raise OpError(
246
+ self,
247
+ "expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16",
248
+ )
249
+ assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
250
+ # Accumulator data type verification
251
+ if self.acc_dtype not in [Float16, Float32]:
252
+ raise OpError(
253
+ self,
254
+ "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
255
+ )
256
+ if (self.a_dtype == BFloat16) and (self.acc_dtype != Float32):
257
+ raise OpError(
258
+ self,
259
+ "expects the 'acc_dtype' Op parameter to be Float32 when 'ab_dtype' is BFloat16",
260
+ )
261
+ # Verify the instruction shape
262
+ instruction_k = 16
263
+ if rank(self.shape_mnk) == 2:
264
+ object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
265
+ if self.shape_mnk[2] != instruction_k:
266
+ raise OpError(
267
+ self,
268
+ f"expects the instruction extent in the K-mode to be {instruction_k}, "
269
+ f"but got {self.shape_mnk[2]}",
270
+ )
271
+
272
+ def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait":
273
+ shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
274
+ ty = _cute_nvgpu_ir.MmaAtomSM90Type.get(
275
+ shape_mnk.type.attribute,
276
+ self.a_major_mode._to_ir(),
277
+ self.b_major_mode._to_ir(),
278
+ self.a_dtype.mlir_type,
279
+ self.b_dtype.mlir_type,
280
+ self.acc_dtype.mlir_type,
281
+ self.a_src._to_ir(),
282
+ )
283
+ return MmaF16BF16Trait(
284
+ _cute_nvgpu_ir.make_sm90_mma(
285
+ ty,
286
+ Boolean(False).ir_value(loc=loc, ip=ip),
287
+ loc=loc,
288
+ ip=ip,
289
+ )
290
+ )
291
+
292
+
293
+ class MmaF16BF16Trait(MmaTrait):
294
+ pass
295
+
296
+
297
+ @dataclass(frozen=True)
298
+ class MmaF8Op(MmaOp):
299
+ """
300
+ F16/BF16 warpgroup MMA Operation.
301
+
302
+ See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-mma-async>`__.
303
+ This Operation covers the instructions using the ``.e4m3`` or ``.e5m2`` qualifiers for the input operands.
304
+ """
305
+
306
+ descriptive_name = "warpgroup F8 MMA Operation"
307
+
308
+ def __init__(
309
+ self,
310
+ a_dtype: Type[Numeric],
311
+ b_dtype: Type[Numeric],
312
+ acc_dtype: Type[Numeric],
313
+ instruction_shape: Shape,
314
+ a_src: OperandSource,
315
+ a_major_mode: OperandMajorMode,
316
+ b_major_mode: OperandMajorMode,
317
+ ) -> None:
318
+ super().__init__(
319
+ a_dtype,
320
+ b_dtype,
321
+ acc_dtype,
322
+ instruction_shape,
323
+ a_src,
324
+ a_major_mode,
325
+ b_major_mode,
326
+ )
327
+ self._verify()
328
+
329
+ def _verify(self):
330
+ # Input data type verification
331
+ if self.a_dtype not in [Float8E5M2, Float8E4M3FN]:
332
+ raise OpError(
333
+ self,
334
+ "expects the 'a_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
335
+ )
336
+ if self.b_dtype not in [Float8E5M2, Float8E4M3FN]:
337
+ raise OpError(
338
+ self,
339
+ "expects the 'b_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
340
+ )
341
+ # Accumulator data type verification
342
+ if self.acc_dtype not in [Float16, Float32]:
343
+ raise OpError(
344
+ self,
345
+ "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
346
+ )
347
+ # Verify the instruction shape
348
+ instruction_k = 32
349
+ if rank(self.shape_mnk) == 2:
350
+ object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
351
+ if self.shape_mnk[2] != instruction_k:
352
+ raise OpError(
353
+ self,
354
+ f"expects the instruction extent in the K-mode to be {instruction_k}, "
355
+ f"but got {self.shape_mnk[2]}",
356
+ )
357
+
358
+ def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF8Trait":
359
+ shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
360
+ ty = _cute_nvgpu_ir.MmaAtomSM90Type.get(
361
+ shape_mnk.type.attribute,
362
+ self.a_major_mode._to_ir(),
363
+ self.b_major_mode._to_ir(),
364
+ self.a_dtype.mlir_type,
365
+ self.b_dtype.mlir_type,
366
+ self.acc_dtype.mlir_type,
367
+ self.a_src._to_ir(),
368
+ )
369
+ return MmaF8Trait(
370
+ _cute_nvgpu_ir.make_sm90_mma(
371
+ ty, Boolean(False).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
372
+ )
373
+ )
374
+
375
+
376
+ class MmaF8Trait(MmaTrait):
377
+ pass
378
+
379
+
380
+ ####################################################################################################
381
+ #
382
+ # SMEM layout atoms
383
+ #
384
+ ####################################################################################################
385
+
386
+
387
+ class SmemLayoutAtomKind(enum.Enum):
388
+ """
389
+ Enum class for the kinds of SMEM layout atoms for SM90.
390
+
391
+ Given a swizzle kind, an SMEM layout atom is the compact layout of smallest size that can
392
+ be used to construct an SMEM layout using blocked product for operand A or B such that the
393
+ resulting layout is legal for both TMA and UMMA.
394
+
395
+ Note that there are other ways of creating legal layouts for operand A and B.
396
+ """
397
+
398
+ MN_INTER = enum.auto()
399
+ MN_SW32 = enum.auto()
400
+ MN_SW64 = enum.auto()
401
+ MN_SW128 = enum.auto()
402
+ K_INTER = enum.auto()
403
+ K_SW32 = enum.auto()
404
+ K_SW64 = enum.auto()
405
+ K_SW128 = enum.auto()
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/runtime.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ import ctypes
13
+ from functools import lru_cache
14
+ import itertools
15
+ import operator
16
+ from time import time
17
+ from typing import Union
18
+
19
+ # MLIR modules imports
20
+ from cutlass._mlir import ir
21
+ import cutlass._mlir.dialects.cute as _cute_ir
22
+
23
+ from cutlass.base_dsl.dsl import is_dynamic_expression
24
+ from cutlass.cutlass_dsl import JitArgAdapterRegistry
25
+
26
+ # Local modules imports
27
+ from .typing import (
28
+ AddressSpace,
29
+ Tensor,
30
+ Type,
31
+ Pointer,
32
+ Boolean,
33
+ Numeric,
34
+ Float4E2M1FN,
35
+ Int64,
36
+ Int32,
37
+ Int16,
38
+ Int8,
39
+ Uint64,
40
+ Uint32,
41
+ Uint16,
42
+ Uint8,
43
+ Float64,
44
+ Float32,
45
+ Float16,
46
+ BFloat16,
47
+ Float8E5M2,
48
+ )
49
+ from . import core
50
+ from .core import _Tensor as CoreTensor
51
+
52
+
53
+ class _Pointer(Pointer):
54
+ """Runtime representation of a pointer that can inter-operate with various data structures,
55
+ including numpy arrays and device memory.
56
+
57
+ :param pointer: The pointer to the data
58
+ :type pointer: int or pointer-like object
59
+ :param dtype: Data type of the elements pointed to
60
+ :type dtype: Type
61
+ :param mem_space: Memory space where the pointer resides, defaults to generic
62
+ :type mem_space: _cute_ir.AddressSpace, optional
63
+ :param assumed_align: Assumed alignment of input pointer in bytes, defaults to None
64
+ :type assumed_align: int, optional
65
+
66
+ :ivar _pointer: The underlying pointer
67
+ :ivar _dtype: Data type of the elements
68
+ :ivar _addr_space: Memory space of the pointer
69
+ :ivar _assumed_align: Alignment of the pointer in bytes
70
+ :ivar _desc: C-type descriptor for the pointer
71
+ :ivar _c_pointer: C-compatible pointer representation
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ pointer,
77
+ dtype,
78
+ mem_space: _cute_ir.AddressSpace = _cute_ir.AddressSpace.generic,
79
+ assumed_align=None,
80
+ ):
81
+ self._pointer = pointer
82
+ self._dtype = dtype
83
+ self._addr_space = mem_space
84
+
85
+ if assumed_align is None:
86
+ self._assumed_align = dtype.width // 8
87
+ else:
88
+ self._assumed_align = assumed_align
89
+
90
+ self._c_pointer = None
91
+ assert (
92
+ int(self._pointer) % self._assumed_align == 0
93
+ ), f"pointer must be {self._assumed_align} bytes aligned"
94
+
95
+ def size_in_bytes(self) -> int:
96
+ self._desc = ctypes.c_void_p(int(self._pointer))
97
+ return ctypes.sizeof(self._desc)
98
+
99
+ def __get_mlir_types__(self):
100
+ return [self.mlir_type]
101
+
102
+ def __c_pointers__(self):
103
+ if self._c_pointer is None:
104
+ self._desc = ctypes.c_void_p(int(self._pointer))
105
+ self._c_pointer = ctypes.addressof(self._desc)
106
+ return [self._c_pointer]
107
+
108
+ def __new_from_mlir_values__(self, values):
109
+ assert len(values) == 1
110
+ return values[0]
111
+
112
+ def __extract_mlir_values__(self):
113
+ return [self._c_pointer]
114
+
115
+ # Move mlir Type out of __init__ to decouple with mlir Context
116
+ @property
117
+ def mlir_type(self) -> ir.Type:
118
+ return _cute_ir.PtrType.get(
119
+ self._dtype.mlir_type, self._addr_space, self._assumed_align
120
+ )
121
+
122
+ @property
123
+ def dtype(self) -> Type[Numeric]:
124
+ return self._dtype
125
+
126
+ @property
127
+ def memspace(self):
128
+ return self._addr_space
129
+
130
+ def align(self, min_align: int, *, loc=None, ip=None) -> Pointer:
131
+ raise NotImplementedError("align is not supported in runtime")
132
+
133
+ def verify(self, expected_py_type):
134
+ if expected_py_type is Pointer:
135
+ return True
136
+ elif isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer:
137
+ return True
138
+
139
+ return False
140
+
141
+ def __str__(self) -> str:
142
+ return f"Ptr<0x{int(self._pointer):016x}@{self._addr_space}>"
143
+
144
+ def __repr__(self):
145
+ return self.__str__()
146
+
147
+
148
+ class _Tensor(Tensor):
149
+ def __init__(
150
+ self,
151
+ tensor,
152
+ assumed_align=None,
153
+ ):
154
+ # If tensor is already a DLPack object, use it directly
155
+ if hasattr(tensor, "__dlpack_device__") and not hasattr(tensor, "__dlpack__"):
156
+ self._dlpack_data = tensor
157
+ else:
158
+ self._dlpack_data = tensor.__dlpack__()
159
+ self._dltensor_wrapper = None
160
+ self._assumed_align = assumed_align
161
+ self._is_dynamic = False
162
+ self._memref_desc = None
163
+ self._dtype = None
164
+
165
+ @property
166
+ def __class__(self) -> Type[Tensor]:
167
+ # Cheat to let `type(_Tensor())` to return cute.Tensor
168
+ return Tensor
169
+
170
+ @staticmethod
171
+ def lazily_load_dltensor(func):
172
+ """Decorator to lazily load the DLTensorWrapper.
173
+
174
+ This decorator loads the DLTensorWrapper when needed,
175
+ avoiding overhead in the critical path of calling JIT functions.
176
+ """
177
+
178
+ def wrapper(self, *args, **kwargs):
179
+ if self._dltensor_wrapper is None:
180
+ self._dltensor_wrapper = _cute_ir.DLTensorWrapper(self._dlpack_data)
181
+ return func(self, *args, **kwargs)
182
+
183
+ return wrapper
184
+
185
+ @lazily_load_dltensor
186
+ def mark_layout_dynamic(self, leading_dim: int | None = None):
187
+ """Marks the tensor layout as dynamic based on the leading dimension.
188
+
189
+ :param leading_dim: The leading dimension of the layout, defaults to None
190
+ :type leading_dim: int, optional
191
+
192
+ When ``leading_dim`` is None, automatically deduces the leading dimension from the tensor layout.
193
+ The layout can be deduced only when exactly one dimension has a stride of 1. Raises an error
194
+ if the layout cannot be automatically deduced.
195
+
196
+ When ``leading_dim`` is explicitly specified, marks the layout as dynamic while setting the
197
+ stride at ``leading_dim`` to 1. Also validates that the specified ``leading_dim`` is consistent
198
+ with the existing layout by checking that the corresponding stride of that dimension is 1.
199
+
200
+ Limitation: only support flat layout for now. Will work on supporting nested layout in the future.
201
+
202
+ :return: The tensor with dynamic layout
203
+ :rtype: _Tensor
204
+ """
205
+ self._dltensor_wrapper.mark_layout_dynamic(leading_dim)
206
+ return self
207
+
208
+ @lazily_load_dltensor
209
+ def mark_compact_shape_dynamic(
210
+ self,
211
+ mode: int,
212
+ stride_order: tuple[int, ...] | None = None,
213
+ divisibility: int = 1,
214
+ ):
215
+ """Marks the tensor shape as dynamic and propagates dynamic and divisibility information to the corresponding strides.
216
+
217
+ :param mode: The mode of the compact shape, defaults to 0
218
+ :type mode: int
219
+ :param stride_order: Consistent with `torch.Tensor.dim_order`. Defaults to None.
220
+ Indicates the order of the modes (dimensions) if the current layout were converted to row-major order.
221
+ It starts from the outermost to the innermost dimension.
222
+ :type stride_order: tuple[int, ...], optional
223
+ :param divisibility: The divisibility constraint for the compact shape, defaults to 1
224
+ :type divisibility: int, optional
225
+ :return: The tensor with dynamic compact shape
226
+ :rtype: _Tensor
227
+
228
+ If ``stride_order`` is not provided, the stride ordering will be automatically deduced from the layout.
229
+ Automatic deduction is only possible when exactly one dimension has a stride of 1 (compact layout).
230
+ An error is raised if automatic deduction fails.
231
+
232
+ If ``stride_order`` is explicitly specified, it does the consistency check with the layout.
233
+
234
+ For example:
235
+ - Layout: (4,2):(1,4) has stride_order: (1,0) indicates the innermost dimension is 0(`4:1`), the outermost dimension is 1(`2:4`)
236
+ - Layout: (5,3,2,4):(3,1,15,30) has stride_order: (3,2,0,1) indicates the innermost dimension is 1(`3:1`), the outermost dimension is 3(`4:30`).
237
+
238
+ Using `torch.Tensor.dim_order()` to get the stride order of the torch tensor.
239
+ .. code-block:: python
240
+ a = torch.empty(3, 4)
241
+ t = cute.runtime.from_dlpack(a)
242
+ t = t.mark_compact_shape_dynamic(mode=0, stride_order=a.dim_order())
243
+ """
244
+ self._dltensor_wrapper.mark_compact_shape_dynamic(
245
+ mode, stride_order, divisibility
246
+ )
247
+ return self
248
+
249
+ @property
250
+ @lazily_load_dltensor
251
+ def element_type(self) -> Type[Numeric]:
252
+ if self._dtype is None:
253
+ self._dtype = self._dltensor_wrapper.dtype
254
+ return self._dtype
255
+
256
+ @element_type.setter
257
+ def element_type(self, new_type):
258
+ """Set the element type of the tensor.
259
+
260
+ :warning: This API is added for narrow precision before we have a clean `recast_tensor` story.
261
+
262
+ :note: It is only used for the case that frameworks don't natively support narrow precision but we get tensor
263
+ from frameworks with storage type like uint8.
264
+
265
+ **Example**:
266
+
267
+ .. code-block:: python
268
+
269
+ # Create a tensor from a numpy array
270
+ import numpy as np
271
+ from cutlass.cute import from_dlpack
272
+
273
+ # Create a tensor with Float32 elements
274
+ a = np.zeros(shape, dtype=np.uint8)
275
+ tensor = from_dlpack(a)
276
+
277
+ # Change the element type to Float4E2M1FN even storage type is uint8
278
+ tensor.element_type = cutlass.Float4E2M1FN
279
+
280
+ src = from_dlpack(... data tensor ...)
281
+ # convert and initialize narrow precision tensor
282
+ cute.testing.convert(src, tensor)
283
+ """
284
+ self._dtype = new_type
285
+
286
+ @property
287
+ @lazily_load_dltensor
288
+ def memspace(self):
289
+ return self._dltensor_wrapper.address_space
290
+
291
+ @property
292
+ @lazily_load_dltensor
293
+ def size_in_bytes(self) -> int:
294
+ return self._dltensor_wrapper.size_in_bytes()
295
+
296
+ @property
297
+ @lazily_load_dltensor
298
+ def mlir_type(self) -> ir.Type:
299
+ return self._dltensor_wrapper.get_type(
300
+ self.element_type.mlir_type, self._assumed_align
301
+ )
302
+
303
+ @lazily_load_dltensor
304
+ def __str__(self) -> str:
305
+ return f"Tensor<0x{self._dltensor_wrapper.str}>"
306
+
307
+ def __repr__(self):
308
+ return self.__str__()
309
+
310
+ def __setitem__(self, crd, value):
311
+ raise TypeError(f"runtime._Tensor is not indexable")
312
+
313
+ def __getitem__(self, crd):
314
+ raise TypeError(f"runtime._Tensor is not indexable")
315
+
316
+ @property
317
+ @lazily_load_dltensor
318
+ def iterator(self):
319
+ return _Pointer(
320
+ self._dltensor_wrapper.data_ptr,
321
+ self.element_type,
322
+ self.memspace,
323
+ self._assumed_align,
324
+ )
325
+
326
+ @property
327
+ def layout(self):
328
+ raise NotImplementedError(
329
+ f"layout property is not supported in runtime, support in future"
330
+ )
331
+
332
+ @property
333
+ @lazily_load_dltensor
334
+ def shape(self):
335
+ return self._dltensor_wrapper.shape
336
+
337
+ @property
338
+ @lazily_load_dltensor
339
+ def stride(self):
340
+ strides = self._dltensor_wrapper.stride
341
+ if strides is None:
342
+ strides = itertools.accumulate(
343
+ reversed(self.shape), func=operator.mul, initial=1
344
+ )
345
+ strides = tuple(reversed(list(strides)[:-1]))
346
+
347
+ return strides
348
+
349
+ @property
350
+ @lru_cache(maxsize=128, typed=True)
351
+ def leading_dim(self):
352
+ """Get the leading dimension of this Tensor.
353
+
354
+ :return: The leading dimension index or indices
355
+ :rtype: int or tuple or None
356
+
357
+ The return value depends on the tensor's stride pattern:
358
+
359
+ * If a single leading dimension is found, returns an integer index
360
+ * If nested leading dimensions are found, returns a tuple of indices
361
+ * If no leading dimension is found, returns None
362
+ """
363
+ return core.leading_dim(self.shape, self.stride)
364
+
365
+ def fill(self, value: Numeric):
366
+ raise TypeError(f"fill function is not supported in runtime")
367
+
368
+ @property
369
+ @lazily_load_dltensor
370
+ def data_ptr(self):
371
+ return self._dltensor_wrapper.data_ptr
372
+
373
+ @lazily_load_dltensor
374
+ def __c_pointers__(self):
375
+ self._memref_desc = self._dltensor_wrapper.build_memref_desc(
376
+ self._assumed_align
377
+ )
378
+ return [_cute_ir.pycapsule_get_pointer(self._memref_desc)]
379
+
380
+ def __get_mlir_types__(self):
381
+ return [self.mlir_type]
382
+
383
+ def __new_from_mlir_values__(self, values):
384
+ assert len(values) == 1
385
+ assert isinstance(values[0], CoreTensor)
386
+ return CoreTensor(values[0].value, self._dtype)
387
+
388
+
389
+ def from_dlpack(
390
+ tensor_dlpack,
391
+ assumed_align=None,
392
+ ) -> Tensor:
393
+ """Convert from tensor object supporting __dlpack__() to a CuTe Tensor.
394
+
395
+ :param tensor_dlpack: Tensor object that supports the DLPack protocol
396
+ :type tensor_dlpack: object
397
+ :param assumed_align: Assumed alignment of the tensor (bytes), defaults to None,
398
+ if None, will use the element size bytes as the assumed alignment.
399
+ :type assumed_align: int, optional
400
+ :return: A CuTe Tensor object
401
+ :rtype: Tensor
402
+
403
+ Examples:
404
+ .. code-block:: python
405
+
406
+ import torch
407
+ from cutlass.cute.runtime import from_dlpack
408
+ x = torch.randn(100, 100)
409
+ y = from_dlpack(x)
410
+ y.shape
411
+ # (100, 100)
412
+ type(y)
413
+ # <class 'cutlass.cute.Tensor'>
414
+ """
415
+ return _Tensor(
416
+ tensor_dlpack,
417
+ assumed_align=assumed_align,
418
+ )
419
+
420
+
421
+ def make_ptr(
422
+ dtype: Type[Numeric],
423
+ value: Union[int, ctypes._Pointer],
424
+ mem_space: AddressSpace = AddressSpace.generic,
425
+ assumed_align=None,
426
+ ) -> Pointer:
427
+ """Create a pointer from a memory address
428
+
429
+ :param dtype: Data type of the pointer elements
430
+ :type dtype: Type[Numeric]
431
+ :param value: Memory address as integer or ctypes pointer
432
+ :type value: Union[int, ctypes._Pointer]
433
+ :param mem_space: Memory address space, defaults to AddressSpace.generic
434
+ :type mem_space: AddressSpace, optional
435
+ :param align_bytes: Alignment in bytes, defaults to None
436
+ :type align_bytes: int, optional
437
+ :return: A pointer object
438
+ :rtype: Pointer
439
+
440
+ .. code-block:: python
441
+
442
+ import numpy as np
443
+ import ctypes
444
+
445
+ from cutlass import Float32
446
+ from cutlass.cute.runtime import make_ptr
447
+
448
+ # Create a numpy array
449
+ a = np.random.randn(16, 32).astype(np.float32)
450
+
451
+ # Get pointer address as integer
452
+ ptr_address = a.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
453
+
454
+ # Create pointer from address
455
+ y = make_ptr(cutlass.Float32, ptr_address)
456
+
457
+ # Check properties
458
+ print(y.element_type)
459
+ print(type(y)) # <class 'cutlass.cute.Pointer'>
460
+ """
461
+ # check if value is int or ctypes.POINTER
462
+ if isinstance(value, int):
463
+ address_value = value
464
+ elif isinstance(value, ctypes._Pointer):
465
+ # get address value
466
+ address_value = ctypes.cast(value, ctypes.c_void_p).value
467
+ assert address_value is not None, "Pointer address is None"
468
+ else:
469
+ raise TypeError(
470
+ f"Expect int or ctypes.POINTER for value but got {type(value)=}"
471
+ )
472
+
473
+ return _Pointer(address_value, dtype, mem_space, assumed_align=assumed_align)
474
+
475
+
476
+ class TensorAdapter:
477
+ """
478
+ Convert a DLPack protocol supported tensor/array to a cute tensor.
479
+ """
480
+
481
+ def __init__(self, arg):
482
+ self._arg = from_dlpack(arg).mark_layout_dynamic()
483
+
484
+ def __new_from_mlir_values__(self, values):
485
+ return self._arg.__new_from_mlir_values__(values)
486
+
487
+ def __c_pointers__(self):
488
+ return self._arg.__c_pointers__()
489
+
490
+ def __get_mlir_types__(self):
491
+ return self._arg.__get_mlir_types__()
492
+
493
+
494
+ # -------------------------------------------------------------------------
495
+ # Try to register_jit_arg_adapter for TensorAdapter
496
+ # -------------------------------------------------------------------------
497
+
498
+ try: # Register for numpy.ndarray
499
+ import numpy
500
+
501
+ JitArgAdapterRegistry.register_jit_arg_adapter(numpy.ndarray)(TensorAdapter)
502
+ except ImportError:
503
+ pass # silent attempt, suppress error
504
+
505
+ try: # Register for torch.Tensor
506
+ import torch
507
+
508
+ JitArgAdapterRegistry.register_jit_arg_adapter(torch.Tensor)(TensorAdapter)
509
+ except ImportError:
510
+ pass # silent attempt, suppress error
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/testing.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ import functools
13
+ import inspect
14
+ import logging
15
+ import os
16
+ from enum import Enum
17
+ from inspect import isclass
18
+ from itertools import product
19
+ from time import time
20
+ from typing import Any, Callable, Dict, List, Optional, Type, Union
21
+
22
+ import cuda.bindings.driver as cuda_driver
23
+ import cuda.bindings.runtime as cuda_runtime
24
+ import numpy as np
25
+
26
+ import cutlass._mlir.ir as ir
27
+ import cutlass.base_dsl.jit_executor
28
+ import cutlass.cute as cute
29
+ from cutlass._mlir.dialects import builtin, cf, nvvm, vector
30
+ from cutlass.cute import core, nvgpu
31
+ from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, t, dsl_user_op
32
+
33
+
34
+ @dsl_user_op
35
+ def assert_(cond, msg=None, *, loc=None, ip=None):
36
+ cf.assert_(t.Boolean(cond).ir_value(), msg if msg else "", loc=loc, ip=ip)
37
+
38
+
39
+ def _maybe_recast_tensor_from_f4(src: core.Tensor, tv_layout: core.Layout):
40
+ if src.element_type.width == 4:
41
+ tv_layout = core.recast_layout(8, 4, tv_layout)
42
+ src = core.recast_tensor(src, dtype=t.Int8)
43
+ return src, tv_layout
44
+
45
+
46
+ def _maybe_recast_to_f4(input: core.TensorSSA, dtype: Type[core.Numeric]):
47
+ """Conditionally recasts the tensor to 4-bit type if the destination type is 4-bit.
48
+
49
+ :param input: The input tensor to recast.
50
+ :param dtype: The target numeric type to potentially recast to.
51
+ :raises TypeError: If dtype is not a subclass of Numeric.
52
+ :return: A new tensor recast to 4-bit if dtype is 4-bit, otherwise returns self unchanged.
53
+ """
54
+ if not isclass(dtype) or not issubclass(dtype, core.Numeric):
55
+ raise TypeError(f"dst_ty must be a type of Numeric, but got {dtype}")
56
+
57
+ if dtype.width == 4:
58
+ recast_shape = core.recast_layout(4, 8, core.make_layout(input.shape)).shape
59
+ i4_vec = vector.bitcast(
60
+ T.vector(input.type.shape[0] * 2, T.i(4)), input.maybe_downcast()
61
+ )
62
+ res_vect = builtin.unrealized_conversion_cast(
63
+ [T.vector(i4_vec.type.shape[0], dtype.mlir_type)], [i4_vec]
64
+ )
65
+ return core.TensorSSA(res_vect, recast_shape, dtype)
66
+ return input
67
+
68
+
69
+ def _maybe_recast_from_f4(input: core.TensorSSA, src_dtype: Type[core.Numeric]):
70
+ """Conditionally recasts the tensor from 4-bit type if the source type is 4-bit.
71
+
72
+ :param input: The input tensor to recast.
73
+ :param src_dtype: The source numeric type to potentially recast from.
74
+ :raises TypeError: If src_dtype is not a subclass of Numeric.
75
+ :return: A new tensor recast from 4-bit if src_dtype is 4-bit, otherwise returns self unchanged.
76
+ """
77
+ if not isclass(src_dtype) or not issubclass(src_dtype, core.Numeric):
78
+ raise TypeError(f"src_ty must be a type of Numeric, but got {src_dtype}")
79
+
80
+ if src_dtype.width == 4:
81
+ recast_shape = core.recast_layout(8, 4, core.make_layout(input.shape)).shape
82
+ i4_vec = builtin.unrealized_conversion_cast(
83
+ [T.vector(input.type.shape[0], T.i(4))], [input.maybe_downcast()]
84
+ )
85
+ res_vect = vector.bitcast(T.vector(i4_vec.type.shape[0] // 2, T.i8()), i4_vec)
86
+ return core.TensorSSA(res_vect, recast_shape, core.Int8)
87
+ return input
88
+
89
+
90
+ @CuTeDSL.kernel
91
+ def _convert_kernel(
92
+ gSrc: core.Tensor,
93
+ gDst: core.Tensor,
94
+ cSrc: core.Tensor,
95
+ src_tv_layout: core.Layout,
96
+ dst_tv_layout: core.Layout,
97
+ src_shape: core.Shape,
98
+ src_ty,
99
+ dst_ty,
100
+ ):
101
+ tidx = nvvm.read_ptx_sreg_tid_x(T.i32())
102
+ bidx = nvvm.read_ptx_sreg_ctaid_x(T.i32())
103
+
104
+ cta_coord = (None, bidx)
105
+ # logical idx -> address
106
+ ctaSrc = gSrc[cta_coord] # (...,TileV,...)
107
+ ctaDst = gDst[cta_coord] # (...,TileV,...)
108
+ ctaCSrc = cSrc[cta_coord] # (...,TileV,...)
109
+ # print(f"ctaSrc = {ctaSrc.type}")
110
+
111
+ # compose with CTA TV layout
112
+ # tid, vid -> address
113
+ tidfrgSrc = core.composition(ctaSrc, src_tv_layout) # (T,V)
114
+ tidfrgDst = core.composition(ctaDst, dst_tv_layout) # (T,V)
115
+ tidfrgCSrc = core.composition(ctaCSrc, src_tv_layout) # (T,V)
116
+ # print(f"tidfrgSrc = {tidfrgSrc.type}")
117
+
118
+ # slice for threads
119
+ thr_coord = (tidx, None)
120
+ thrSrc = tidfrgSrc[thr_coord] # (V)
121
+ thrDst = tidfrgDst[thr_coord] # (V)
122
+ thrCSrc = tidfrgCSrc[thr_coord] # (V)
123
+ # print(f"thrSrc = {thrSrc.type}")
124
+
125
+ # predicate
126
+ if core.elem_less(thrCSrc[0], src_shape):
127
+ # allocate fragments for gmem->rmem
128
+ frgSrc = core.make_fragment(
129
+ core.get(src_tv_layout, mode=[1]), gSrc.element_type
130
+ ) # (V)
131
+ frgDst = core.make_fragment(
132
+ core.get(dst_tv_layout, mode=[1]), gDst.element_type
133
+ ) # (V)
134
+ # print(f"frgSrc = {frgSrc.type}")
135
+
136
+ # Move data to reg address space
137
+ copy_atom_load = core.make_copy_atom(nvgpu.CopyUniversalOp(), gSrc.element_type)
138
+ core.copy(copy_atom_load, thrSrc, frgSrc)
139
+
140
+ vec_src = frgSrc.load()
141
+ vec_src = _maybe_recast_to_f4(vec_src, src_ty)
142
+ vec_dst = vec_src.to(dst_ty)
143
+ vec_dst = _maybe_recast_from_f4(vec_dst, dst_ty)
144
+ frgDst.store(vec_dst)
145
+
146
+ # Copy the results back to c
147
+ copy_atom_stg = core.make_copy_atom(nvgpu.CopyUniversalOp(), gDst.element_type)
148
+ core.copy(copy_atom_stg, frgDst, thrDst)
149
+
150
+
151
+ @CuTeDSL.jit(preprocess=False)
152
+ def _convert(
153
+ src: core.Tensor,
154
+ dst: core.Tensor,
155
+ leading_mode: Constexpr,
156
+ elem_per_copy: Constexpr,
157
+ ):
158
+
159
+ # Step 1. figure proper tv_layout
160
+ src_ty = src.element_type
161
+ dst_ty = dst.element_type
162
+
163
+ tv_layout = core.make_layout((128, elem_per_copy), stride=(elem_per_copy, 1))
164
+
165
+ # Step 2. maybe recast from f4 tensor
166
+ src, src_tv_layout = _maybe_recast_tensor_from_f4(src, tv_layout)
167
+ dst, dst_tv_layout = _maybe_recast_tensor_from_f4(dst, tv_layout)
168
+ src_shape = src.shape
169
+ # predicate tensor
170
+ idA = core.make_identity_tensor(src.shape)
171
+
172
+ # Step 3. select a proper tiling pattern as (...,TileV, ...)
173
+ src_cta_tiler = [
174
+ 1,
175
+ ] * core.rank(src.layout)
176
+ src_cta_tiler[leading_mode] = core.size(src_tv_layout) # (...,TileV,...)
177
+ dst_cta_tiler = [
178
+ 1,
179
+ ] * core.rank(dst.layout)
180
+ dst_cta_tiler[leading_mode] = core.size(dst_tv_layout) # (...,TileV,...)
181
+
182
+ # Step 4. partition input and output tensor by cta tiler.
183
+ gS = core.zipped_divide(
184
+ src, tuple(src_cta_tiler)
185
+ ) # ((...,TileV,...),(...,RestV,...))
186
+ cS = core.zipped_divide(
187
+ idA, tuple(src_cta_tiler)
188
+ ) # ((...,TileV,...),(...,RestV,...))
189
+ gD = core.zipped_divide(
190
+ dst, tuple(dst_cta_tiler)
191
+ ) # ((...,TileV,...),(...,RestV,...))
192
+ # print(f"{gS.type=}")
193
+
194
+ _convert_kernel(
195
+ gS,
196
+ gD,
197
+ cS,
198
+ src_tv_layout,
199
+ dst_tv_layout,
200
+ src_shape,
201
+ src_ty,
202
+ dst_ty,
203
+ ).launch(
204
+ grid=[core.size(gS, mode=[1]), 1, 1],
205
+ block=[core.size(src_tv_layout, mode=[0]), 1, 1],
206
+ )
207
+
208
+
209
+ # Converts from src tensor to dst tensor, their logical shape are required to be the same.
210
+ # And when src or dst dtype is narrow precision(Float4E2M1FN/Float8E8M0FNU/Float8E4M3FN), the shape of
211
+ # their leading dimension should be 4(fp8)/8(fp4) element align. (nvgpu.cvt_fptrunc/cvt_fpext
212
+ # needs 32-bits aligned input/output)
213
+ def convert(src: core.Tensor, dst: core.Tensor):
214
+ assert len(src.shape) == len(
215
+ dst.shape
216
+ ), "Shape of src and dst tensors should be the same rank."
217
+ # find leading mode
218
+ leading_mode = [
219
+ idx
220
+ for idx, (shape, stride) in enumerate(zip(src.shape, src.stride))
221
+ if shape > 1 and stride == 1
222
+ ]
223
+ if len(leading_mode) != 1:
224
+ raise ValueError(f"Leading mode should be unique, but got {leading_mode}")
225
+ leading_mode = leading_mode[0]
226
+
227
+ elem_per_copy = 2
228
+
229
+ if src.element_type.width == 4 or dst.element_type.width == 4:
230
+ elem_per_copy = 8
231
+ elif src.element_type.width == 8 or dst.element_type.width == 8:
232
+ elem_per_copy = 4
233
+ assert (
234
+ src.shape[leading_mode] % elem_per_copy == 0
235
+ and dst.shape[leading_mode] % elem_per_copy == 0
236
+ )
237
+ _convert(src, dst, leading_mode, elem_per_copy)
238
+
239
+
240
+ #########################################
241
+ # Testing utilities
242
+ #########################################
243
+
244
+
245
+ def sample_pytest(rand_cfg=None):
246
+ """
247
+ Decorator to randomly sample pytest parametrized tests.
248
+ rand_cfg: Tuple[int, float] - (random_seed, sample_ratio)
249
+ Sampling is disabled when:
250
+ - A specific test is selected (via -k or direct test path)
251
+ - Not running under pytest
252
+ """
253
+ import functools
254
+ import os
255
+ import random
256
+ import sys
257
+
258
+ import pytest
259
+
260
+ seed, sample_ratio = rand_cfg
261
+ random.seed(seed)
262
+
263
+ def decorator(func):
264
+ @functools.wraps(func)
265
+ def wrapper(*args, **kwargs):
266
+ if rand_cfg is not None and "PYTEST_CURRENT_TEST" in os.environ:
267
+ # Check if test was explicitly selected like ::test_name[param1-param2-...]
268
+ if "-k" in sys.argv or any(".py::" in arg for arg in sys.argv):
269
+ # Test was explicitly selected, don't skip
270
+ return func(*args, **kwargs)
271
+
272
+ if random.uniform(0.0, 1.0) > sample_ratio:
273
+ pytest.skip(f"Randomly skipped (sampling ratio: {sample_ratio})")
274
+ return func(*args, **kwargs)
275
+
276
+ return wrapper
277
+
278
+ return decorator
279
+
280
+
281
+ #########################################
282
+ # Benchmarking utilities
283
+ #########################################
284
+
285
+
286
+ class JitArguments:
287
+ """
288
+ A type to hold both args and kwargs for passing to a kernel while benchmarking.
289
+ """
290
+
291
+ def __init__(self, *args, **kwargs):
292
+ self.args = args
293
+ self.kwargs = kwargs
294
+
295
+
296
+ def _cuda_success(
297
+ err: Union[tuple, cuda_runtime.cudaError_t, cuda_driver.CUresult], message: str
298
+ ):
299
+ """
300
+ Helper function to check CUDA API errors.
301
+ """
302
+ if isinstance(err, tuple):
303
+ _cuda_success(err[0], message)
304
+ elif isinstance(err, cuda_runtime.cudaError_t):
305
+ error_message = cuda_runtime.cudaGetErrorString(err)[1].decode("utf-8")
306
+ if err != cuda_runtime.cudaError_t.cudaSuccess:
307
+ raise RuntimeError(f"{message} : {error_message}")
308
+ elif isinstance(err, cuda_driver.CUresult):
309
+ if err != cuda_driver.CUresult.CUDA_SUCCESS:
310
+ error_message = cuda_driver.cuGetErrorString(err)[1].decode("utf-8")
311
+ raise RuntimeError(f"{message} : {error_message}")
312
+ else:
313
+ raise TypeError(
314
+ f"{err} is an unexpected type : it should be a cudaError_t or CUresult"
315
+ )
316
+
317
+
318
+ def _does_kernel_use_stream(
319
+ kernel: Callable, stream: cuda_driver.CUstream, *args, **kwargs
320
+ ):
321
+ """
322
+ This function checks if the kernel uses the provided non-default stream.
323
+ It does this by capturing the stream and then checking if any kernels were launched.
324
+ :param kernel: The kernel to check
325
+ :type kernel: Callable
326
+ :param stream: The stream to check
327
+ :type stream: cuda_driver.CUstream
328
+ :return: True if the kernel uses the stream, False otherwise
329
+ :rtype: bool
330
+ """
331
+
332
+ assert int(stream) != int(
333
+ cuda_driver.CUstream_flags.CU_STREAM_DEFAULT
334
+ ), "Stream must be a non-default stream"
335
+
336
+ err = cuda_runtime.cudaStreamBeginCapture(
337
+ stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal
338
+ )
339
+ _cuda_success(err, "Error on stream capture")
340
+
341
+ kernel(*args, **kwargs)
342
+
343
+ err, graph = cuda_runtime.cudaStreamEndCapture(stream)
344
+ _cuda_success(err, "Error on stream capture")
345
+
346
+ # Get number of nodes in warmup graph to check it matches what is expected
347
+ err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(graph)
348
+ _cuda_success(err, "Error on querying graph")
349
+ return num_nodes > 0
350
+
351
+
352
+ def benchmark(
353
+ callable: Callable,
354
+ *,
355
+ warmup_iterations: int = 10,
356
+ iterations: int = 100,
357
+ stream: Optional[cuda_driver.CUstream] = None,
358
+ kernel_arguments: Optional[JitArguments] = None,
359
+ workspace_generator: Optional[Callable[[], JitArguments]] = None,
360
+ workspace_count: int = 1,
361
+ use_cuda_graphs: bool = False,
362
+ ) -> float:
363
+ """Benchmarks a callable function with the specified parameters.
364
+
365
+ For example,
366
+ .. code-block:: python
367
+
368
+ from cutlass.cute.testing import benchmark
369
+
370
+ @cute.jit
371
+ def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda_driver.CUstream):
372
+ # contents of the function
373
+ pass
374
+
375
+ time_us = benchmark(user_function, kernel_arguments=JitArguments(a, b, c, stream)
376
+ warmup_iterations=10, iterations=100
377
+ stream=stream)
378
+
379
+ To prevent skewing results by repeately accessing the L2 cache, use the workspace_count and workspace_generator
380
+ parameters to cycle through a number of different workspaces.
381
+
382
+ .. code-block:: python
383
+
384
+ from cutlass.cute.testing import benchmark
385
+
386
+ @cute.jit
387
+ def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):
388
+ # contents of the function
389
+ pass
390
+
391
+ def workspace_generator():
392
+ # create a, b, and c
393
+ return JitArguments(a, b, c)
394
+
395
+ time_us = benchmark(user_function,
396
+ workspace_generator=workspace_generator,
397
+ workspace_count=10,
398
+ warmup_iterations=10000,
399
+ iterations=1000)
400
+
401
+ To benchmark you may always configure the function being profiled (callable), the warmup iterations, and
402
+ the number of profiling iterations.
403
+
404
+ Whenever the kernel being benchmarked runs in a non-default stream, the stream must be provided through the stream parameter.
405
+
406
+ To use CUDA graphs, the callable must be a compiled @cute.jit annotated function.
407
+ When using CUDA graphs, the kernel must be launched in a non-default stream.
408
+
409
+ :param callable: The function to benchmark
410
+ :type callable: Callable
411
+ :param warmup_iterations: Number of warmup iterations, defaults to 10
412
+ :type warmup_iterations: int, optional
413
+ :param iterations: Number of benchmark iterations, defaults to 100
414
+ :type iterations: int, optional
415
+ :param stream: Stream kernel is launched in, defaults to CUDA stream default
416
+ :type stream: CUstream, None
417
+ :param kernel_arguments: Kernel arguments to launch callable with, defaults to None
418
+ :type kernel_arguments: JitArguments, None
419
+ :param workspace_generator: Function that returns kernel arguments, defaults to None
420
+ :type workspace_generator: Callable
421
+ :param workspace_count: Number of workspaces (arguments) to loop through, looping through enough workspaces will keep the L2 cache cold
422
+ :type workspace_count: int, optional
423
+ :param use_cuda_graphs: Whether to use cuda graphs, defaults to False
424
+ :type use_cuda_graphs: bool, optional
425
+
426
+ :return: The benchmark time in microseconds
427
+ :rtype: float
428
+ """
429
+
430
+ if stream is None:
431
+ stream = cuda_driver.CUstream(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT)
432
+
433
+ if workspace_count < 1:
434
+ raise ValueError("workspace_count must be at least 1")
435
+
436
+ time_us = float("nan")
437
+ if workspace_generator == None:
438
+ # If no workspace generator is provided, we need a single workspace
439
+ if workspace_count != 1:
440
+ raise ValueError("Need a single workspace if not providing a generator")
441
+
442
+ # If no workspace generator is provided, we need a kernel_argument
443
+ if kernel_arguments == None:
444
+ raise ValueError(
445
+ "Please pass a kernel argument if not providing a generator"
446
+ )
447
+ workspace_generator = lambda: kernel_arguments
448
+
449
+ workspaces = [workspace_generator() for _ in range(workspace_count)]
450
+
451
+ for workspace in workspaces:
452
+ if type(workspace) != JitArguments:
453
+ raise TypeError(
454
+ "workspace_generator and/or kernel_arguments should use JitArguments type"
455
+ )
456
+
457
+ def _loop_and_call_kernel(iterations: int, workspace_index: int = 0):
458
+ for _ in range(iterations):
459
+ current_workspace = workspaces[workspace_index]
460
+ callable(*current_workspace.args, **current_workspace.kwargs)
461
+ workspace_index = (workspace_index + 1) % workspace_count
462
+ return workspace_index
463
+
464
+ # Create CUDA events for timing
465
+ err, start_event = cuda_driver.cuEventCreate(
466
+ cuda_driver.CUevent_flags.CU_EVENT_DEFAULT
467
+ )
468
+ _cuda_success(err, "Error on creating event")
469
+ err, end_event = cuda_driver.cuEventCreate(
470
+ cuda_driver.CUevent_flags.CU_EVENT_DEFAULT
471
+ )
472
+ _cuda_success(err, "Error on creating event")
473
+
474
+ elapsed_time = float("nan")
475
+
476
+ if use_cuda_graphs:
477
+ # Check if the callable is a JitExecutor
478
+ if not isinstance(callable, cutlass.base_dsl.jit_executor.JitExecutor):
479
+ raise TypeError("Function must be precompiled to be used with CUDA Graphs")
480
+
481
+ # Check if the stream is a non-default stream
482
+ if int(stream) == int(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT):
483
+ raise ValueError(
484
+ "Measuring with CUDA Graphs requires executing in a non-default stream"
485
+ )
486
+
487
+ workspace_index = 0
488
+
489
+ # Capture warmup graph
490
+ err = cuda_runtime.cudaStreamBeginCapture(
491
+ stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal
492
+ )
493
+ _cuda_success(err, "Error on stream capture")
494
+
495
+ workspace_index = _loop_and_call_kernel(warmup_iterations)
496
+ err, gwarm = cuda_runtime.cudaStreamEndCapture(stream)
497
+ _cuda_success(err, "Error on stream capture")
498
+
499
+ # Get number of nodes in warmup graph to check it matches what is expected
500
+ err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(gwarm)
501
+ _cuda_success(err, "Error on querying graph")
502
+ # Assertion is >= since we may launch multiple kernels in one host function
503
+ if num_nodes < warmup_iterations:
504
+ raise ValueError(
505
+ f"CUDA stream passed to benchmark does not match the stream the kernel was launched in"
506
+ )
507
+
508
+ # Capture profiling graph
509
+ err = cuda_runtime.cudaStreamBeginCapture(
510
+ stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal
511
+ )
512
+ _cuda_success(err, "Error on stream capture")
513
+ _loop_and_call_kernel(iterations, workspace_index)
514
+ err, gprofile = cuda_runtime.cudaStreamEndCapture(stream)
515
+ _cuda_success(err, "Error on stream capture")
516
+
517
+ # Instantiate graphs
518
+ err, gwarm = cuda_runtime.cudaGraphInstantiate(gwarm, 0)
519
+ _cuda_success(err, "Error on graph instantiation")
520
+ err, gprofile = cuda_runtime.cudaGraphInstantiate(gprofile, 0)
521
+ _cuda_success(err, "Error on graph instantiation")
522
+
523
+ # Launch warmup graph
524
+ err = cuda_runtime.cudaGraphLaunch(gwarm, stream)
525
+ _cuda_success(err, "Error on graph launch")
526
+
527
+ # Record start time
528
+ err = cuda_driver.cuEventRecord(start_event, stream)
529
+ _cuda_success(err, "Error on recording event")
530
+
531
+ # Launch profiling graph
532
+ err = cuda_runtime.cudaGraphLaunch(gprofile, stream)
533
+ _cuda_success(err, "Error on graph launch")
534
+
535
+ # Record end time
536
+ err = cuda_driver.cuEventRecord(end_event, stream)
537
+ _cuda_success(err, "Error on recording event")
538
+ err = cuda_driver.cuEventSynchronize(end_event)
539
+ _cuda_success(err, "Error on synchronizing event")
540
+
541
+ # Get elapsed time
542
+ err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event)
543
+ _cuda_success(err, "Error on querying event")
544
+
545
+ # Destroy graphs
546
+ err = cuda_runtime.cudaGraphExecDestroy(gwarm)
547
+ _cuda_success(err, "Error on destroying graph")
548
+ err = cuda_runtime.cudaGraphExecDestroy(gprofile)
549
+ _cuda_success(err, "Error on destroying graph")
550
+
551
+ else:
552
+
553
+ if int(stream) != int(
554
+ cuda_driver.CUstream_flags.CU_STREAM_DEFAULT
555
+ ) and not _does_kernel_use_stream(
556
+ callable, stream, *workspaces[0].args, **workspaces[0].kwargs
557
+ ):
558
+ raise ValueError(
559
+ "CUDA stream passed to benchmark does not match the stream the kernel was launched in"
560
+ )
561
+
562
+ # Not using graphs
563
+ # Warmup
564
+ workspace_index = _loop_and_call_kernel(warmup_iterations)
565
+ # Record start event
566
+ err = cuda_driver.cuEventRecord(start_event, stream)
567
+ _cuda_success(err, "Error on recording event")
568
+ _loop_and_call_kernel(iterations, workspace_index)
569
+ # Record end event
570
+ err = cuda_driver.cuEventRecord(end_event, stream)
571
+ _cuda_success(err, "Error on recording event")
572
+ # Synchronize end event
573
+ err = cuda_driver.cuEventSynchronize(end_event)
574
+ _cuda_success(err, "Error on synchronizing event")
575
+ err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event)
576
+ _cuda_success(err, "Error on querying event")
577
+
578
+ # Destroy events
579
+ err = cuda_driver.cuEventDestroy(start_event)
580
+ _cuda_success(err, "Error on destroying event")
581
+ err = cuda_driver.cuEventDestroy(end_event)
582
+ _cuda_success(err, "Error on destroying event")
583
+
584
+ return elapsed_time / iterations * 1e3
585
+
586
+
587
+ def get_workspace_count(
588
+ one_workspace_bytes: int, warmup_iterations: int, iterations: int
589
+ ) -> int:
590
+ """Calculate the number of workspaces needed to fill L2 cache.
591
+
592
+ :param one_workspace_bytes: Size of one workspace in bytes
593
+ :type one_workspace_bytes: int
594
+ :param warmup_iterations: Number of warmup iterations
595
+ :type warmup_iterations: int
596
+ :param iterations: Number of iterations
597
+ :type iterations: int
598
+ :return: Number of workspaces needed
599
+ :rtype: int
600
+ """
601
+ num_l2_cache_bytes = cutlass.utils.HardwareInfo().get_l2_cache_size_in_bytes()
602
+ return max(
603
+ 1,
604
+ min(
605
+ warmup_iterations + iterations, # Don't create more workspaces than needed
606
+ (num_l2_cache_bytes + one_workspace_bytes - 1)
607
+ // one_workspace_bytes, # Ceiling division
608
+ ),
609
+ )
610
+
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/typing.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # Use of this software is governed by the terms and conditions of the
5
+ # NVIDIA End User License Agreement (EULA), available at:
6
+ # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
7
+ #
8
+ # Any use, reproduction, disclosure, or distribution of this software
9
+ # and related documentation outside the scope permitted by the EULA
10
+ # is strictly prohibited.
11
+
12
+ from abc import ABC, abstractmethod
13
+ from typing import ForwardRef, Tuple, Union, Any, Type, List
14
+
15
+ from cutlass.base_dsl.typing import *
16
+
17
+ from cutlass._mlir import ir
18
+ import cutlass._mlir.extras.types as T
19
+ from cutlass._mlir.dialects.cute import AddressSpace
20
+
21
+
22
+ Int = Union[int, Integer]
23
+
24
+
25
+ ScaledBasis = ForwardRef("ScaledBasis")
26
+
27
+
28
+ IntTuple = Union[Int, Tuple["IntTuple", ...]]
29
+ Shape = Union[Int, Tuple["Shape", ...]]
30
+ Stride = Union[Int, ScaledBasis, Tuple["Stride", ...]]
31
+ Coord = Union[Int, None, Tuple["Coord", ...]]
32
+
33
+
34
+ class Layout(ir.Value):
35
+ def __init__(self, op_result):
36
+ super().__init__(op_result)
37
+
38
+ def __str__(self): ...
39
+
40
+ def get_hier_coord(self, idx) -> Coord:
41
+ """Return the (hierarchical) ND logical coordinate corresponding to the linear index"""
42
+ ...
43
+
44
+ @property
45
+ def shape(self, *, loc=None, ip=None) -> Shape: ...
46
+
47
+ @property
48
+ def stride(self, *, loc=None, ip=None) -> Stride: ...
49
+
50
+
51
+ Tile = Union[Int, None, Layout, Tuple["Tile", ...]]
52
+
53
+ # XTuple is super set of above types
54
+ XTuple = Union[IntTuple, Shape, Stride, Coord, Tile]
55
+
56
+ Tiler = Union[Shape, Layout, Tile]
57
+
58
+
59
+ class Pointer(ABC):
60
+ """
61
+ Abstract base class for CuTe jit function and runtime _Pointer
62
+ """
63
+
64
+ @property
65
+ def value_type(self) -> Type[Numeric]:
66
+ return self.dtype
67
+
68
+ @property
69
+ def dtype(self) -> Type[Numeric]: ...
70
+
71
+ def align(self, min_align: int) -> "Pointer": ...
72
+
73
+ def __get_mlir_types__(self) -> List[ir.Type]: ...
74
+
75
+ def __extract_mlir_values__(self) -> List[ir.Value]: ...
76
+
77
+ def __new_from_mlir_values__(self, values) -> "Pointer": ...
78
+
79
+
80
+ class Tensor(ABC):
81
+ """
82
+ Abstract base class for CuTe jit function and runtime _Tensor
83
+
84
+ A CuTe Tensor is iterator with layout
85
+
86
+ :Examples:
87
+
88
+ Create tensor from torch.tensor with Host Runtime:
89
+
90
+ .. code-block:: python
91
+
92
+ >>> import torch
93
+ >>> from cutlass.cute.runtime import from_dlpack
94
+ >>> mA = from_dlpack(torch.tensor([1, 3, 5], dtype=torch.int32))
95
+ >>> mA.shape
96
+ (3,)
97
+ >>> mA.stride
98
+ (1,)
99
+ >>> mA.layout
100
+ (3,):(1,)
101
+
102
+ Define JIT function:
103
+
104
+ .. code-block:: python
105
+
106
+ @cute.jit
107
+ def add(a: Tensor, b: Tensor, res: Tensor): ...
108
+
109
+ Call JIT function from python:
110
+
111
+ .. code-block:: python
112
+
113
+ >>> import torch
114
+ >>> a = torch.tensor([1, 3, 5], dtype=torch.int32)
115
+ >>> b = torch.tensor([2, 4, 6], dtype=torch.int32)
116
+ >>> c = torch.zeros([3], dtype=torch.int32)
117
+ >>> mA = from_dlpack(a)
118
+ >>> mB = from_dlpack(b)
119
+ >>> mC = from_dlpack(c)
120
+ >>> add(mA, mB, mC)
121
+ >>> c
122
+ tensor([3, 7, 11], dtype=torch.int32)
123
+ """
124
+
125
+ def __str__(self): ...
126
+
127
+ @abstractmethod
128
+ def __getitem__(self, idx) -> Union["Tensor", ir.Value, IntTuple]: ...
129
+
130
+ @abstractmethod
131
+ def __setitem__(self, idx, value): ...
132
+
133
+ @property
134
+ @abstractmethod
135
+ def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]: ...
136
+
137
+ @element_type.setter
138
+ def element_type(self, new_type): ...
139
+
140
+ @property
141
+ @abstractmethod
142
+ def memspace(self) -> AddressSpace: ...
143
+
144
+ @property
145
+ @abstractmethod
146
+ def iterator(self): ...
147
+
148
+ @property
149
+ def layout(self) -> Union[Layout, "ComposedLayout"]: ...
150
+
151
+ @property
152
+ def shape(self) -> Shape: ...
153
+
154
+ def load(self, *, loc=None, ip=None) -> "TensorSSA": ...
155
+
156
+ def store(self, data: "TensorSSA", *, loc=None, ip=None): ...
157
+
158
+ def mark_layout_dynamic(self, leading_dim: int | None = None) -> "Tensor": ...
159
+
160
+ def mark_compact_shape_dynamic(
161
+ self,
162
+ mode: int,
163
+ stride_order: tuple[int, ...] | None = None,
164
+ divisibility: int = 1,
165
+ ) -> "Tensor": ...
166
+
167
+ @abstractmethod
168
+ def fill(self, value: Numeric) -> None: ...
169
+
170
+
171
+ __all__ = [
172
+ "Coord",
173
+ "Numeric",
174
+ "Integer",
175
+ "Boolean",
176
+ "Int8",
177
+ "Int16",
178
+ "Int32",
179
+ "Int64",
180
+ "Uint8",
181
+ "Uint16",
182
+ "Uint32",
183
+ "Uint64",
184
+ "Float",
185
+ "Float16",
186
+ "BFloat16",
187
+ "TFloat32",
188
+ "Float32",
189
+ "Float64",
190
+ "Float8E5M2",
191
+ "Float8E4M3FN",
192
+ "Float8E4M3B11FNUZ",
193
+ "Float8E4M3",
194
+ "Float8E8M0FNU",
195
+ "Float4E2M1FN",
196
+ "Float6E2M3FN",
197
+ "Float6E3M2FN",
198
+ "IntTuple",
199
+ "Layout",
200
+ "Pointer",
201
+ "Shape",
202
+ "Stride",
203
+ "Tensor",
204
+ "Tile",
205
+ "Tiler",
206
+ "XTuple",
207
+ ]