Build uploaded using `kernels` (batch 2/10).
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/op.py +34 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_helpers.py +616 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_preprocessor.py +1958 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/cache_helpers.py +153 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/common.py +268 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/compiler.py +288 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/dsl.py +1686 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/env_manager.py +320 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/jit_executor.py +357 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/__init__.py +25 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/cuda.py +476 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/device_tensor.py +121 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/dlpack_types.py +76 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py +188 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py +201 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/typing.py +1962 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/__init__.py +19 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/logger.py +81 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/stacktrace.py +165 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/timer.py +56 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/__init__.py +59 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/__init__.py +319 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/__init__.py +101 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/elect.py +84 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/mbar.py +349 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py +681 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/smem.py +108 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/tmem.py +142 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/core.py +0 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/math.py +445 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py +26 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/common.py +189 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py +39 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py +471 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py +341 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py +249 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py +62 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py +663 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py +328 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py +1041 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py +25 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py +189 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py +83 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py +29 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py +109 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py +405 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/runtime.py +510 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/testing.py +610 -0
- build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/typing.py +207 -0
.gitattributes
CHANGED
|
@@ -12,3 +12,4 @@ build/torch29-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lf
|
|
| 12 |
build/torch29-cxx11-cu129-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 13 |
build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 14 |
build/torch210-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 12 |
build/torch29-cxx11-cu129-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 13 |
build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 14 |
build/torch210-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
build/torch210-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/op.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This module provides MLIR's OP helper functions
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import inspect
|
| 18 |
+
from functools import wraps
|
| 19 |
+
|
| 20 |
+
from ..._mlir import ir
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def dsl_user_op(opFunc):
|
| 24 |
+
@wraps(opFunc)
|
| 25 |
+
def wrapper(*args, **kwargs):
|
| 26 |
+
loc = kwargs.pop("loc", None)
|
| 27 |
+
if loc is None:
|
| 28 |
+
frame = inspect.currentframe().f_back
|
| 29 |
+
file_loc = ir.Location.file(frame.f_code.co_filename, frame.f_lineno, 0)
|
| 30 |
+
loc = ir.Location.name(frame.f_code.co_name, childLoc=file_loc)
|
| 31 |
+
res_or_list = opFunc(*args, **kwargs, loc=loc)
|
| 32 |
+
return res_or_list
|
| 33 |
+
|
| 34 |
+
return wrapper
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_helpers.py
ADDED
|
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This module provides helper functions that are generated by the preprocessor.
|
| 14 |
+
The preprocessor read through python's ast and changes the input code.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from typing import Callable, Iterator, Optional, overload
|
| 18 |
+
from typing_extensions import deprecated
|
| 19 |
+
import warnings
|
| 20 |
+
import inspect
|
| 21 |
+
from types import BuiltinFunctionType
|
| 22 |
+
from functools import lru_cache
|
| 23 |
+
from inspect import getmembers
|
| 24 |
+
|
| 25 |
+
from .utils.logger import log
|
| 26 |
+
from .common import *
|
| 27 |
+
|
| 28 |
+
from ._mlir_helpers.arith import ArithValue
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Executor:
|
| 32 |
+
"""
|
| 33 |
+
The Executor class handles dynamic and compile-time (constexpr) execution
|
| 34 |
+
of "for" loops and "if-else-elif" statements.
|
| 35 |
+
|
| 36 |
+
Methods:
|
| 37 |
+
set_functions: Assigns the functions for checking loop bounds and
|
| 38 |
+
conditional evaluation.
|
| 39 |
+
|
| 40 |
+
for_execute: Generates MLIR for OP
|
| 41 |
+
while_execute: Generates MLIR while OP
|
| 42 |
+
if_execute: generate MLIR if OP
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self):
|
| 46 |
+
self._is_dynamic_expression = None
|
| 47 |
+
self._loop_execute_range_dynamic = None
|
| 48 |
+
self._if_dynamic = None
|
| 49 |
+
self._while_dynamic = None
|
| 50 |
+
self._compare_executor = None
|
| 51 |
+
self._any_executor = None
|
| 52 |
+
self._all_executor = None
|
| 53 |
+
self._builtin_redirector = None
|
| 54 |
+
|
| 55 |
+
def set_functions(
|
| 56 |
+
self,
|
| 57 |
+
*,
|
| 58 |
+
is_dynamic_expression: Callable,
|
| 59 |
+
loop_execute_range_dynamic: Callable,
|
| 60 |
+
if_dynamic: Callable,
|
| 61 |
+
while_dynamic: Callable,
|
| 62 |
+
compare_executor: Callable,
|
| 63 |
+
any_executor: Callable = None,
|
| 64 |
+
all_executor: Callable = None,
|
| 65 |
+
builtin_redirector: Callable = None,
|
| 66 |
+
):
|
| 67 |
+
self._is_dynamic_expression = is_dynamic_expression
|
| 68 |
+
self._loop_execute_range_dynamic = loop_execute_range_dynamic
|
| 69 |
+
self._if_dynamic = if_dynamic
|
| 70 |
+
self._while_dynamic = while_dynamic
|
| 71 |
+
self._compare_executor = compare_executor
|
| 72 |
+
self._any_executor = any_executor
|
| 73 |
+
self._all_executor = all_executor
|
| 74 |
+
self._builtin_redirector = builtin_redirector
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def convert_to_list(x):
|
| 78 |
+
"""This function is used to convert x to a list.
|
| 79 |
+
If x is None, return an empty list.
|
| 80 |
+
If x is not a list, return a list containing x.
|
| 81 |
+
Otherwise, return x itself.
|
| 82 |
+
"""
|
| 83 |
+
if x is None:
|
| 84 |
+
return []
|
| 85 |
+
if not isinstance(x, list):
|
| 86 |
+
return [x]
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
def converge_ret_val(res):
|
| 91 |
+
"""This function is used to converge res (the return value) of the function.
|
| 92 |
+
If res is None, return None.
|
| 93 |
+
If res is a list and has only one element, return the element.
|
| 94 |
+
Otherwise, return res itself.
|
| 95 |
+
"""
|
| 96 |
+
if res is None:
|
| 97 |
+
return res
|
| 98 |
+
elif isinstance(res, list) and len(res) == 1:
|
| 99 |
+
return res[0]
|
| 100 |
+
return res
|
| 101 |
+
|
| 102 |
+
def for_execute(
|
| 103 |
+
self,
|
| 104 |
+
func,
|
| 105 |
+
start,
|
| 106 |
+
stop,
|
| 107 |
+
step,
|
| 108 |
+
write_args=[],
|
| 109 |
+
full_write_args_count=0,
|
| 110 |
+
write_args_names=[],
|
| 111 |
+
unroll=-1,
|
| 112 |
+
unroll_full=False,
|
| 113 |
+
prefetch_stages=None,
|
| 114 |
+
):
|
| 115 |
+
assert (
|
| 116 |
+
self._loop_execute_range_dynamic
|
| 117 |
+
), "Functions must be set before execution."
|
| 118 |
+
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
|
| 119 |
+
|
| 120 |
+
return self._loop_execute_range_dynamic(
|
| 121 |
+
func,
|
| 122 |
+
start,
|
| 123 |
+
stop,
|
| 124 |
+
step,
|
| 125 |
+
write_args,
|
| 126 |
+
full_write_args_count,
|
| 127 |
+
write_args_names,
|
| 128 |
+
unroll,
|
| 129 |
+
unroll_full,
|
| 130 |
+
prefetch_stages,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def if_execute(
|
| 134 |
+
self,
|
| 135 |
+
pred,
|
| 136 |
+
then_block: Callable,
|
| 137 |
+
else_block: Optional[Callable] = None,
|
| 138 |
+
write_args=[],
|
| 139 |
+
full_write_args_count=0,
|
| 140 |
+
write_args_names=[],
|
| 141 |
+
):
|
| 142 |
+
assert self._if_dynamic, "Functions must be set before execution."
|
| 143 |
+
|
| 144 |
+
# MLIR generation
|
| 145 |
+
return self._if_dynamic(
|
| 146 |
+
pred,
|
| 147 |
+
then_block,
|
| 148 |
+
else_block,
|
| 149 |
+
write_args,
|
| 150 |
+
full_write_args_count,
|
| 151 |
+
write_args_names,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def while_execute(
|
| 155 |
+
self,
|
| 156 |
+
pred,
|
| 157 |
+
while_before_block: Callable,
|
| 158 |
+
while_after_block: Callable,
|
| 159 |
+
write_args=[],
|
| 160 |
+
full_write_args_count=0,
|
| 161 |
+
write_args_names=[],
|
| 162 |
+
):
|
| 163 |
+
assert self._while_dynamic, "Functions must be set before execution."
|
| 164 |
+
|
| 165 |
+
# MLIR generation
|
| 166 |
+
return self._while_dynamic(
|
| 167 |
+
while_before_block,
|
| 168 |
+
while_after_block,
|
| 169 |
+
write_args,
|
| 170 |
+
full_write_args_count,
|
| 171 |
+
write_args_names,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# =============================================================================
|
| 176 |
+
# Decorator
|
| 177 |
+
# =============================================================================
|
| 178 |
+
|
| 179 |
+
executor = Executor()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def loop_selector(
|
| 183 |
+
start,
|
| 184 |
+
stop,
|
| 185 |
+
step,
|
| 186 |
+
*,
|
| 187 |
+
write_args=[],
|
| 188 |
+
full_write_args_count=0,
|
| 189 |
+
write_args_names=[],
|
| 190 |
+
unroll=-1,
|
| 191 |
+
unroll_full=False,
|
| 192 |
+
prefetch_stages=None,
|
| 193 |
+
):
|
| 194 |
+
log().debug(
|
| 195 |
+
"start [%s] stop [%s] step [%s] write_args [%s] full_write_args_count [%s] write_args_names [%s] unroll [%s] unroll_full [%s] prefetch_stages [%s]",
|
| 196 |
+
start,
|
| 197 |
+
stop,
|
| 198 |
+
step,
|
| 199 |
+
write_args,
|
| 200 |
+
full_write_args_count,
|
| 201 |
+
write_args_names,
|
| 202 |
+
unroll,
|
| 203 |
+
unroll_full,
|
| 204 |
+
prefetch_stages,
|
| 205 |
+
)
|
| 206 |
+
from .typing import Integer, Numeric
|
| 207 |
+
|
| 208 |
+
def _maybe_upcast(value):
|
| 209 |
+
if isinstance(value, Integer):
|
| 210 |
+
value = value.ir_value()
|
| 211 |
+
|
| 212 |
+
return value
|
| 213 |
+
|
| 214 |
+
start = _maybe_upcast(start)
|
| 215 |
+
stop = _maybe_upcast(stop)
|
| 216 |
+
step = _maybe_upcast(step)
|
| 217 |
+
|
| 218 |
+
def ir_loop(func):
|
| 219 |
+
return executor.for_execute(
|
| 220 |
+
func,
|
| 221 |
+
start,
|
| 222 |
+
stop,
|
| 223 |
+
step,
|
| 224 |
+
write_args,
|
| 225 |
+
full_write_args_count,
|
| 226 |
+
write_args_names,
|
| 227 |
+
unroll,
|
| 228 |
+
unroll_full,
|
| 229 |
+
prefetch_stages,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
return ir_loop
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def if_selector(pred, write_args=[]):
|
| 236 |
+
log().debug("pred [%s] write_args [%s]", pred, write_args)
|
| 237 |
+
# Handle Numeric types here?
|
| 238 |
+
|
| 239 |
+
from .typing import Numeric
|
| 240 |
+
|
| 241 |
+
if isinstance(pred, Numeric):
|
| 242 |
+
pred = pred.value
|
| 243 |
+
|
| 244 |
+
def ir_loop(func):
|
| 245 |
+
return func(pred, *write_args)
|
| 246 |
+
|
| 247 |
+
return ir_loop
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def while_selector(pred, write_args=[]):
|
| 251 |
+
def ir_while_loop(func):
|
| 252 |
+
return func(pred, *write_args)
|
| 253 |
+
|
| 254 |
+
return ir_while_loop
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def while_executor(
|
| 258 |
+
pred,
|
| 259 |
+
while_before_block: Callable,
|
| 260 |
+
while_after_block: Callable,
|
| 261 |
+
write_args=[],
|
| 262 |
+
full_write_args_count=0,
|
| 263 |
+
write_args_names=[],
|
| 264 |
+
):
|
| 265 |
+
return executor.while_execute(
|
| 266 |
+
pred,
|
| 267 |
+
while_before_block,
|
| 268 |
+
while_after_block,
|
| 269 |
+
write_args,
|
| 270 |
+
full_write_args_count,
|
| 271 |
+
write_args_names,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def if_executor(
|
| 276 |
+
pred,
|
| 277 |
+
then_block: Callable,
|
| 278 |
+
else_block: Optional[Callable] = None,
|
| 279 |
+
write_args=[],
|
| 280 |
+
full_write_args_count=0,
|
| 281 |
+
write_args_names=[],
|
| 282 |
+
):
|
| 283 |
+
return executor.if_execute(
|
| 284 |
+
pred,
|
| 285 |
+
then_block,
|
| 286 |
+
else_block,
|
| 287 |
+
write_args,
|
| 288 |
+
full_write_args_count,
|
| 289 |
+
write_args_names,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# =============================================================================
|
| 294 |
+
# Range
|
| 295 |
+
# =============================================================================
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class range:
|
| 299 |
+
"""
|
| 300 |
+
A range-like object for dynamic loop iteration in the DSL.
|
| 301 |
+
|
| 302 |
+
This class provides a range interface similar to Python's built-in range,
|
| 303 |
+
but is designed to be preprocessed into constructs for dynamic
|
| 304 |
+
loop execution.
|
| 305 |
+
|
| 306 |
+
The class supports both single-argument (stop) and three-argument
|
| 307 |
+
(start, stop, step) constructors with additional parameters for loop
|
| 308 |
+
optimization:
|
| 309 |
+
|
| 310 |
+
- unroll: Number of iterations to unroll (0 or 1 = no unrolling)
|
| 311 |
+
- unroll_full: Whether to fully unroll the loop
|
| 312 |
+
- prefetch_stages: Number of prefetch stages to generate
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
@overload
|
| 316 |
+
def __new__(cls, stop, unroll=0, unroll_full=False, prefetch_stages=None):
|
| 317 |
+
pass
|
| 318 |
+
|
| 319 |
+
@overload
|
| 320 |
+
def __new__(
|
| 321 |
+
cls, start, stop, step, unroll=0, unroll_full=False, prefetch_stages=None
|
| 322 |
+
):
|
| 323 |
+
pass
|
| 324 |
+
|
| 325 |
+
def __new__(cls, *args, **kwargs):
|
| 326 |
+
raise DSLRuntimeError("dynamic range should be always preprocessed to IR")
|
| 327 |
+
|
| 328 |
+
def __iter__(self) -> Iterator[int]:
|
| 329 |
+
raise DSLRuntimeError("dynamic range should be always preprocessed to IR")
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@deprecated(
|
| 333 |
+
"range_dynamic is deprecated and will be removed in the future, please remove it."
|
| 334 |
+
)
|
| 335 |
+
def range_dynamic(*args, **kwargs):
|
| 336 |
+
raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def range_constexpr(*args):
|
| 340 |
+
raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.")
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# =============================================================================
|
| 344 |
+
# If expressions
|
| 345 |
+
# =============================================================================
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def const_expr(expression):
|
| 349 |
+
"""
|
| 350 |
+
This function is used to check if the expression is a python value.
|
| 351 |
+
If the expression is a python value, return the boolean value of the expression.
|
| 352 |
+
If the expression is a dynamic expression, raise an error.
|
| 353 |
+
"""
|
| 354 |
+
from .typing import Numeric
|
| 355 |
+
|
| 356 |
+
failed = False
|
| 357 |
+
|
| 358 |
+
if isinstance(expression, Numeric):
|
| 359 |
+
if isinstance(expression.value, (int, float, bool)):
|
| 360 |
+
return expression.value
|
| 361 |
+
else:
|
| 362 |
+
failed = True
|
| 363 |
+
elif executor._is_dynamic_expression(expression):
|
| 364 |
+
failed = True
|
| 365 |
+
|
| 366 |
+
if failed:
|
| 367 |
+
raise DSLRuntimeError(
|
| 368 |
+
f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).",
|
| 369 |
+
context={
|
| 370 |
+
"If your expression depends on dynamic values": "Remove `const_expr()`",
|
| 371 |
+
},
|
| 372 |
+
)
|
| 373 |
+
return expression
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
@deprecated(
|
| 377 |
+
"dynamic_expr is deprecated and will be removed in the future, please remove it."
|
| 378 |
+
)
|
| 379 |
+
def dynamic_expr(expression):
|
| 380 |
+
return expression
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
# =============================================================================
|
| 384 |
+
# Assertion & casting
|
| 385 |
+
# =============================================================================
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def assert_executor(test, msg=None):
|
| 389 |
+
from .typing import Numeric
|
| 390 |
+
|
| 391 |
+
fail = False
|
| 392 |
+
# Implicit convert dynamic expression to bool is not allowed
|
| 393 |
+
# So here explicitly do a None check
|
| 394 |
+
if test is not None and executor._is_dynamic_expression(test):
|
| 395 |
+
if isinstance(test, Numeric):
|
| 396 |
+
try:
|
| 397 |
+
test = test.to(bool)
|
| 398 |
+
except:
|
| 399 |
+
fail = True
|
| 400 |
+
else:
|
| 401 |
+
fail = True
|
| 402 |
+
|
| 403 |
+
if not fail:
|
| 404 |
+
assert test, msg
|
| 405 |
+
else:
|
| 406 |
+
raise DSLRuntimeError(
|
| 407 |
+
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
|
| 408 |
+
suggestion="Please replace with runtime assert.",
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def bool_cast(value):
|
| 413 |
+
if executor._is_dynamic_expression(value):
|
| 414 |
+
raise DSLRuntimeError(
|
| 415 |
+
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
|
| 416 |
+
suggestion="Please explicitly convert to boolean with expressions like comparision.",
|
| 417 |
+
)
|
| 418 |
+
return bool(value)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def compare_executor(left, comparators, ops):
|
| 422 |
+
"""
|
| 423 |
+
Executes comparison operations with a left operand and a list of comparators.
|
| 424 |
+
|
| 425 |
+
Args:
|
| 426 |
+
left: The leftmost value in the comparison chain
|
| 427 |
+
comparators: A list of values to compare against
|
| 428 |
+
ops: A list of comparison operators to apply
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
The result of the comparison chain
|
| 432 |
+
|
| 433 |
+
Raises:
|
| 434 |
+
AssertionError: If the executor function is not set before execution
|
| 435 |
+
"""
|
| 436 |
+
assert (
|
| 437 |
+
executor._compare_executor is not None
|
| 438 |
+
), "Function must be set before execution."
|
| 439 |
+
return executor._compare_executor(left, comparators, ops)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def any_executor(iterable):
|
| 443 |
+
"""Executes the 'any' operation on an iterable, handling both dynamic and static expressions.
|
| 444 |
+
|
| 445 |
+
:param iterable: An iterable to check if any elements evaluate to True
|
| 446 |
+
:type iterable: Iterable
|
| 447 |
+
:return: boolean of Python value or IR value
|
| 448 |
+
:rtype: bool or cutlass.Boolean
|
| 449 |
+
|
| 450 |
+
"""
|
| 451 |
+
if executor._any_executor and executor._is_dynamic_expression(iterable):
|
| 452 |
+
return executor._any_executor(iterable)
|
| 453 |
+
else:
|
| 454 |
+
return any(iterable)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def all_executor(iterable):
|
| 458 |
+
"""Executes the 'all' operation on an iterable, handling both dynamic and static expressions.
|
| 459 |
+
|
| 460 |
+
:param iterable: An iterable to check if all elements evaluate to True
|
| 461 |
+
:type iterable: Iterable
|
| 462 |
+
:return: boolean of Python value or IR value
|
| 463 |
+
:rtype: bool or cutlass.Boolean
|
| 464 |
+
"""
|
| 465 |
+
if executor._all_executor and executor._is_dynamic_expression(iterable):
|
| 466 |
+
return executor._all_executor(iterable)
|
| 467 |
+
else:
|
| 468 |
+
return all(iterable)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
# =============================================================================
|
| 472 |
+
# Control flow checks
|
| 473 |
+
# =============================================================================
|
| 474 |
+
class DSLOptimizationWarning(Warning):
|
| 475 |
+
"""
|
| 476 |
+
This warning is used to warn the user about the optimization related issues in DSL.
|
| 477 |
+
"""
|
| 478 |
+
|
| 479 |
+
def __init__(self, message):
|
| 480 |
+
self.message = message
|
| 481 |
+
super().__init__()
|
| 482 |
+
|
| 483 |
+
def __str__(self):
|
| 484 |
+
return self.message
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def range_value_check(*args):
|
| 488 |
+
"""
|
| 489 |
+
Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
|
| 490 |
+
"""
|
| 491 |
+
try:
|
| 492 |
+
args = tuple(arg.__index__() for arg in args)
|
| 493 |
+
|
| 494 |
+
# Compute range size and warn if it's too large
|
| 495 |
+
start = 0
|
| 496 |
+
end = 0
|
| 497 |
+
step = 1
|
| 498 |
+
if len(args) == 1:
|
| 499 |
+
end = args[0]
|
| 500 |
+
elif len(args) == 2:
|
| 501 |
+
start = args[0]
|
| 502 |
+
end = args[1]
|
| 503 |
+
elif len(args) == 3:
|
| 504 |
+
start = args[0]
|
| 505 |
+
end = args[1]
|
| 506 |
+
step = args[2]
|
| 507 |
+
|
| 508 |
+
range_length = (abs(end - start) - 1) // abs(step) + 1
|
| 509 |
+
if range_length >= 64:
|
| 510 |
+
warnings.warn(
|
| 511 |
+
f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.",
|
| 512 |
+
category=DSLOptimizationWarning,
|
| 513 |
+
stacklevel=2,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
return (start, end, step)
|
| 517 |
+
except:
|
| 518 |
+
raise DSLRuntimeError(
|
| 519 |
+
"`range_constexpr` requires constexpr (compile-time constant) for all arguments.",
|
| 520 |
+
suggestion="Use `range` instead of `range_constexpr`.",
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def range_perf_warning(filename, lineno, *args):
|
| 525 |
+
has_dynamic_expr = False
|
| 526 |
+
for arg in args:
|
| 527 |
+
if executor._is_dynamic_expression(arg):
|
| 528 |
+
has_dynamic_expr = True
|
| 529 |
+
break
|
| 530 |
+
if not has_dynamic_expr:
|
| 531 |
+
warnings.warn_explicit(
|
| 532 |
+
(
|
| 533 |
+
"This loop is no longer unrolled and may cause performance regression. "
|
| 534 |
+
"Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants."
|
| 535 |
+
),
|
| 536 |
+
category=DSLOptimizationWarning,
|
| 537 |
+
filename=filename,
|
| 538 |
+
lineno=lineno,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
@lru_cache(maxsize=1)
|
| 543 |
+
def _get_self_module():
|
| 544 |
+
"""
|
| 545 |
+
This function is used to get the owning module of this function.
|
| 546 |
+
"""
|
| 547 |
+
return inspect.getmodule(_get_self_module)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def cf_symbol_check(symbol):
|
| 551 |
+
"""
|
| 552 |
+
Check if the symbol is control flow symbol from current module.
|
| 553 |
+
"""
|
| 554 |
+
|
| 555 |
+
failed = False
|
| 556 |
+
name = symbol.__name__
|
| 557 |
+
self_module = _get_self_module()
|
| 558 |
+
if inspect.ismodule(symbol):
|
| 559 |
+
name = "range"
|
| 560 |
+
if not self_module.__name__.startswith(symbol.__name__):
|
| 561 |
+
failed = True
|
| 562 |
+
else:
|
| 563 |
+
owning_module = inspect.getmodule(symbol)
|
| 564 |
+
if owning_module != self_module:
|
| 565 |
+
failed = True
|
| 566 |
+
|
| 567 |
+
if failed:
|
| 568 |
+
raise DSLRuntimeError(
|
| 569 |
+
f"Incorrect {symbol.__name__} is used.",
|
| 570 |
+
suggestion=f"Please avoid overriding `{symbol.__name__}` from DSL package.",
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
def redirect_builtin_function(fcn):
|
| 575 |
+
"""
|
| 576 |
+
This function is used to redirect built-in function call
|
| 577 |
+
to the function defined in DSL package.
|
| 578 |
+
"""
|
| 579 |
+
# Only redirect if it's a built-in
|
| 580 |
+
if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector:
|
| 581 |
+
return executor._builtin_redirector(fcn)
|
| 582 |
+
return fcn
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def copy_members(dest, src):
|
| 586 |
+
"""
|
| 587 |
+
Copies all non-callable, non-dunder members from src to dest if they exist in src.
|
| 588 |
+
Skips members that are callables or have names starting with double underscores.
|
| 589 |
+
"""
|
| 590 |
+
if id(dest) == id(src):
|
| 591 |
+
return
|
| 592 |
+
|
| 593 |
+
members = getmembers(dest)
|
| 594 |
+
for name, value in members:
|
| 595 |
+
if (
|
| 596 |
+
name.startswith("__")
|
| 597 |
+
or isinstance(value, Callable)
|
| 598 |
+
or not hasattr(src, name)
|
| 599 |
+
):
|
| 600 |
+
continue
|
| 601 |
+
setattr(dest, name, getattr(src, name))
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def get_locals_or_none(locals, symbols):
|
| 605 |
+
"""
|
| 606 |
+
Given a locals() dictionary and a list of symbol names, return a list of their values
|
| 607 |
+
in the same order as the symbols list. If a symbol is not present in locals, None is returned
|
| 608 |
+
for that symbol.
|
| 609 |
+
"""
|
| 610 |
+
variables = []
|
| 611 |
+
for symbol in symbols:
|
| 612 |
+
if symbol in locals:
|
| 613 |
+
variables.append(locals[symbol])
|
| 614 |
+
else:
|
| 615 |
+
variables.append(None)
|
| 616 |
+
return variables
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_preprocessor.py
ADDED
|
@@ -0,0 +1,1958 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This module defines the `DSLPreprocessor` class, which acts as a Python preprocessor.
|
| 14 |
+
It uses Python's AST and rewrites specific Python statements such as `for` and `if-else`.
|
| 15 |
+
|
| 16 |
+
The preprocessor operates on the following constructs:
|
| 17 |
+
- `for` loops:
|
| 18 |
+
- Rewrites `for` loops with the `@loop_selector` decorator.
|
| 19 |
+
- Supports `range`, `range_dynamic` for loop iteration.
|
| 20 |
+
- `if-elif-else` statements:
|
| 21 |
+
- Rewrites conditional statements with the `@if_selector` decorator.
|
| 22 |
+
- Supports `dynamic_expr` and `const_expr` in the condition expressions.
|
| 23 |
+
|
| 24 |
+
Additionally, both `for` loops and `if-else` statements require `yield`
|
| 25 |
+
operation generation. The preprocessor handles this by:
|
| 26 |
+
- Using a `ScopeManager` to track symbols across different scopes during AST traversal.
|
| 27 |
+
- Identifying read-only, read-write, and active variables for DSL constructs.
|
| 28 |
+
- Generating `yield` operations for symbols that are classified as read-write or write.
|
| 29 |
+
|
| 30 |
+
It is designed to be generic and can handle `for` and `if` constructs from other dialects.
|
| 31 |
+
In such cases, the user's DSL should implement `@loop_selector` and `@if_selector`
|
| 32 |
+
to generate dialect-specific operations for `for` and `if` statements.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import ast
|
| 36 |
+
import importlib
|
| 37 |
+
import inspect
|
| 38 |
+
import textwrap
|
| 39 |
+
import warnings
|
| 40 |
+
from dataclasses import dataclass
|
| 41 |
+
from typing import List, Set, Dict, Any, Callable, Optional
|
| 42 |
+
from types import ModuleType
|
| 43 |
+
from collections import OrderedDict
|
| 44 |
+
from copy import deepcopy
|
| 45 |
+
|
| 46 |
+
from .common import *
|
| 47 |
+
from .utils.logger import log
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class OrderedSet:
|
| 51 |
+
"""
|
| 52 |
+
A deterministic set implementation for ordered operations.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, iterable=None):
|
| 56 |
+
self._dict = dict.fromkeys(iterable or [])
|
| 57 |
+
|
| 58 |
+
def add(self, item):
|
| 59 |
+
self._dict[item] = None
|
| 60 |
+
|
| 61 |
+
def __iter__(self):
|
| 62 |
+
return iter(self._dict)
|
| 63 |
+
|
| 64 |
+
def __and__(self, other):
|
| 65 |
+
return OrderedSet(key for key in self._dict if key in other)
|
| 66 |
+
|
| 67 |
+
def __or__(self, other):
|
| 68 |
+
new_dict = self._dict.copy()
|
| 69 |
+
new_dict.update(dict.fromkeys(other))
|
| 70 |
+
return OrderedSet(new_dict)
|
| 71 |
+
|
| 72 |
+
def __sub__(self, other):
|
| 73 |
+
return OrderedSet(key for key in self._dict if key not in other)
|
| 74 |
+
|
| 75 |
+
def intersections(self, others):
|
| 76 |
+
"""Compute the intersection of this set with multiple other sets.
|
| 77 |
+
|
| 78 |
+
:param others: A list of sets to compute intersections with
|
| 79 |
+
:type others: List[Set[str]]
|
| 80 |
+
:return: A new ordered set containing elements that appear in this set
|
| 81 |
+
and at least one of the other sets
|
| 82 |
+
"""
|
| 83 |
+
result = OrderedSet()
|
| 84 |
+
for key in self._dict:
|
| 85 |
+
for other in reversed(others):
|
| 86 |
+
if key in other:
|
| 87 |
+
result.add(key)
|
| 88 |
+
break
|
| 89 |
+
return result
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclass
|
| 93 |
+
class ImportInfo:
|
| 94 |
+
"""
|
| 95 |
+
Information about an import expression.
|
| 96 |
+
"""
|
| 97 |
+
module_path: str
|
| 98 |
+
attr_name: Optional[str]
|
| 99 |
+
alias_name: str
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@dataclass
|
| 103 |
+
class ScopeManager:
|
| 104 |
+
"""
|
| 105 |
+
Manages symbol scopes during AST traversal.
|
| 106 |
+
Manage nested scopes during transformations.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
scopes: List[Set[str]]
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def create(cls) -> "ScopeManager":
|
| 113 |
+
return cls([])
|
| 114 |
+
|
| 115 |
+
def add_to_scope(self, name: str) -> None:
|
| 116 |
+
if name == "_":
|
| 117 |
+
return
|
| 118 |
+
self.scopes[-1].add(name)
|
| 119 |
+
|
| 120 |
+
def get_active_symbols(self) -> List[Set[str]]:
|
| 121 |
+
return self.scopes.copy()
|
| 122 |
+
|
| 123 |
+
def __enter__(self) -> "ScopeManager":
|
| 124 |
+
self.scopes.append(set())
|
| 125 |
+
return self
|
| 126 |
+
|
| 127 |
+
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
| 128 |
+
self.scopes.pop()
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class DSLPreprocessor(ast.NodeTransformer):
|
| 132 |
+
"""
|
| 133 |
+
A preprocessor for transforming Python ASTs. It supports:
|
| 134 |
+
|
| 135 |
+
- Rewriting `for` loops with the `@loop_selector` decorator.
|
| 136 |
+
- Rewriting `if-elif-else` statements with the `@if_selector` decorator.
|
| 137 |
+
- Generating `yield` operations for read-write or write symbols.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
DECORATOR_FOR_STATEMENT = "loop_selector"
|
| 141 |
+
DECORATOR_IF_STATEMENT = "if_selector"
|
| 142 |
+
DECORATOR_WHILE_STATEMENT = "while_selector"
|
| 143 |
+
IF_EXECUTOR = "if_executor"
|
| 144 |
+
WHILE_EXECUTOR = "while_executor"
|
| 145 |
+
ASSERT_EXECUTOR = "assert_executor"
|
| 146 |
+
BOOL_CAST = "bool_cast"
|
| 147 |
+
IMPLICIT_DOWNCAST_NUMERIC_TYPE = "implicitDowncastNumericType"
|
| 148 |
+
SUPPORTED_FOR_RANGE_STATEMENTS = {"range", "range_dynamic", "range_constexpr"}
|
| 149 |
+
COMPARE_EXECUTOR = "compare_executor"
|
| 150 |
+
ANY_EXECUTOR = "any_executor"
|
| 151 |
+
ALL_EXECUTOR = "all_executor"
|
| 152 |
+
|
| 153 |
+
def __init__(self, client_module_name):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.counter = 0 # Unique function names for multiple loops
|
| 156 |
+
self.scope_manager = ScopeManager.create()
|
| 157 |
+
self.processed_functions = set()
|
| 158 |
+
self.function_counter = 0
|
| 159 |
+
self.function_name = "<unknown function>"
|
| 160 |
+
self.class_name = None
|
| 161 |
+
self.file_name = "<unknown filename>"
|
| 162 |
+
self.function_depth = 0
|
| 163 |
+
self.local_closures = set()
|
| 164 |
+
self.function_globals = None
|
| 165 |
+
self.client_module_name = client_module_name
|
| 166 |
+
self.import_top_module = False
|
| 167 |
+
|
| 168 |
+
def _create_module_attribute(
|
| 169 |
+
self,
|
| 170 |
+
func_name,
|
| 171 |
+
*,
|
| 172 |
+
top_module_name="_dsl_",
|
| 173 |
+
submodule_name="ast_helpers",
|
| 174 |
+
lineno=None,
|
| 175 |
+
col_offset=None,
|
| 176 |
+
):
|
| 177 |
+
# If we simply copy location from origin node, it contains a way to wide range, which cause location in traceback to be wrong.
|
| 178 |
+
def set_location(node, lineno, col_offset):
|
| 179 |
+
if lineno and col_offset:
|
| 180 |
+
node.lineno = lineno
|
| 181 |
+
node.end_lineno = lineno
|
| 182 |
+
node.col_offset = col_offset
|
| 183 |
+
node.end_col_offset = col_offset
|
| 184 |
+
|
| 185 |
+
base = ast.Name(id=top_module_name, ctx=ast.Load())
|
| 186 |
+
set_location(base, lineno, col_offset)
|
| 187 |
+
if submodule_name:
|
| 188 |
+
base = ast.Attribute(value=base, attr=submodule_name, ctx=ast.Load())
|
| 189 |
+
set_location(base, lineno, col_offset)
|
| 190 |
+
node = ast.Attribute(value=base, attr=func_name, ctx=ast.Load())
|
| 191 |
+
set_location(node, lineno, col_offset)
|
| 192 |
+
return node
|
| 193 |
+
|
| 194 |
+
def _get_module_imports(self, decorated_func):
|
| 195 |
+
"""Extract imports from the module containing the decorated function"""
|
| 196 |
+
imports = []
|
| 197 |
+
|
| 198 |
+
# Get the module containing the decorated function
|
| 199 |
+
if module := inspect.getmodule(decorated_func):
|
| 200 |
+
try:
|
| 201 |
+
# Get the module source code
|
| 202 |
+
source = inspect.getsource(module)
|
| 203 |
+
module_ast = ast.parse(source)
|
| 204 |
+
|
| 205 |
+
# Extract imports from the full module
|
| 206 |
+
alias = lambda n: n.asname if n.asname else n.name
|
| 207 |
+
for node in ast.walk(module_ast):
|
| 208 |
+
if isinstance(node, ast.Import):
|
| 209 |
+
for name in node.names:
|
| 210 |
+
imports.append(
|
| 211 |
+
ImportInfo(
|
| 212 |
+
module_path=name.name,
|
| 213 |
+
attr_name=None,
|
| 214 |
+
alias_name=alias(name),
|
| 215 |
+
)
|
| 216 |
+
)
|
| 217 |
+
elif isinstance(node, ast.ImportFrom):
|
| 218 |
+
module_name = node.module
|
| 219 |
+
if node.level > 0:
|
| 220 |
+
# Handle relative imports
|
| 221 |
+
package_name = module.__package__.rsplit(
|
| 222 |
+
".", node.level - 1
|
| 223 |
+
)[0]
|
| 224 |
+
module_name = f"{package_name}.{module_name}"
|
| 225 |
+
for name in node.names:
|
| 226 |
+
imports.append(
|
| 227 |
+
ImportInfo(
|
| 228 |
+
module_path=module_name,
|
| 229 |
+
attr_name=name.name,
|
| 230 |
+
alias_name=alias(name),
|
| 231 |
+
)
|
| 232 |
+
)
|
| 233 |
+
except (IOError, TypeError):
|
| 234 |
+
pass
|
| 235 |
+
|
| 236 |
+
return imports
|
| 237 |
+
|
| 238 |
+
def exec(self, function_name, original_function, code_object, exec_globals):
|
| 239 |
+
# Get imports from the original module
|
| 240 |
+
module_imports = self._get_module_imports(original_function)
|
| 241 |
+
|
| 242 |
+
# Import all required modules
|
| 243 |
+
for import_info in module_imports:
|
| 244 |
+
module_path, attr_name, alias_name = (
|
| 245 |
+
import_info.module_path,
|
| 246 |
+
import_info.attr_name,
|
| 247 |
+
import_info.alias_name,
|
| 248 |
+
)
|
| 249 |
+
try:
|
| 250 |
+
module = importlib.import_module(module_path)
|
| 251 |
+
if attr_name:
|
| 252 |
+
if attr_name == "*":
|
| 253 |
+
if hasattr(module, "__all__"):
|
| 254 |
+
attrs = module.__all__
|
| 255 |
+
else:
|
| 256 |
+
attrs = [
|
| 257 |
+
name for name in dir(module) if not name.startswith("_")
|
| 258 |
+
]
|
| 259 |
+
else:
|
| 260 |
+
attrs = [attr_name]
|
| 261 |
+
|
| 262 |
+
for attr in attrs:
|
| 263 |
+
alias = attr if attr_name == "*" else alias_name
|
| 264 |
+
exec_globals[alias] = getattr(module, attr)
|
| 265 |
+
else:
|
| 266 |
+
exec_globals[alias_name] = module
|
| 267 |
+
except (ImportError, AttributeError) as e:
|
| 268 |
+
raise ImportError(f"Failed to import {module_path}: {str(e)}")
|
| 269 |
+
|
| 270 |
+
# Execute the transformed code
|
| 271 |
+
log().info(
|
| 272 |
+
"ASTPreprocessor Executing transformed code for function [%s]",
|
| 273 |
+
function_name,
|
| 274 |
+
)
|
| 275 |
+
exec(code_object, exec_globals)
|
| 276 |
+
return exec_globals.get(function_name)
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def print_ast(transformed_tree=None):
|
| 280 |
+
print("#", "-" * 40, "Transformed AST", "-" * 40)
|
| 281 |
+
unparsed_code = ast.unparse(transformed_tree)
|
| 282 |
+
print(unparsed_code)
|
| 283 |
+
print("#", "-" * 40, "End Transformed AST", "-" * 40)
|
| 284 |
+
|
| 285 |
+
def make_func_param_name(self, base_name, used_names):
|
| 286 |
+
"""Generate a unique parameter name that doesn't collide with existing names."""
|
| 287 |
+
if base_name not in used_names:
|
| 288 |
+
return base_name
|
| 289 |
+
|
| 290 |
+
i = 0
|
| 291 |
+
while f"{base_name}_{i}" in used_names:
|
| 292 |
+
i += 1
|
| 293 |
+
return f"{base_name}_{i}"
|
| 294 |
+
|
| 295 |
+
def transform_function(self, func_name, function_pointer):
|
| 296 |
+
"""
|
| 297 |
+
Transforms a function.
|
| 298 |
+
"""
|
| 299 |
+
# Skip if the function has already been processed
|
| 300 |
+
if function_pointer in self.processed_functions:
|
| 301 |
+
log().info(
|
| 302 |
+
"ASTPreprocessor Skipping already processed function [%s]", func_name
|
| 303 |
+
)
|
| 304 |
+
return []
|
| 305 |
+
|
| 306 |
+
# Step 1. Parse the given function
|
| 307 |
+
file_name = inspect.getsourcefile(function_pointer)
|
| 308 |
+
lines, start_line = inspect.getsourcelines(function_pointer)
|
| 309 |
+
dedented_source = textwrap.dedent("".join(lines))
|
| 310 |
+
tree = ast.parse(dedented_source, filename=file_name)
|
| 311 |
+
# Bump the line numbers so they match the real source file
|
| 312 |
+
ast.increment_lineno(tree, start_line - 1)
|
| 313 |
+
|
| 314 |
+
# Step 1.2 Check the decorator
|
| 315 |
+
if not self.check_decorator(tree.body[0]):
|
| 316 |
+
log().info(
|
| 317 |
+
"[%s] - Skipping function due to missing decorator",
|
| 318 |
+
func_name,
|
| 319 |
+
)
|
| 320 |
+
return []
|
| 321 |
+
|
| 322 |
+
self.processed_functions.add(function_pointer)
|
| 323 |
+
log().info("ASTPreprocessor Transforming function [%s]", func_name)
|
| 324 |
+
|
| 325 |
+
# Step 2. Transform the function
|
| 326 |
+
transformed_tree = self.visit(tree)
|
| 327 |
+
|
| 328 |
+
# Step 3. Import cutlass and base_dsl
|
| 329 |
+
top_module_name = ".".join(self.client_module_name)
|
| 330 |
+
import_stmts = []
|
| 331 |
+
if self.import_top_module:
|
| 332 |
+
import_stmts.append(ast.Import(names=[ast.alias(name=top_module_name)]))
|
| 333 |
+
import_stmts.append(
|
| 334 |
+
ast.Import(
|
| 335 |
+
names=[ast.alias(name=f"{top_module_name}.base_dsl", asname="_dsl_")]
|
| 336 |
+
)
|
| 337 |
+
)
|
| 338 |
+
transformed_tree.body = import_stmts + transformed_tree.body
|
| 339 |
+
|
| 340 |
+
# Step 4. Import cutlass and base_dsl
|
| 341 |
+
ast.fix_missing_locations(transformed_tree)
|
| 342 |
+
combined_body = transformed_tree.body
|
| 343 |
+
|
| 344 |
+
# Step 5. Return the transformed tree
|
| 345 |
+
return combined_body
|
| 346 |
+
|
| 347 |
+
def check_early_exit(self, tree, kind):
|
| 348 |
+
"""
|
| 349 |
+
Checks if a given region or scope in the provided Python code has early exits.
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
class EarlyExitChecker(ast.NodeVisitor):
|
| 353 |
+
def __init__(self, kind):
|
| 354 |
+
self.has_early_exit = False
|
| 355 |
+
self.early_exit_node = None
|
| 356 |
+
self.early_exit_type = None
|
| 357 |
+
self.kind = kind
|
| 358 |
+
self.loop_nest_level = 0
|
| 359 |
+
|
| 360 |
+
# Early exit is not allowed in any level of dynamic control flow
|
| 361 |
+
def visit_Return(self, node):
|
| 362 |
+
self.has_early_exit = True
|
| 363 |
+
self.early_exit_node = node
|
| 364 |
+
self.early_exit_type = "return"
|
| 365 |
+
|
| 366 |
+
def visit_Raise(self, node):
|
| 367 |
+
self.has_early_exit = True
|
| 368 |
+
self.early_exit_node = node
|
| 369 |
+
self.early_exit_type = "raise"
|
| 370 |
+
|
| 371 |
+
def visit_Break(self, node):
|
| 372 |
+
# For break/continue in inner loops, we don't consider it as early exit
|
| 373 |
+
if self.loop_nest_level == 0 and self.kind != "if":
|
| 374 |
+
self.has_early_exit = True
|
| 375 |
+
self.early_exit_node = node
|
| 376 |
+
self.early_exit_type = "break"
|
| 377 |
+
|
| 378 |
+
def visit_Continue(self, node):
|
| 379 |
+
if self.loop_nest_level == 0 and self.kind != "if":
|
| 380 |
+
self.has_early_exit = True
|
| 381 |
+
self.early_exit_node = node
|
| 382 |
+
self.early_exit_type = "continue"
|
| 383 |
+
|
| 384 |
+
def visit_For(self, node):
|
| 385 |
+
self.loop_nest_level += 1
|
| 386 |
+
self.generic_visit(node)
|
| 387 |
+
self.loop_nest_level -= 1
|
| 388 |
+
|
| 389 |
+
def visit_While(self, node):
|
| 390 |
+
self.loop_nest_level += 1
|
| 391 |
+
self.generic_visit(node)
|
| 392 |
+
self.loop_nest_level -= 1
|
| 393 |
+
|
| 394 |
+
checker = EarlyExitChecker(kind)
|
| 395 |
+
checker.generic_visit(tree)
|
| 396 |
+
if not checker.has_early_exit:
|
| 397 |
+
return
|
| 398 |
+
raise DSLAstPreprocessorError(
|
| 399 |
+
message=f"Early exit ({checker.early_exit_type}) is not allowed in `{self.function_name}`"
|
| 400 |
+
+ (f" in `{self.class_name}`" if self.class_name else ""),
|
| 401 |
+
filename=self.file_name,
|
| 402 |
+
snippet=ast.unparse(tree),
|
| 403 |
+
suggestion=(
|
| 404 |
+
"If predicates are constant expression, write like "
|
| 405 |
+
"`if const_expr(...)` or `for ... in range_constexpr(...)`. "
|
| 406 |
+
"In that case, early exit will be executed by Python "
|
| 407 |
+
"interpreter, so it's supported."
|
| 408 |
+
),
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
def is_node_constexpr(self, node) -> bool:
|
| 412 |
+
"""
|
| 413 |
+
Determines if the node is a constexpr.
|
| 414 |
+
Supported nodes are if, while statements.
|
| 415 |
+
"""
|
| 416 |
+
if isinstance(node, ast.If) or isinstance(node, ast.While):
|
| 417 |
+
if isinstance(node.test, ast.Call):
|
| 418 |
+
func = node.test.func
|
| 419 |
+
|
| 420 |
+
if isinstance(func, ast.Attribute) and func.attr == "const_expr":
|
| 421 |
+
return True
|
| 422 |
+
|
| 423 |
+
elif isinstance(func, ast.Name) and func.id == "const_expr":
|
| 424 |
+
return True
|
| 425 |
+
return False
|
| 426 |
+
|
| 427 |
+
def _get_range_kind(self, iter_node):
|
| 428 |
+
"""
|
| 429 |
+
Return "range", "range_dynamic", "range_constexpr" or None for the iterable
|
| 430 |
+
"""
|
| 431 |
+
if isinstance(iter_node, ast.Call):
|
| 432 |
+
func = iter_node.func
|
| 433 |
+
if (
|
| 434 |
+
isinstance(func, ast.Name)
|
| 435 |
+
and func.id in self.SUPPORTED_FOR_RANGE_STATEMENTS
|
| 436 |
+
):
|
| 437 |
+
return func.id, True, len(iter_node.keywords) != 0
|
| 438 |
+
if (
|
| 439 |
+
isinstance(func, ast.Attribute)
|
| 440 |
+
and func.attr in self.SUPPORTED_FOR_RANGE_STATEMENTS
|
| 441 |
+
):
|
| 442 |
+
return func.attr, False, len(iter_node.keywords) != 0
|
| 443 |
+
return None, None, None
|
| 444 |
+
|
| 445 |
+
def transform(self, original_function, exec_globals):
|
| 446 |
+
"""
|
| 447 |
+
Transforms the provided function using the preprocessor.
|
| 448 |
+
"""
|
| 449 |
+
self.file_name = inspect.getsourcefile(original_function)
|
| 450 |
+
self.function_globals = exec_globals
|
| 451 |
+
transformed_tree = self.transform_function(
|
| 452 |
+
original_function.__name__, original_function
|
| 453 |
+
)
|
| 454 |
+
self.function_globals = None
|
| 455 |
+
unified_tree = ast.Module(body=transformed_tree, type_ignores=[])
|
| 456 |
+
unified_tree = ast.fix_missing_locations(unified_tree)
|
| 457 |
+
|
| 458 |
+
return unified_tree
|
| 459 |
+
|
| 460 |
+
def analyze_region_variables(
|
| 461 |
+
self, node: Union[ast.For, ast.If], active_symbols: List[Set[str]]
|
| 462 |
+
):
|
| 463 |
+
"""
|
| 464 |
+
Analyze variables in different code regions to identify read-only, write-only,
|
| 465 |
+
and active variables for DSL constructs.
|
| 466 |
+
"""
|
| 467 |
+
|
| 468 |
+
# we need orderedset to keep the insertion order the same. otherwise generated IR is different each time
|
| 469 |
+
write_args = OrderedSet()
|
| 470 |
+
invoked_args = OrderedSet()
|
| 471 |
+
local_closure = self.local_closures
|
| 472 |
+
file_name = self.file_name
|
| 473 |
+
region_node = node
|
| 474 |
+
|
| 475 |
+
class RegionAnalyzer(ast.NodeVisitor):
|
| 476 |
+
force_store = False
|
| 477 |
+
|
| 478 |
+
def visit_Name(self, node):
|
| 479 |
+
"""
|
| 480 |
+
Mark every store as write.
|
| 481 |
+
"""
|
| 482 |
+
if isinstance(node.ctx, ast.Store) or self.force_store:
|
| 483 |
+
write_args.add(node.id)
|
| 484 |
+
|
| 485 |
+
def visit_Subscript(self, node):
|
| 486 |
+
# When subscript occurs on the lhs of an assignment, the `Name` is still a load, but `Subscript` is marked as `Store`.
|
| 487 |
+
# We need to force the store for the `Name` to be marked as write.
|
| 488 |
+
if isinstance(node.ctx, ast.Store):
|
| 489 |
+
self.force_store = True
|
| 490 |
+
self.visit(node.value)
|
| 491 |
+
self.force_store = False
|
| 492 |
+
self.visit(node.slice)
|
| 493 |
+
else:
|
| 494 |
+
self.generic_visit(node)
|
| 495 |
+
|
| 496 |
+
def visit_Assign(self, node):
|
| 497 |
+
self.force_store = True
|
| 498 |
+
[self.visit(target) for target in node.targets]
|
| 499 |
+
self.force_store = False
|
| 500 |
+
self.visit(node.value)
|
| 501 |
+
|
| 502 |
+
def visit_AugAssign(self, node):
|
| 503 |
+
self.force_store = True
|
| 504 |
+
self.visit(node.target)
|
| 505 |
+
self.force_store = False
|
| 506 |
+
self.visit(node.value)
|
| 507 |
+
|
| 508 |
+
@staticmethod
|
| 509 |
+
def get_call_base(func_node):
|
| 510 |
+
if isinstance(func_node, ast.Attribute):
|
| 511 |
+
# If the .value is another Attribute, keep digging
|
| 512 |
+
if isinstance(func_node.value, ast.Attribute):
|
| 513 |
+
return RegionAnalyzer.get_call_base(func_node.value)
|
| 514 |
+
# If the .value is a Name, that's our base
|
| 515 |
+
elif isinstance(func_node.value, ast.Name):
|
| 516 |
+
return func_node.value.id
|
| 517 |
+
else:
|
| 518 |
+
# Could be something else (lambda, call, etc.)
|
| 519 |
+
return None
|
| 520 |
+
elif isinstance(func_node, ast.Name):
|
| 521 |
+
return None
|
| 522 |
+
return None
|
| 523 |
+
|
| 524 |
+
@staticmethod
|
| 525 |
+
def get_function_name(func_node: ast.Call):
|
| 526 |
+
if isinstance(func_node.func, ast.Name):
|
| 527 |
+
function_name = func_node.func.id
|
| 528 |
+
# Check if it's a method or attribute call
|
| 529 |
+
elif isinstance(func_node.func, ast.Attribute):
|
| 530 |
+
function_name = func_node.func.attr
|
| 531 |
+
else:
|
| 532 |
+
function_name = None
|
| 533 |
+
return function_name
|
| 534 |
+
|
| 535 |
+
def visit_Call(self, node):
|
| 536 |
+
base_name = RegionAnalyzer.get_call_base(node.func)
|
| 537 |
+
|
| 538 |
+
if isinstance(node.func, ast.Name):
|
| 539 |
+
func_name = node.func.id
|
| 540 |
+
if func_name in local_closure:
|
| 541 |
+
raise DSLAstPreprocessorError(
|
| 542 |
+
f"Function `{func_name}` is a closure and is not supported in for/if statements",
|
| 543 |
+
filename=file_name,
|
| 544 |
+
snippet=ast.unparse(region_node),
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
# Classes are mutable by default. Mark them as write. If they are
|
| 548 |
+
# dataclass(frozen=True), treat them as read in runtime.
|
| 549 |
+
if base_name is not None and base_name not in ("self"):
|
| 550 |
+
invoked_args.add(base_name)
|
| 551 |
+
|
| 552 |
+
self.generic_visit(node)
|
| 553 |
+
|
| 554 |
+
analyzer = RegionAnalyzer()
|
| 555 |
+
analyzer.visit(ast.Module(body=node))
|
| 556 |
+
|
| 557 |
+
# If arg is both write and invoke, remove from invoked_args
|
| 558 |
+
invoked_args = invoked_args - write_args
|
| 559 |
+
|
| 560 |
+
write_args = list(write_args.intersections(active_symbols))
|
| 561 |
+
invoked_args = list(invoked_args.intersections(active_symbols))
|
| 562 |
+
|
| 563 |
+
return write_args + invoked_args, len(write_args)
|
| 564 |
+
|
| 565 |
+
def extract_range_args(self, iter_node):
|
| 566 |
+
args = iter_node.args
|
| 567 |
+
if len(args) == 1:
|
| 568 |
+
return (
|
| 569 |
+
self.visit(ast.Constant(value=0)),
|
| 570 |
+
self.visit(args[0]),
|
| 571 |
+
self.visit(ast.Constant(value=1)),
|
| 572 |
+
False,
|
| 573 |
+
)
|
| 574 |
+
elif len(args) == 2:
|
| 575 |
+
return (
|
| 576 |
+
self.visit(args[0]),
|
| 577 |
+
self.visit(args[1]),
|
| 578 |
+
self.visit(ast.Constant(value=1)),
|
| 579 |
+
False,
|
| 580 |
+
)
|
| 581 |
+
elif len(args) == 3:
|
| 582 |
+
return self.visit(args[0]), self.visit(args[1]), self.visit(args[2]), True
|
| 583 |
+
else:
|
| 584 |
+
raise DSLAstPreprocessorError(
|
| 585 |
+
"Unsupported number of arguments in range", filename=self.file_name
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
def extract_unroll_args(self, iter_node):
|
| 589 |
+
keywords = {kw.arg: kw.value for kw in iter_node.keywords}
|
| 590 |
+
return (
|
| 591 |
+
keywords.get("unroll", ast.Constant(value=-1)),
|
| 592 |
+
keywords.get("unroll_full", ast.Constant(value=False)),
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
def issue_deprecation_warning(self, *, message, category, filename, lineno):
|
| 596 |
+
warnings.simplefilter("always", category) # turn off filter
|
| 597 |
+
warnings.warn_explicit(
|
| 598 |
+
message, category=category, filename=filename, lineno=lineno
|
| 599 |
+
)
|
| 600 |
+
warnings.simplefilter("default", category) # reset filter
|
| 601 |
+
|
| 602 |
+
def extract_prefetch_stages_args(self, iter_node):
|
| 603 |
+
keywords = {kw.arg: kw.value for kw in iter_node.keywords}
|
| 604 |
+
if "pipelining" in keywords:
|
| 605 |
+
self.issue_deprecation_warning(
|
| 606 |
+
message="pipelining is deprecated, use prefetch_stages instead",
|
| 607 |
+
category=DeprecationWarning,
|
| 608 |
+
filename=self.file_name,
|
| 609 |
+
lineno=iter_node.lineno,
|
| 610 |
+
)
|
| 611 |
+
return keywords.get("pipelining", ast.Constant(value=None))
|
| 612 |
+
return keywords.get("prefetch_stages", ast.Constant(value=None))
|
| 613 |
+
|
| 614 |
+
def create_loop_function(
|
| 615 |
+
self,
|
| 616 |
+
func_name,
|
| 617 |
+
node,
|
| 618 |
+
start,
|
| 619 |
+
stop,
|
| 620 |
+
step,
|
| 621 |
+
unroll,
|
| 622 |
+
unroll_full,
|
| 623 |
+
prefetch_stages,
|
| 624 |
+
write_args,
|
| 625 |
+
full_write_args_count,
|
| 626 |
+
):
|
| 627 |
+
"""
|
| 628 |
+
Creates a loop body function with the `loop_selector` decorator.
|
| 629 |
+
"""
|
| 630 |
+
|
| 631 |
+
func_args = [ast.arg(arg=node.target.id, annotation=None)]
|
| 632 |
+
func_args += [ast.arg(arg=var, annotation=None) for var in write_args]
|
| 633 |
+
|
| 634 |
+
# Create the loop body
|
| 635 |
+
transformed_body = []
|
| 636 |
+
for stmt in node.body:
|
| 637 |
+
transformed_stmt = self.visit(stmt) # Recursively visit inner statements
|
| 638 |
+
if isinstance(transformed_stmt, list):
|
| 639 |
+
transformed_body.extend(transformed_stmt)
|
| 640 |
+
else:
|
| 641 |
+
transformed_body.append(transformed_stmt)
|
| 642 |
+
|
| 643 |
+
# Handle the return for a single iterated argument correctly
|
| 644 |
+
if len(write_args) == 0:
|
| 645 |
+
transformed_body.append(ast.Return())
|
| 646 |
+
else:
|
| 647 |
+
transformed_body.append(
|
| 648 |
+
ast.Return(
|
| 649 |
+
value=ast.List(
|
| 650 |
+
elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args],
|
| 651 |
+
ctx=ast.Load(),
|
| 652 |
+
)
|
| 653 |
+
)
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
# Define the decorator with parameters
|
| 657 |
+
decorator = ast.copy_location(
|
| 658 |
+
ast.Call(
|
| 659 |
+
func=self._create_module_attribute(
|
| 660 |
+
self.DECORATOR_FOR_STATEMENT,
|
| 661 |
+
lineno=node.lineno,
|
| 662 |
+
col_offset=node.col_offset,
|
| 663 |
+
),
|
| 664 |
+
args=[start, stop, step],
|
| 665 |
+
keywords=[
|
| 666 |
+
ast.keyword(arg="unroll", value=unroll),
|
| 667 |
+
ast.keyword(arg="unroll_full", value=unroll_full),
|
| 668 |
+
ast.keyword(arg="prefetch_stages", value=prefetch_stages),
|
| 669 |
+
ast.keyword(
|
| 670 |
+
arg="write_args",
|
| 671 |
+
value=self.generate_get_locals_or_none_call(write_args),
|
| 672 |
+
),
|
| 673 |
+
ast.keyword(
|
| 674 |
+
arg="full_write_args_count",
|
| 675 |
+
value=ast.Constant(value=full_write_args_count),
|
| 676 |
+
),
|
| 677 |
+
ast.keyword(
|
| 678 |
+
arg="write_args_names",
|
| 679 |
+
value=ast.List(
|
| 680 |
+
elts=[ast.Constant(value=arg) for arg in write_args],
|
| 681 |
+
ctx=ast.Load(),
|
| 682 |
+
),
|
| 683 |
+
),
|
| 684 |
+
],
|
| 685 |
+
),
|
| 686 |
+
node,
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
return ast.copy_location(
|
| 690 |
+
ast.FunctionDef(
|
| 691 |
+
name=func_name,
|
| 692 |
+
args=ast.arguments(
|
| 693 |
+
posonlyargs=[],
|
| 694 |
+
args=func_args,
|
| 695 |
+
kwonlyargs=[],
|
| 696 |
+
kw_defaults=[],
|
| 697 |
+
defaults=[],
|
| 698 |
+
),
|
| 699 |
+
body=transformed_body,
|
| 700 |
+
decorator_list=[decorator],
|
| 701 |
+
),
|
| 702 |
+
node,
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
def visit_BoolOp(self, node):
|
| 706 |
+
# Visit child nodes first
|
| 707 |
+
self.generic_visit(node)
|
| 708 |
+
|
| 709 |
+
# It is necessary to expand short circuit evaluation explicit here
|
| 710 |
+
# Although we do not support inline if-else for IR generation, this is actually evaluated in Python
|
| 711 |
+
# So it's fine here
|
| 712 |
+
# Transform "and" to "and_"
|
| 713 |
+
if isinstance(node.op, ast.And):
|
| 714 |
+
# Create an if-else statement in AST form
|
| 715 |
+
# if type(lhs) == bool and lhs == False:
|
| 716 |
+
# return lhs
|
| 717 |
+
# else
|
| 718 |
+
# return and_(lhs, rhs)
|
| 719 |
+
short_circuit_value = ast.Constant(value=False)
|
| 720 |
+
helper_func = self._create_module_attribute(
|
| 721 |
+
"and_",
|
| 722 |
+
top_module_name="cutlass",
|
| 723 |
+
submodule_name=None,
|
| 724 |
+
lineno=node.lineno,
|
| 725 |
+
col_offset=node.col_offset,
|
| 726 |
+
)
|
| 727 |
+
self.import_top_module = True
|
| 728 |
+
# Transform "or" to "or_"
|
| 729 |
+
elif isinstance(node.op, ast.Or):
|
| 730 |
+
# Create an if-else statement in AST form
|
| 731 |
+
# if type(lhs) == bool and lhs == True:
|
| 732 |
+
# return lhs
|
| 733 |
+
# else
|
| 734 |
+
# return or_(lhs, rhs)
|
| 735 |
+
short_circuit_value = ast.Constant(value=True)
|
| 736 |
+
helper_func = self._create_module_attribute(
|
| 737 |
+
"or_",
|
| 738 |
+
top_module_name="cutlass",
|
| 739 |
+
submodule_name=None,
|
| 740 |
+
lineno=node.lineno,
|
| 741 |
+
col_offset=node.col_offset,
|
| 742 |
+
)
|
| 743 |
+
self.import_top_module = True
|
| 744 |
+
else:
|
| 745 |
+
# BoolOp should be either And or Or
|
| 746 |
+
raise DSLAstPreprocessorError(
|
| 747 |
+
f"Unsupported boolean operation: {node.op}",
|
| 748 |
+
filename=self.file_name,
|
| 749 |
+
snippet=ast.unparse(node),
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
def short_circuit_eval(value, short_circuit_value):
|
| 753 |
+
return ast.BoolOp(
|
| 754 |
+
op=ast.And(),
|
| 755 |
+
values=[
|
| 756 |
+
ast.Compare(
|
| 757 |
+
left=ast.Call(
|
| 758 |
+
func=ast.Name(id="type", ctx=ast.Load()),
|
| 759 |
+
args=[value],
|
| 760 |
+
keywords=[],
|
| 761 |
+
),
|
| 762 |
+
ops=[ast.Eq()],
|
| 763 |
+
comparators=[ast.Name(id="bool", ctx=ast.Load())],
|
| 764 |
+
),
|
| 765 |
+
ast.Compare(
|
| 766 |
+
left=value,
|
| 767 |
+
ops=[ast.Eq()],
|
| 768 |
+
comparators=[short_circuit_value],
|
| 769 |
+
),
|
| 770 |
+
],
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
lhs = node.values[0]
|
| 774 |
+
|
| 775 |
+
for i in range(1, len(node.values)):
|
| 776 |
+
test = short_circuit_eval(lhs, short_circuit_value)
|
| 777 |
+
lhs = ast.IfExp(
|
| 778 |
+
test=test,
|
| 779 |
+
body=lhs,
|
| 780 |
+
orelse=ast.Call(
|
| 781 |
+
func=helper_func,
|
| 782 |
+
args=[lhs, node.values[i]],
|
| 783 |
+
keywords=[],
|
| 784 |
+
),
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
return ast.copy_location(lhs, node)
|
| 788 |
+
|
| 789 |
+
def visit_UnaryOp(self, node):
|
| 790 |
+
# Visit child nodes first
|
| 791 |
+
self.generic_visit(node)
|
| 792 |
+
|
| 793 |
+
# Transform "not" to "~" as we overload __invert__
|
| 794 |
+
if isinstance(node.op, ast.Not):
|
| 795 |
+
func_name = self._create_module_attribute(
|
| 796 |
+
"not_",
|
| 797 |
+
top_module_name="cutlass",
|
| 798 |
+
submodule_name=None,
|
| 799 |
+
lineno=node.lineno,
|
| 800 |
+
col_offset=node.col_offset,
|
| 801 |
+
)
|
| 802 |
+
self.import_top_module = True
|
| 803 |
+
return ast.copy_location(
|
| 804 |
+
ast.Call(func=func_name, args=[node.operand], keywords=[]), node
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
return node
|
| 808 |
+
|
| 809 |
+
def _insert_range_value_check(self, node):
|
| 810 |
+
"""
|
| 811 |
+
Insert a check for range arguments
|
| 812 |
+
"""
|
| 813 |
+
range_inputs = node.iter.args
|
| 814 |
+
check_call = ast.copy_location(
|
| 815 |
+
ast.Call(
|
| 816 |
+
func=self._create_module_attribute(
|
| 817 |
+
"range_value_check", lineno=node.lineno, col_offset=node.col_offset
|
| 818 |
+
),
|
| 819 |
+
args=range_inputs,
|
| 820 |
+
keywords=[],
|
| 821 |
+
),
|
| 822 |
+
node.iter,
|
| 823 |
+
)
|
| 824 |
+
node.iter = ast.copy_location(
|
| 825 |
+
ast.Call(
|
| 826 |
+
func=ast.Name(id="range", ctx=ast.Load()),
|
| 827 |
+
args=[ast.Starred(value=check_call, ctx=ast.Load())],
|
| 828 |
+
keywords=[],
|
| 829 |
+
),
|
| 830 |
+
node.iter,
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
def _insert_cf_symbol_check(self, func):
|
| 834 |
+
"""
|
| 835 |
+
Insert a check for range symbol
|
| 836 |
+
"""
|
| 837 |
+
check_call = ast.copy_location(
|
| 838 |
+
ast.Call(
|
| 839 |
+
func=self._create_module_attribute(
|
| 840 |
+
"cf_symbol_check", lineno=func.lineno, col_offset=func.col_offset
|
| 841 |
+
),
|
| 842 |
+
args=[deepcopy(func)],
|
| 843 |
+
keywords=[],
|
| 844 |
+
),
|
| 845 |
+
func,
|
| 846 |
+
)
|
| 847 |
+
return ast.Expr(check_call)
|
| 848 |
+
|
| 849 |
+
def visit_For(self, node):
|
| 850 |
+
# For static for loop (for with range_constexpr or not range based for), preprocessor keeps the loop.
|
| 851 |
+
range_kind, is_builtin_range, has_keyword = self._get_range_kind(node.iter)
|
| 852 |
+
if range_kind == "range_constexpr" or range_kind == None:
|
| 853 |
+
self.generic_visit(node)
|
| 854 |
+
if range_kind == "range_constexpr":
|
| 855 |
+
check_call = self._insert_cf_symbol_check(node.iter.func)
|
| 856 |
+
# Rewrite range_constexpr to range
|
| 857 |
+
node.iter.func = ast.Name(id="range", ctx=ast.Load())
|
| 858 |
+
self._insert_range_value_check(node)
|
| 859 |
+
return [check_call, node]
|
| 860 |
+
return node
|
| 861 |
+
|
| 862 |
+
active_symbols = self.scope_manager.get_active_symbols()
|
| 863 |
+
|
| 864 |
+
with self.scope_manager:
|
| 865 |
+
if isinstance(node.target, ast.Name):
|
| 866 |
+
self.scope_manager.add_to_scope(node.target.id)
|
| 867 |
+
|
| 868 |
+
if range_kind == "range_dynamic":
|
| 869 |
+
# Generate a warning
|
| 870 |
+
self.issue_deprecation_warning(
|
| 871 |
+
message="range_dynamic is deprecated and will be removed in the future, please remove it.",
|
| 872 |
+
category=DeprecationWarning,
|
| 873 |
+
filename=self.file_name,
|
| 874 |
+
lineno=node.iter.lineno,
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
warning_call = None
|
| 878 |
+
if range_kind == "range" and is_builtin_range and not has_keyword:
|
| 879 |
+
# Warn about possible performance regression due to behavior change
|
| 880 |
+
warning_call = ast.Expr(
|
| 881 |
+
ast.Call(
|
| 882 |
+
func=self._create_module_attribute(
|
| 883 |
+
"range_perf_warning",
|
| 884 |
+
lineno=node.lineno,
|
| 885 |
+
col_offset=node.col_offset,
|
| 886 |
+
),
|
| 887 |
+
args=[
|
| 888 |
+
ast.Constant(value=self.file_name),
|
| 889 |
+
ast.Constant(value=node.iter.lineno),
|
| 890 |
+
]
|
| 891 |
+
+ node.iter.args,
|
| 892 |
+
keywords=[],
|
| 893 |
+
)
|
| 894 |
+
)
|
| 895 |
+
ast.copy_location(warning_call, node.iter)
|
| 896 |
+
|
| 897 |
+
is_prefixed_range = range_kind == "range" and not is_builtin_range
|
| 898 |
+
check_call = None
|
| 899 |
+
if range_kind == "range_dynamic" or is_prefixed_range:
|
| 900 |
+
# Insert a check for range symbol
|
| 901 |
+
if not is_prefixed_range:
|
| 902 |
+
check_call = self._insert_cf_symbol_check(node.iter.func)
|
| 903 |
+
else:
|
| 904 |
+
# Get toplevel module
|
| 905 |
+
check_call = self._insert_cf_symbol_check(node.iter.func.value)
|
| 906 |
+
|
| 907 |
+
new_for_node = self.transform_for_loop(node, active_symbols)
|
| 908 |
+
if check_call is not None:
|
| 909 |
+
new_for_node = [check_call] + new_for_node
|
| 910 |
+
|
| 911 |
+
return new_for_node if warning_call is None else [warning_call] + new_for_node
|
| 912 |
+
|
| 913 |
+
@staticmethod
|
| 914 |
+
def _hoist_expr_to_assignments(expr, name):
|
| 915 |
+
return ast.copy_location(
|
| 916 |
+
ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=expr), expr
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
def _build_select_and_assign(self, *, name, test, body, orelse, location):
|
| 920 |
+
node = ast.copy_location(
|
| 921 |
+
ast.Assign(
|
| 922 |
+
targets=[ast.Name(id=name, ctx=ast.Store())],
|
| 923 |
+
value=ast.IfExp(
|
| 924 |
+
test=test,
|
| 925 |
+
body=body,
|
| 926 |
+
orelse=orelse,
|
| 927 |
+
),
|
| 928 |
+
),
|
| 929 |
+
location,
|
| 930 |
+
)
|
| 931 |
+
self.generic_visit(node)
|
| 932 |
+
return node
|
| 933 |
+
|
| 934 |
+
def _handle_negative_step(self, node, start_expr, stop_expr, step_expr):
|
| 935 |
+
# hoist start, stop, step to assignments
|
| 936 |
+
start_ori_name = f"start_ori_{self.counter}"
|
| 937 |
+
start = self._hoist_expr_to_assignments(start_expr, start_ori_name)
|
| 938 |
+
stop_ori_name = f"stop_ori_{self.counter}"
|
| 939 |
+
stop = self._hoist_expr_to_assignments(stop_expr, stop_ori_name)
|
| 940 |
+
step_ori_name = f"step_ori_{self.counter}"
|
| 941 |
+
step = self._hoist_expr_to_assignments(step_expr, step_ori_name)
|
| 942 |
+
|
| 943 |
+
extra_exprs = [start, stop, step]
|
| 944 |
+
|
| 945 |
+
# Handle possible negative step, generates the following code in Python:
|
| 946 |
+
# isNegative = step < 0
|
| 947 |
+
isNegative_name = f"isNegative_{self.counter}"
|
| 948 |
+
isNegative = ast.copy_location(
|
| 949 |
+
ast.Assign(
|
| 950 |
+
targets=[ast.Name(id=isNegative_name, ctx=ast.Store())],
|
| 951 |
+
value=ast.Compare(
|
| 952 |
+
left=ast.Name(id=step_ori_name, ctx=ast.Load()),
|
| 953 |
+
ops=[ast.Lt()],
|
| 954 |
+
comparators=[ast.Constant(value=0)],
|
| 955 |
+
),
|
| 956 |
+
),
|
| 957 |
+
step,
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
# start = stop if isNegative else start
|
| 961 |
+
start_name = f"start_{self.counter}"
|
| 962 |
+
start = self._build_select_and_assign(
|
| 963 |
+
name=start_name,
|
| 964 |
+
test=ast.Name(id=isNegative_name, ctx=ast.Load()),
|
| 965 |
+
body=ast.Name(id=stop_ori_name, ctx=ast.Load()),
|
| 966 |
+
orelse=ast.Name(id=start_ori_name, ctx=ast.Load()),
|
| 967 |
+
location=start,
|
| 968 |
+
)
|
| 969 |
+
|
| 970 |
+
# stop = start if isNegative else stop
|
| 971 |
+
stop_name = f"stop_{self.counter}"
|
| 972 |
+
stop = self._build_select_and_assign(
|
| 973 |
+
name=stop_name,
|
| 974 |
+
test=ast.Name(id=isNegative_name, ctx=ast.Load()),
|
| 975 |
+
body=ast.Name(id=start_ori_name, ctx=ast.Load()),
|
| 976 |
+
orelse=ast.Name(id=stop_ori_name, ctx=ast.Load()),
|
| 977 |
+
location=stop,
|
| 978 |
+
)
|
| 979 |
+
|
| 980 |
+
# step = -step if isNegative else step
|
| 981 |
+
step_name = f"step_{self.counter}"
|
| 982 |
+
step = self._build_select_and_assign(
|
| 983 |
+
name=step_name,
|
| 984 |
+
test=ast.Name(id=isNegative_name, ctx=ast.Load()),
|
| 985 |
+
body=ast.UnaryOp(
|
| 986 |
+
op=ast.USub(), operand=ast.Name(id=step_ori_name, ctx=ast.Load())
|
| 987 |
+
),
|
| 988 |
+
orelse=ast.Name(id=step_ori_name, ctx=ast.Load()),
|
| 989 |
+
location=step,
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
# offset = start + stop if isNegative else 0
|
| 993 |
+
offset_name = f"offset_{self.counter}"
|
| 994 |
+
offset = self._build_select_and_assign(
|
| 995 |
+
name=offset_name,
|
| 996 |
+
test=ast.Name(id=isNegative_name, ctx=ast.Load()),
|
| 997 |
+
body=ast.BinOp(
|
| 998 |
+
op=ast.Add(),
|
| 999 |
+
left=ast.Name(id=start_name, ctx=ast.Load()),
|
| 1000 |
+
right=ast.Name(id=stop_name, ctx=ast.Load()),
|
| 1001 |
+
),
|
| 1002 |
+
orelse=ast.Constant(value=0),
|
| 1003 |
+
location=node,
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
extra_exprs.append(isNegative)
|
| 1007 |
+
extra_exprs.append(start)
|
| 1008 |
+
extra_exprs.append(stop)
|
| 1009 |
+
extra_exprs.append(step)
|
| 1010 |
+
extra_exprs.append(offset)
|
| 1011 |
+
|
| 1012 |
+
# Add this to begining of loop body
|
| 1013 |
+
# for i in range(start, stop, step):
|
| 1014 |
+
# i = offset - i if isNegative else i
|
| 1015 |
+
assert isinstance(node.target, ast.Name)
|
| 1016 |
+
|
| 1017 |
+
target_name = node.target.id
|
| 1018 |
+
target = self._build_select_and_assign(
|
| 1019 |
+
name=target_name,
|
| 1020 |
+
test=ast.Name(id=isNegative_name, ctx=ast.Load()),
|
| 1021 |
+
body=ast.BinOp(
|
| 1022 |
+
op=ast.Sub(),
|
| 1023 |
+
left=ast.Name(id=offset_name, ctx=ast.Load()),
|
| 1024 |
+
right=ast.Name(id=target_name, ctx=ast.Load()),
|
| 1025 |
+
),
|
| 1026 |
+
orelse=ast.Name(id=target_name, ctx=ast.Load()),
|
| 1027 |
+
location=node.target,
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
node.body.insert(0, target)
|
| 1031 |
+
|
| 1032 |
+
return (
|
| 1033 |
+
ast.Name(id=start_name, ctx=ast.Load()),
|
| 1034 |
+
ast.Name(id=stop_name, ctx=ast.Load()),
|
| 1035 |
+
ast.Name(id=step_name, ctx=ast.Load()),
|
| 1036 |
+
extra_exprs,
|
| 1037 |
+
)
|
| 1038 |
+
|
| 1039 |
+
def transform_for_loop(self, node, active_symbols):
|
| 1040 |
+
# Check for early exit and raise exception
|
| 1041 |
+
self.check_early_exit(node, "for")
|
| 1042 |
+
if node.orelse:
|
| 1043 |
+
raise DSLAstPreprocessorError(
|
| 1044 |
+
"dynamic for loop with else is not supported",
|
| 1045 |
+
filename=self.file_name,
|
| 1046 |
+
snippet=ast.unparse(node),
|
| 1047 |
+
)
|
| 1048 |
+
|
| 1049 |
+
# Get loop target variable name
|
| 1050 |
+
target_var_name = None
|
| 1051 |
+
target_var_is_active_before_loop = False
|
| 1052 |
+
if isinstance(node.target, ast.Name):
|
| 1053 |
+
target_var_name = node.target.id
|
| 1054 |
+
for active_symbol in active_symbols:
|
| 1055 |
+
if target_var_name in active_symbol:
|
| 1056 |
+
target_var_is_active_before_loop = True
|
| 1057 |
+
active_symbols.remove(active_symbol)
|
| 1058 |
+
break
|
| 1059 |
+
|
| 1060 |
+
# Add necessary exprs to handle this
|
| 1061 |
+
if target_var_is_active_before_loop:
|
| 1062 |
+
# Initialize an extra loop carried variable
|
| 1063 |
+
loop_carried_var_name = f"loop_carried_var_{self.counter}"
|
| 1064 |
+
pre_loop_expr = ast.copy_location(
|
| 1065 |
+
ast.Assign(
|
| 1066 |
+
targets=[ast.Name(id=loop_carried_var_name, ctx=ast.Store())],
|
| 1067 |
+
value=ast.Name(id=target_var_name, ctx=ast.Load()),
|
| 1068 |
+
),
|
| 1069 |
+
node,
|
| 1070 |
+
)
|
| 1071 |
+
# append an extra assignment to the loop carried variable
|
| 1072 |
+
node.body.append(
|
| 1073 |
+
ast.copy_location(
|
| 1074 |
+
ast.Assign(
|
| 1075 |
+
targets=[ast.Name(id=loop_carried_var_name, ctx=ast.Store())],
|
| 1076 |
+
value=ast.Name(id=target_var_name, ctx=ast.Load()),
|
| 1077 |
+
),
|
| 1078 |
+
node,
|
| 1079 |
+
)
|
| 1080 |
+
)
|
| 1081 |
+
active_symbols.append({loop_carried_var_name})
|
| 1082 |
+
|
| 1083 |
+
start_expr, stop_expr, step_expr, has_step = self.extract_range_args(node.iter)
|
| 1084 |
+
unroll, unroll_full = self.extract_unroll_args(node.iter)
|
| 1085 |
+
prefetch_stages = self.extract_prefetch_stages_args(node.iter)
|
| 1086 |
+
write_args, full_write_args_count = self.analyze_region_variables(
|
| 1087 |
+
node, active_symbols
|
| 1088 |
+
)
|
| 1089 |
+
|
| 1090 |
+
if has_step and self.client_module_name[0] == "cutlass":
|
| 1091 |
+
start, stop, step, exprs = self._handle_negative_step(
|
| 1092 |
+
node, start_expr, stop_expr, step_expr
|
| 1093 |
+
)
|
| 1094 |
+
else:
|
| 1095 |
+
start, stop, step, exprs = start_expr, stop_expr, step_expr, []
|
| 1096 |
+
|
| 1097 |
+
if target_var_is_active_before_loop:
|
| 1098 |
+
exprs.append(pre_loop_expr)
|
| 1099 |
+
|
| 1100 |
+
func_name = f"loop_body_{self.counter}"
|
| 1101 |
+
self.counter += 1
|
| 1102 |
+
|
| 1103 |
+
func_def = self.create_loop_function(
|
| 1104 |
+
func_name,
|
| 1105 |
+
node,
|
| 1106 |
+
start,
|
| 1107 |
+
stop,
|
| 1108 |
+
step,
|
| 1109 |
+
unroll,
|
| 1110 |
+
unroll_full,
|
| 1111 |
+
prefetch_stages,
|
| 1112 |
+
write_args,
|
| 1113 |
+
full_write_args_count,
|
| 1114 |
+
)
|
| 1115 |
+
|
| 1116 |
+
assign = self.create_cf_call(func_name, write_args, node)
|
| 1117 |
+
|
| 1118 |
+
# This should work fine as it modifies the AST structure
|
| 1119 |
+
exprs = exprs + [func_def] + assign
|
| 1120 |
+
|
| 1121 |
+
if target_var_is_active_before_loop:
|
| 1122 |
+
# Create a new assignment to the target variable
|
| 1123 |
+
exprs.append(
|
| 1124 |
+
ast.copy_location(
|
| 1125 |
+
ast.Assign(
|
| 1126 |
+
targets=[ast.Name(id=target_var_name, ctx=ast.Store())],
|
| 1127 |
+
value=ast.Name(id=loop_carried_var_name, ctx=ast.Load()),
|
| 1128 |
+
),
|
| 1129 |
+
node,
|
| 1130 |
+
)
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
return exprs
|
| 1134 |
+
|
| 1135 |
+
def visit_Assert(self, node):
|
| 1136 |
+
test = self.visit(node.test)
|
| 1137 |
+
|
| 1138 |
+
args = [ast.keyword(arg="test", value=test)]
|
| 1139 |
+
if node.msg:
|
| 1140 |
+
msg = self.visit(node.msg)
|
| 1141 |
+
args.append(ast.keyword(arg="msg", value=msg))
|
| 1142 |
+
|
| 1143 |
+
# Rewrite to assert_executor(test, msg)
|
| 1144 |
+
new_node = ast.Expr(
|
| 1145 |
+
ast.Call(
|
| 1146 |
+
func=self._create_module_attribute(
|
| 1147 |
+
self.ASSERT_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset
|
| 1148 |
+
),
|
| 1149 |
+
args=[],
|
| 1150 |
+
keywords=args,
|
| 1151 |
+
)
|
| 1152 |
+
)
|
| 1153 |
+
|
| 1154 |
+
# Propagate line number from original node to new node
|
| 1155 |
+
ast.copy_location(new_node, node)
|
| 1156 |
+
return new_node
|
| 1157 |
+
|
| 1158 |
+
def visit_Call(self, node):
|
| 1159 |
+
func = node.func
|
| 1160 |
+
# Visit args and kwargs
|
| 1161 |
+
node.args = [self.visit(arg) for arg in node.args]
|
| 1162 |
+
node.keywords = [self.visit(kwarg) for kwarg in node.keywords]
|
| 1163 |
+
|
| 1164 |
+
# Rewrite call to some built-in functions
|
| 1165 |
+
if isinstance(func, ast.Name):
|
| 1166 |
+
# Check if the function is 'bool'
|
| 1167 |
+
if func.id == "bool":
|
| 1168 |
+
return ast.copy_location(
|
| 1169 |
+
ast.Call(
|
| 1170 |
+
func=self._create_module_attribute(
|
| 1171 |
+
self.BOOL_CAST,
|
| 1172 |
+
lineno=node.lineno,
|
| 1173 |
+
col_offset=node.col_offset,
|
| 1174 |
+
),
|
| 1175 |
+
args=[node.args[0]],
|
| 1176 |
+
keywords=[],
|
| 1177 |
+
),
|
| 1178 |
+
node,
|
| 1179 |
+
)
|
| 1180 |
+
elif func.id in ["any", "all"]:
|
| 1181 |
+
helper_func = (
|
| 1182 |
+
self.ANY_EXECUTOR if func.id == "any" else self.ALL_EXECUTOR
|
| 1183 |
+
)
|
| 1184 |
+
return ast.copy_location(
|
| 1185 |
+
ast.Call(
|
| 1186 |
+
func=self._create_module_attribute(
|
| 1187 |
+
helper_func, lineno=node.lineno, col_offset=node.col_offset
|
| 1188 |
+
),
|
| 1189 |
+
args=[node.args[0]],
|
| 1190 |
+
keywords=[],
|
| 1191 |
+
),
|
| 1192 |
+
node,
|
| 1193 |
+
)
|
| 1194 |
+
elif func.id in ["min", "max"]:
|
| 1195 |
+
return ast.copy_location(
|
| 1196 |
+
ast.Call(
|
| 1197 |
+
func=self._create_module_attribute(
|
| 1198 |
+
func.id,
|
| 1199 |
+
top_module_name="cutlass",
|
| 1200 |
+
submodule_name=None,
|
| 1201 |
+
lineno=node.lineno,
|
| 1202 |
+
col_offset=node.col_offset,
|
| 1203 |
+
),
|
| 1204 |
+
args=[node.args[0], node.args[1]],
|
| 1205 |
+
keywords=[],
|
| 1206 |
+
),
|
| 1207 |
+
node,
|
| 1208 |
+
)
|
| 1209 |
+
elif isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
|
| 1210 |
+
def create_downcast_call(arg):
|
| 1211 |
+
return ast.copy_location(
|
| 1212 |
+
ast.Call(
|
| 1213 |
+
func=self._create_module_attribute(
|
| 1214 |
+
self.IMPLICIT_DOWNCAST_NUMERIC_TYPE,
|
| 1215 |
+
submodule_name="typing",
|
| 1216 |
+
lineno=node.lineno,
|
| 1217 |
+
col_offset=node.col_offset,
|
| 1218 |
+
),
|
| 1219 |
+
args=[arg],
|
| 1220 |
+
keywords=[],
|
| 1221 |
+
),
|
| 1222 |
+
arg,
|
| 1223 |
+
)
|
| 1224 |
+
module = self.function_globals.get(func.value.id)
|
| 1225 |
+
if isinstance(module, ModuleType) and module.__package__.endswith(
|
| 1226 |
+
"._mlir.dialects"
|
| 1227 |
+
):
|
| 1228 |
+
# Check if argument is Numeric, if so, call ir_value()
|
| 1229 |
+
args = []
|
| 1230 |
+
for arg in node.args:
|
| 1231 |
+
args.append(create_downcast_call(arg))
|
| 1232 |
+
kwargs = []
|
| 1233 |
+
for kwarg in node.keywords:
|
| 1234 |
+
kwargs.append(
|
| 1235 |
+
ast.copy_location(
|
| 1236 |
+
ast.keyword(
|
| 1237 |
+
arg=kwarg.arg,
|
| 1238 |
+
value=create_downcast_call(kwarg.value),
|
| 1239 |
+
),
|
| 1240 |
+
kwarg,
|
| 1241 |
+
)
|
| 1242 |
+
)
|
| 1243 |
+
return ast.copy_location(
|
| 1244 |
+
ast.Call(func=func, args=args, keywords=kwargs), node
|
| 1245 |
+
)
|
| 1246 |
+
else:
|
| 1247 |
+
node.func = self.visit(node.func)
|
| 1248 |
+
|
| 1249 |
+
return node
|
| 1250 |
+
|
| 1251 |
+
def visit_ClassDef(self, node):
|
| 1252 |
+
self.class_name = node.name
|
| 1253 |
+
self.generic_visit(node)
|
| 1254 |
+
self.class_name = None
|
| 1255 |
+
return node
|
| 1256 |
+
|
| 1257 |
+
def _visit_target(self, target):
|
| 1258 |
+
if isinstance(target, ast.Name):
|
| 1259 |
+
self.scope_manager.add_to_scope(target.id)
|
| 1260 |
+
elif isinstance(target, ast.Tuple):
|
| 1261 |
+
for t in target.elts:
|
| 1262 |
+
if isinstance(t, ast.Name):
|
| 1263 |
+
self.scope_manager.add_to_scope(t.id)
|
| 1264 |
+
|
| 1265 |
+
def visit_Assign(self, node):
|
| 1266 |
+
for target in node.targets:
|
| 1267 |
+
self._visit_target(target)
|
| 1268 |
+
self.generic_visit(node)
|
| 1269 |
+
return node
|
| 1270 |
+
|
| 1271 |
+
def visit_AugAssign(self, node):
|
| 1272 |
+
self._visit_target(node.target)
|
| 1273 |
+
self.generic_visit(node)
|
| 1274 |
+
return node
|
| 1275 |
+
|
| 1276 |
+
def visit_Name(self, node):
|
| 1277 |
+
isLoad = isinstance(node.ctx, ast.Load)
|
| 1278 |
+
if node.id in ["max", "min", "any", "all"] and isLoad:
|
| 1279 |
+
return ast.copy_location(
|
| 1280 |
+
ast.Call(
|
| 1281 |
+
func=self._create_module_attribute(
|
| 1282 |
+
"redirect_builtin_function",
|
| 1283 |
+
lineno=node.lineno,
|
| 1284 |
+
col_offset=node.col_offset,
|
| 1285 |
+
),
|
| 1286 |
+
args=[node],
|
| 1287 |
+
keywords=[],
|
| 1288 |
+
),
|
| 1289 |
+
node,
|
| 1290 |
+
)
|
| 1291 |
+
elif node.id == "_" and isLoad:
|
| 1292 |
+
raise DSLAstPreprocessorError("Read '_' is not allowed")
|
| 1293 |
+
else:
|
| 1294 |
+
self.generic_visit(node)
|
| 1295 |
+
return node
|
| 1296 |
+
|
| 1297 |
+
def check_decorator(self, node: ast.AST) -> bool:
|
| 1298 |
+
"""
|
| 1299 |
+
Check if the function has the correct decorator for preprocessing.
|
| 1300 |
+
"""
|
| 1301 |
+
if not isinstance(node, ast.FunctionDef):
|
| 1302 |
+
return False
|
| 1303 |
+
decorator_list = node.decorator_list
|
| 1304 |
+
if len(decorator_list) == 0:
|
| 1305 |
+
return False
|
| 1306 |
+
|
| 1307 |
+
for d in decorator_list:
|
| 1308 |
+
if isinstance(d, ast.Call):
|
| 1309 |
+
if isinstance(d.func, ast.Attribute):
|
| 1310 |
+
if d.func.attr in ["jit", "kernel"]:
|
| 1311 |
+
if d.keywords == []:
|
| 1312 |
+
return True
|
| 1313 |
+
for keyword in d.keywords:
|
| 1314 |
+
if keyword.arg == "preprocess":
|
| 1315 |
+
try:
|
| 1316 |
+
if isinstance(keyword.value, ast.Constant):
|
| 1317 |
+
return keyword.value.value
|
| 1318 |
+
else:
|
| 1319 |
+
return ast.literal_eval(keyword.value)
|
| 1320 |
+
except:
|
| 1321 |
+
pass
|
| 1322 |
+
|
| 1323 |
+
elif isinstance(d, ast.Attribute):
|
| 1324 |
+
if d.attr in ["jit", "kernel"]:
|
| 1325 |
+
return True
|
| 1326 |
+
|
| 1327 |
+
return False
|
| 1328 |
+
|
| 1329 |
+
def remove_dsl_decorator(self, decorator_list):
|
| 1330 |
+
"""
|
| 1331 |
+
Remove .jit and .kernel decorators
|
| 1332 |
+
The decorator can be in two forms:
|
| 1333 |
+
- @jit(...)
|
| 1334 |
+
- @jit
|
| 1335 |
+
"""
|
| 1336 |
+
new_decorator_list = []
|
| 1337 |
+
decorator_names = ["jit", "kernel"]
|
| 1338 |
+
for d in decorator_list:
|
| 1339 |
+
is_jit_or_kernel = False
|
| 1340 |
+
if isinstance(d, ast.Call):
|
| 1341 |
+
if isinstance(d.func, ast.Attribute):
|
| 1342 |
+
if d.func.attr in decorator_names:
|
| 1343 |
+
is_jit_or_kernel = True
|
| 1344 |
+
elif isinstance(d, ast.Attribute):
|
| 1345 |
+
if d.attr in decorator_names:
|
| 1346 |
+
is_jit_or_kernel = True
|
| 1347 |
+
|
| 1348 |
+
if not is_jit_or_kernel:
|
| 1349 |
+
new_decorator_list.append(d)
|
| 1350 |
+
return new_decorator_list
|
| 1351 |
+
|
| 1352 |
+
def visit_FunctionDef(self, node):
|
| 1353 |
+
with self.scope_manager:
|
| 1354 |
+
self.function_counter += 1
|
| 1355 |
+
self.function_name = node.name
|
| 1356 |
+
if self.function_depth > 0:
|
| 1357 |
+
self.local_closures.add(node.name)
|
| 1358 |
+
|
| 1359 |
+
self.function_depth += 1
|
| 1360 |
+
|
| 1361 |
+
# Add function name and arguments
|
| 1362 |
+
self.scope_manager.add_to_scope(node.name)
|
| 1363 |
+
for arg in node.args.args:
|
| 1364 |
+
self.scope_manager.add_to_scope(arg.arg)
|
| 1365 |
+
|
| 1366 |
+
self.generic_visit(node)
|
| 1367 |
+
|
| 1368 |
+
self.function_depth -= 1
|
| 1369 |
+
|
| 1370 |
+
# Remove .jit and .kernel decorators
|
| 1371 |
+
node.decorator_list = self.remove_dsl_decorator(node.decorator_list)
|
| 1372 |
+
return node
|
| 1373 |
+
|
| 1374 |
+
def visit_With(self, node):
|
| 1375 |
+
with self.scope_manager:
|
| 1376 |
+
for item in node.items:
|
| 1377 |
+
if isinstance(item.optional_vars, ast.Name):
|
| 1378 |
+
self.scope_manager.add_to_scope(item.optional_vars.id)
|
| 1379 |
+
self.generic_visit(node)
|
| 1380 |
+
|
| 1381 |
+
return node
|
| 1382 |
+
|
| 1383 |
+
def visit_While(self, node):
|
| 1384 |
+
# Constexpr doesn't get preprocessed
|
| 1385 |
+
if self.is_node_constexpr(node):
|
| 1386 |
+
self.generic_visit(node)
|
| 1387 |
+
check = self._insert_cf_symbol_check(node.test.func)
|
| 1388 |
+
return [check, node]
|
| 1389 |
+
|
| 1390 |
+
active_symbols = self.scope_manager.get_active_symbols()
|
| 1391 |
+
|
| 1392 |
+
with self.scope_manager:
|
| 1393 |
+
# Check for early exit and raise exception
|
| 1394 |
+
self.check_early_exit(node, "while")
|
| 1395 |
+
|
| 1396 |
+
write_args, full_write_args_count = self.analyze_region_variables(
|
| 1397 |
+
node, active_symbols
|
| 1398 |
+
)
|
| 1399 |
+
func_name = f"while_region_{self.counter}"
|
| 1400 |
+
self.counter += 1
|
| 1401 |
+
|
| 1402 |
+
func_def = self.create_while_function(
|
| 1403 |
+
func_name, node, write_args, full_write_args_count
|
| 1404 |
+
)
|
| 1405 |
+
assign = self.create_cf_call(func_name, write_args, node)
|
| 1406 |
+
|
| 1407 |
+
return [func_def] + assign
|
| 1408 |
+
|
| 1409 |
+
def visit_Try(self, node):
|
| 1410 |
+
with self.scope_manager:
|
| 1411 |
+
self.generic_visit(node)
|
| 1412 |
+
return node
|
| 1413 |
+
|
| 1414 |
+
def visit_ExceptHandler(self, node):
|
| 1415 |
+
with self.scope_manager:
|
| 1416 |
+
if node.name: # Exception variable
|
| 1417 |
+
self.scope_manager.add_to_scope(node.name)
|
| 1418 |
+
self.generic_visit(node)
|
| 1419 |
+
return node
|
| 1420 |
+
|
| 1421 |
+
def create_cf_call(self, func_name, yield_args, node):
|
| 1422 |
+
"""Creates the assignment statement for the if function call"""
|
| 1423 |
+
if not yield_args:
|
| 1424 |
+
return [
|
| 1425 |
+
ast.copy_location(
|
| 1426 |
+
ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load())), node
|
| 1427 |
+
)
|
| 1428 |
+
]
|
| 1429 |
+
has_self = False
|
| 1430 |
+
for i, arg in enumerate(yield_args):
|
| 1431 |
+
if arg == "self":
|
| 1432 |
+
has_self = True
|
| 1433 |
+
yield_args[i] = "yield_self"
|
| 1434 |
+
break
|
| 1435 |
+
if len(yield_args) == 1:
|
| 1436 |
+
assign = ast.Assign(
|
| 1437 |
+
targets=[ast.Name(id=yield_args[0], ctx=ast.Store())],
|
| 1438 |
+
value=ast.Name(id=func_name, ctx=ast.Load()),
|
| 1439 |
+
)
|
| 1440 |
+
else:
|
| 1441 |
+
assign = ast.Assign(
|
| 1442 |
+
targets=[
|
| 1443 |
+
ast.Tuple(
|
| 1444 |
+
elts=[ast.Name(id=var, ctx=ast.Store()) for var in yield_args],
|
| 1445 |
+
ctx=ast.Store(),
|
| 1446 |
+
)
|
| 1447 |
+
],
|
| 1448 |
+
value=ast.Name(id=func_name, ctx=ast.Load()),
|
| 1449 |
+
)
|
| 1450 |
+
|
| 1451 |
+
if has_self:
|
| 1452 |
+
fix_self = ast.Expr(
|
| 1453 |
+
value=ast.Call(
|
| 1454 |
+
func=self._create_module_attribute(
|
| 1455 |
+
"copy_members", lineno=node.lineno, col_offset=node.col_offset
|
| 1456 |
+
),
|
| 1457 |
+
args=[
|
| 1458 |
+
ast.Name(id="self", ctx=ast.Load()),
|
| 1459 |
+
ast.Name(id="yield_self", ctx=ast.Load()),
|
| 1460 |
+
],
|
| 1461 |
+
keywords=[],
|
| 1462 |
+
)
|
| 1463 |
+
)
|
| 1464 |
+
return [ast.copy_location(assign, node), ast.copy_location(fix_self, node)]
|
| 1465 |
+
else:
|
| 1466 |
+
return [ast.copy_location(assign, node)]
|
| 1467 |
+
|
| 1468 |
+
def visit_IfExp(self, node):
|
| 1469 |
+
"""
|
| 1470 |
+
Visits an inline if-else expression (ternary operator).
|
| 1471 |
+
This is the Python equivalent of `x if condition else y`.
|
| 1472 |
+
"""
|
| 1473 |
+
self.generic_visit(node)
|
| 1474 |
+
# Emit
|
| 1475 |
+
# node if type(pred) == bool else select_(pred, body, orelse)
|
| 1476 |
+
# so if pred is a python bool, use python to short-circuit and avoid emit arith.select
|
| 1477 |
+
self.import_top_module = True
|
| 1478 |
+
return ast.copy_location(
|
| 1479 |
+
ast.IfExp(
|
| 1480 |
+
test=ast.Compare(
|
| 1481 |
+
left=ast.Call(
|
| 1482 |
+
func=ast.Name(id="type", ctx=ast.Load()),
|
| 1483 |
+
args=[node.test],
|
| 1484 |
+
keywords=[],
|
| 1485 |
+
),
|
| 1486 |
+
ops=[ast.Eq()],
|
| 1487 |
+
comparators=[ast.Name(id="bool", ctx=ast.Load())],
|
| 1488 |
+
),
|
| 1489 |
+
body=node, # Original ternary expression
|
| 1490 |
+
orelse=ast.Call(
|
| 1491 |
+
func=self._create_module_attribute(
|
| 1492 |
+
"select_", top_module_name="cutlass", submodule_name=None
|
| 1493 |
+
),
|
| 1494 |
+
args=[
|
| 1495 |
+
node.test,
|
| 1496 |
+
node.body,
|
| 1497 |
+
node.orelse,
|
| 1498 |
+
],
|
| 1499 |
+
keywords=[],
|
| 1500 |
+
),
|
| 1501 |
+
),
|
| 1502 |
+
node,
|
| 1503 |
+
)
|
| 1504 |
+
|
| 1505 |
+
cmpops = {
|
| 1506 |
+
"Eq": "==",
|
| 1507 |
+
"NotEq": "!=",
|
| 1508 |
+
"Lt": "<",
|
| 1509 |
+
"LtE": "<=",
|
| 1510 |
+
"Gt": ">",
|
| 1511 |
+
"GtE": ">=",
|
| 1512 |
+
"Is": "is",
|
| 1513 |
+
"IsNot": "is not",
|
| 1514 |
+
"In": "in",
|
| 1515 |
+
"NotIn": "not in",
|
| 1516 |
+
}
|
| 1517 |
+
def compare_ops_to_str(self, node):
|
| 1518 |
+
names = [
|
| 1519 |
+
ast.Constant(value=self.cmpops[op.__class__.__name__]) for op in node.ops
|
| 1520 |
+
]
|
| 1521 |
+
return ast.List(elts=names, ctx=ast.Load())
|
| 1522 |
+
|
| 1523 |
+
def visit_Compare(self, node):
|
| 1524 |
+
self.generic_visit(node)
|
| 1525 |
+
|
| 1526 |
+
comparator_strs = self.compare_ops_to_str(node)
|
| 1527 |
+
|
| 1528 |
+
keywords = [
|
| 1529 |
+
ast.keyword(arg="left", value=node.left),
|
| 1530 |
+
ast.keyword(
|
| 1531 |
+
arg="comparators", value=ast.List(elts=node.comparators, ctx=ast.Load())
|
| 1532 |
+
),
|
| 1533 |
+
ast.keyword(arg="ops", value=comparator_strs),
|
| 1534 |
+
]
|
| 1535 |
+
|
| 1536 |
+
call = ast.copy_location(
|
| 1537 |
+
ast.Call(
|
| 1538 |
+
func=self._create_module_attribute(self.COMPARE_EXECUTOR),
|
| 1539 |
+
args=[],
|
| 1540 |
+
keywords=keywords,
|
| 1541 |
+
),
|
| 1542 |
+
node,
|
| 1543 |
+
)
|
| 1544 |
+
|
| 1545 |
+
return call
|
| 1546 |
+
|
| 1547 |
+
def visit_If(self, node):
|
| 1548 |
+
# const_expr doesn't get preprocessed
|
| 1549 |
+
if self.is_node_constexpr(node):
|
| 1550 |
+
self.generic_visit(node)
|
| 1551 |
+
check = self._insert_cf_symbol_check(node.test.func)
|
| 1552 |
+
return [check, node]
|
| 1553 |
+
|
| 1554 |
+
active_symbols = self.scope_manager.get_active_symbols()
|
| 1555 |
+
with self.scope_manager:
|
| 1556 |
+
# Check for early exit and raise exception
|
| 1557 |
+
self.check_early_exit(node, "if")
|
| 1558 |
+
|
| 1559 |
+
yield_args, full_write_args_count = self.analyze_region_variables(
|
| 1560 |
+
node, active_symbols
|
| 1561 |
+
)
|
| 1562 |
+
func_name = f"if_region_{self.counter}"
|
| 1563 |
+
self.counter += 1
|
| 1564 |
+
|
| 1565 |
+
func_def = self.create_if_function(
|
| 1566 |
+
func_name, node, yield_args, full_write_args_count
|
| 1567 |
+
)
|
| 1568 |
+
assign = self.create_cf_call(func_name, yield_args, node)
|
| 1569 |
+
|
| 1570 |
+
return [func_def] + assign
|
| 1571 |
+
|
| 1572 |
+
def generate_get_locals_or_none_call(self, write_args):
|
| 1573 |
+
return ast.Call(
|
| 1574 |
+
func=self._create_module_attribute("get_locals_or_none"),
|
| 1575 |
+
args=[
|
| 1576 |
+
ast.Call(
|
| 1577 |
+
func=ast.Name(id="locals", ctx=ast.Load()), args=[], keywords=[]
|
| 1578 |
+
),
|
| 1579 |
+
ast.List(
|
| 1580 |
+
elts=[ast.Constant(value=arg) for arg in write_args],
|
| 1581 |
+
ctx=ast.Load(),
|
| 1582 |
+
),
|
| 1583 |
+
],
|
| 1584 |
+
keywords=[],
|
| 1585 |
+
)
|
| 1586 |
+
|
| 1587 |
+
def create_if_function(self, func_name, node, write_args, full_write_args_count):
|
| 1588 |
+
test_expr = self.visit(node.test)
|
| 1589 |
+
pred_name = self.make_func_param_name("pred", write_args)
|
| 1590 |
+
func_args = [ast.arg(arg=pred_name, annotation=None)]
|
| 1591 |
+
func_args += [ast.arg(arg=var, annotation=None) for var in write_args]
|
| 1592 |
+
func_args_then_else = [ast.arg(arg=var, annotation=None) for var in write_args]
|
| 1593 |
+
|
| 1594 |
+
then_body = []
|
| 1595 |
+
for stmt in node.body:
|
| 1596 |
+
transformed_stmt = self.visit(stmt) # Recursively visit inner statements
|
| 1597 |
+
if isinstance(transformed_stmt, list):
|
| 1598 |
+
then_body.extend(transformed_stmt)
|
| 1599 |
+
else:
|
| 1600 |
+
then_body.append(transformed_stmt)
|
| 1601 |
+
|
| 1602 |
+
# Create common return list for all blocks
|
| 1603 |
+
return_list = ast.List(
|
| 1604 |
+
elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args],
|
| 1605 |
+
ctx=ast.Load(),
|
| 1606 |
+
)
|
| 1607 |
+
|
| 1608 |
+
# Create common function arguments
|
| 1609 |
+
func_decorator_arguments = ast.arguments(
|
| 1610 |
+
posonlyargs=[], args=func_args, kwonlyargs=[], kw_defaults=[], defaults=[]
|
| 1611 |
+
)
|
| 1612 |
+
func_then_else_arguments = ast.arguments(
|
| 1613 |
+
posonlyargs=[],
|
| 1614 |
+
args=func_args_then_else,
|
| 1615 |
+
kwonlyargs=[],
|
| 1616 |
+
kw_defaults=[],
|
| 1617 |
+
defaults=[],
|
| 1618 |
+
)
|
| 1619 |
+
|
| 1620 |
+
then_block_name = f"then_block_{self.counter}"
|
| 1621 |
+
else_block_name = f"else_block_{self.counter}"
|
| 1622 |
+
elif_region_name = f"elif_region_{self.counter}"
|
| 1623 |
+
self.counter += 1
|
| 1624 |
+
|
| 1625 |
+
# Create then block
|
| 1626 |
+
then_block = ast.copy_location(
|
| 1627 |
+
ast.FunctionDef(
|
| 1628 |
+
name=then_block_name,
|
| 1629 |
+
args=func_then_else_arguments,
|
| 1630 |
+
body=then_body + [ast.Return(value=return_list)],
|
| 1631 |
+
decorator_list=[],
|
| 1632 |
+
),
|
| 1633 |
+
node,
|
| 1634 |
+
)
|
| 1635 |
+
|
| 1636 |
+
# Decorator keywords
|
| 1637 |
+
decorator_keywords = [
|
| 1638 |
+
ast.keyword(
|
| 1639 |
+
arg="pred", value=test_expr
|
| 1640 |
+
), # ast.Name(id="pred", ctx=ast.Load())
|
| 1641 |
+
ast.keyword(
|
| 1642 |
+
arg="write_args",
|
| 1643 |
+
value=self.generate_get_locals_or_none_call(write_args),
|
| 1644 |
+
),
|
| 1645 |
+
]
|
| 1646 |
+
|
| 1647 |
+
# Create decorator
|
| 1648 |
+
decorator = ast.copy_location(
|
| 1649 |
+
ast.Call(
|
| 1650 |
+
func=self._create_module_attribute(
|
| 1651 |
+
self.DECORATOR_IF_STATEMENT,
|
| 1652 |
+
lineno=node.lineno,
|
| 1653 |
+
col_offset=node.col_offset,
|
| 1654 |
+
),
|
| 1655 |
+
args=[],
|
| 1656 |
+
keywords=decorator_keywords,
|
| 1657 |
+
),
|
| 1658 |
+
node,
|
| 1659 |
+
)
|
| 1660 |
+
|
| 1661 |
+
# Executor keywords
|
| 1662 |
+
execute_keywords = [
|
| 1663 |
+
ast.keyword(arg="pred", value=ast.Name(id=pred_name, ctx=ast.Load())),
|
| 1664 |
+
ast.keyword(
|
| 1665 |
+
arg="write_args",
|
| 1666 |
+
value=ast.List(
|
| 1667 |
+
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args],
|
| 1668 |
+
ctx=ast.Load(),
|
| 1669 |
+
),
|
| 1670 |
+
),
|
| 1671 |
+
ast.keyword(
|
| 1672 |
+
arg="full_write_args_count",
|
| 1673 |
+
value=ast.Constant(value=full_write_args_count),
|
| 1674 |
+
),
|
| 1675 |
+
ast.keyword(
|
| 1676 |
+
arg="write_args_names",
|
| 1677 |
+
value=ast.List(
|
| 1678 |
+
elts=[ast.Constant(value=arg) for arg in write_args],
|
| 1679 |
+
ctx=ast.Load(),
|
| 1680 |
+
),
|
| 1681 |
+
),
|
| 1682 |
+
ast.keyword(
|
| 1683 |
+
arg="then_block", value=ast.Name(id=then_block_name, ctx=ast.Load())
|
| 1684 |
+
),
|
| 1685 |
+
]
|
| 1686 |
+
|
| 1687 |
+
# Handle different cases
|
| 1688 |
+
if not write_args and node.orelse == []:
|
| 1689 |
+
# No write_args case - only then_block needed
|
| 1690 |
+
execute_call = ast.copy_location(
|
| 1691 |
+
ast.Call(
|
| 1692 |
+
func=self._create_module_attribute(
|
| 1693 |
+
self.IF_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset
|
| 1694 |
+
),
|
| 1695 |
+
args=[],
|
| 1696 |
+
keywords=execute_keywords,
|
| 1697 |
+
),
|
| 1698 |
+
node,
|
| 1699 |
+
)
|
| 1700 |
+
func_body = [then_block, ast.Return(value=execute_call)]
|
| 1701 |
+
else:
|
| 1702 |
+
# Create else block based on node.orelse
|
| 1703 |
+
if node.orelse:
|
| 1704 |
+
if len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If):
|
| 1705 |
+
# Handle elif case
|
| 1706 |
+
elif_node = node.orelse[0]
|
| 1707 |
+
nested_if_name = elif_region_name
|
| 1708 |
+
# Recursion for nested elif
|
| 1709 |
+
nested_if = self.create_if_function(
|
| 1710 |
+
nested_if_name, elif_node, write_args, full_write_args_count
|
| 1711 |
+
)
|
| 1712 |
+
else_block = ast.FunctionDef(
|
| 1713 |
+
name=else_block_name,
|
| 1714 |
+
args=func_then_else_arguments,
|
| 1715 |
+
body=[
|
| 1716 |
+
nested_if,
|
| 1717 |
+
ast.Return(
|
| 1718 |
+
value=ast.Name(id=nested_if_name, ctx=ast.Load())
|
| 1719 |
+
),
|
| 1720 |
+
],
|
| 1721 |
+
decorator_list=[],
|
| 1722 |
+
)
|
| 1723 |
+
else:
|
| 1724 |
+
|
| 1725 |
+
else_body = []
|
| 1726 |
+
for stmt in node.orelse:
|
| 1727 |
+
transformed_stmt = self.visit(
|
| 1728 |
+
stmt
|
| 1729 |
+
) # Recursively visit inner statements
|
| 1730 |
+
if isinstance(transformed_stmt, list):
|
| 1731 |
+
else_body.extend(transformed_stmt)
|
| 1732 |
+
else:
|
| 1733 |
+
else_body.append(transformed_stmt)
|
| 1734 |
+
|
| 1735 |
+
# Regular else block
|
| 1736 |
+
else_block = ast.FunctionDef(
|
| 1737 |
+
name=else_block_name,
|
| 1738 |
+
args=func_then_else_arguments,
|
| 1739 |
+
body=else_body + [ast.Return(value=return_list)],
|
| 1740 |
+
decorator_list=[],
|
| 1741 |
+
)
|
| 1742 |
+
else:
|
| 1743 |
+
# Default else block
|
| 1744 |
+
else_block = ast.FunctionDef(
|
| 1745 |
+
name=else_block_name,
|
| 1746 |
+
args=func_then_else_arguments,
|
| 1747 |
+
body=[ast.Return(value=return_list)],
|
| 1748 |
+
decorator_list=[],
|
| 1749 |
+
)
|
| 1750 |
+
|
| 1751 |
+
# Add else_block to execute keywords
|
| 1752 |
+
execute_keywords.append(
|
| 1753 |
+
ast.keyword(
|
| 1754 |
+
arg="else_block", value=ast.Name(id=else_block_name, ctx=ast.Load())
|
| 1755 |
+
)
|
| 1756 |
+
)
|
| 1757 |
+
|
| 1758 |
+
execute_call = ast.copy_location(
|
| 1759 |
+
ast.Call(
|
| 1760 |
+
func=self._create_module_attribute(
|
| 1761 |
+
self.IF_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset
|
| 1762 |
+
),
|
| 1763 |
+
args=[],
|
| 1764 |
+
keywords=execute_keywords,
|
| 1765 |
+
),
|
| 1766 |
+
node,
|
| 1767 |
+
)
|
| 1768 |
+
func_body = [
|
| 1769 |
+
then_block,
|
| 1770 |
+
ast.copy_location(else_block, node),
|
| 1771 |
+
ast.Return(value=execute_call),
|
| 1772 |
+
]
|
| 1773 |
+
|
| 1774 |
+
return ast.copy_location(
|
| 1775 |
+
ast.FunctionDef(
|
| 1776 |
+
name=func_name,
|
| 1777 |
+
args=func_decorator_arguments,
|
| 1778 |
+
body=func_body,
|
| 1779 |
+
decorator_list=[decorator],
|
| 1780 |
+
),
|
| 1781 |
+
node,
|
| 1782 |
+
)
|
| 1783 |
+
|
| 1784 |
+
def create_while_function(self, func_name, node, write_args, full_write_args_count):
|
| 1785 |
+
"""Create a while function that looks like:
|
| 1786 |
+
|
| 1787 |
+
@while_selector(pred, write_args=[])
|
| 1788 |
+
def while_region(pred, write_args):
|
| 1789 |
+
def while_before_block(*write_args):
|
| 1790 |
+
# Note that during eval of pred can possibly alter yield_args
|
| 1791 |
+
return *pred, write_args
|
| 1792 |
+
def while_after_block(*write_args):
|
| 1793 |
+
...loop_body_transformed...
|
| 1794 |
+
return write_args
|
| 1795 |
+
return self.while_executor(pred, write_args,
|
| 1796 |
+
while_before_block, while_after_block, constexpr)
|
| 1797 |
+
write_args = while_region(pred, write_args)
|
| 1798 |
+
|
| 1799 |
+
Which will later be executed as psuedo-code:
|
| 1800 |
+
|
| 1801 |
+
# Dynamic mode:
|
| 1802 |
+
scf.WhileOp(types(write_args), write_args)
|
| 1803 |
+
with InsertionPoint(before_block):
|
| 1804 |
+
cond, write_args = while_before_block(*write_args)
|
| 1805 |
+
scf.ConditionOp(cond, write_args)
|
| 1806 |
+
with InsertionPoint(after_block):
|
| 1807 |
+
write_args = while_after_block(write_args)
|
| 1808 |
+
scf.YieldOp(write_args)
|
| 1809 |
+
return while_op.results_
|
| 1810 |
+
|
| 1811 |
+
# Const mode:
|
| 1812 |
+
cond, write_args = while_before_block(write_args)
|
| 1813 |
+
while pred:
|
| 1814 |
+
write_args = body_block(write_args)
|
| 1815 |
+
cond, write_args = while_before_block(write_args)
|
| 1816 |
+
return write_args
|
| 1817 |
+
"""
|
| 1818 |
+
test_expr = self.visit(node.test)
|
| 1819 |
+
pred_name = self.make_func_param_name("pred", write_args)
|
| 1820 |
+
|
| 1821 |
+
# Section: decorator construction
|
| 1822 |
+
decorator_keywords = [
|
| 1823 |
+
ast.keyword(arg="pred", value=test_expr),
|
| 1824 |
+
ast.keyword(
|
| 1825 |
+
arg="write_args",
|
| 1826 |
+
value=self.generate_get_locals_or_none_call(write_args),
|
| 1827 |
+
),
|
| 1828 |
+
]
|
| 1829 |
+
decorator = ast.copy_location(
|
| 1830 |
+
ast.Call(
|
| 1831 |
+
func=self._create_module_attribute(
|
| 1832 |
+
self.DECORATOR_WHILE_STATEMENT,
|
| 1833 |
+
lineno=node.lineno,
|
| 1834 |
+
col_offset=node.col_offset,
|
| 1835 |
+
),
|
| 1836 |
+
args=[],
|
| 1837 |
+
keywords=decorator_keywords,
|
| 1838 |
+
),
|
| 1839 |
+
node,
|
| 1840 |
+
)
|
| 1841 |
+
|
| 1842 |
+
# Section: Shared initialization for before and after blocks
|
| 1843 |
+
while_before_block_name = f"while_before_block_{self.counter}"
|
| 1844 |
+
while_after_block_name = f"while_after_block_{self.counter}"
|
| 1845 |
+
self.counter += 1
|
| 1846 |
+
block_args_args = [ast.arg(arg=var, annotation=None) for var in write_args]
|
| 1847 |
+
block_args = ast.arguments(
|
| 1848 |
+
posonlyargs=[],
|
| 1849 |
+
args=block_args_args,
|
| 1850 |
+
kwonlyargs=[],
|
| 1851 |
+
kw_defaults=[],
|
| 1852 |
+
defaults=[],
|
| 1853 |
+
)
|
| 1854 |
+
|
| 1855 |
+
yield_args_ast_name_list = ast.List(
|
| 1856 |
+
elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args],
|
| 1857 |
+
ctx=ast.Load(),
|
| 1858 |
+
)
|
| 1859 |
+
|
| 1860 |
+
# Section: while_before_block FunctionDef, which contains condition
|
| 1861 |
+
while_before_return_list = ast.List(
|
| 1862 |
+
elts=[test_expr, yield_args_ast_name_list],
|
| 1863 |
+
ctx=ast.Load(),
|
| 1864 |
+
)
|
| 1865 |
+
while_before_stmts = [ast.Return(value=while_before_return_list)]
|
| 1866 |
+
while_before_block = ast.copy_location(
|
| 1867 |
+
ast.FunctionDef(
|
| 1868 |
+
name=while_before_block_name,
|
| 1869 |
+
args=block_args,
|
| 1870 |
+
body=while_before_stmts,
|
| 1871 |
+
decorator_list=[],
|
| 1872 |
+
),
|
| 1873 |
+
test_expr,
|
| 1874 |
+
)
|
| 1875 |
+
|
| 1876 |
+
# Section: while_after_block FunctionDef, which contains loop body
|
| 1877 |
+
while_after_stmts = []
|
| 1878 |
+
for stmt in node.body:
|
| 1879 |
+
transformed_stmt = self.visit(stmt) # Recursively visit inner statements
|
| 1880 |
+
if isinstance(transformed_stmt, list):
|
| 1881 |
+
while_after_stmts.extend(transformed_stmt)
|
| 1882 |
+
else:
|
| 1883 |
+
while_after_stmts.append(transformed_stmt)
|
| 1884 |
+
while_after_stmts.append(ast.Return(value=yield_args_ast_name_list))
|
| 1885 |
+
|
| 1886 |
+
while_after_block = ast.copy_location(
|
| 1887 |
+
ast.FunctionDef(
|
| 1888 |
+
name=while_after_block_name,
|
| 1889 |
+
args=block_args,
|
| 1890 |
+
body=while_after_stmts,
|
| 1891 |
+
decorator_list=[],
|
| 1892 |
+
),
|
| 1893 |
+
node,
|
| 1894 |
+
)
|
| 1895 |
+
|
| 1896 |
+
# Section: Execute via executor
|
| 1897 |
+
execute_keywords = [
|
| 1898 |
+
ast.keyword(arg="pred", value=ast.Name(id=pred_name, ctx=ast.Load())),
|
| 1899 |
+
ast.keyword(
|
| 1900 |
+
arg="write_args",
|
| 1901 |
+
value=ast.List(
|
| 1902 |
+
elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args],
|
| 1903 |
+
ctx=ast.Load(),
|
| 1904 |
+
),
|
| 1905 |
+
),
|
| 1906 |
+
ast.keyword(
|
| 1907 |
+
arg="full_write_args_count",
|
| 1908 |
+
value=ast.Constant(value=full_write_args_count),
|
| 1909 |
+
),
|
| 1910 |
+
ast.keyword(
|
| 1911 |
+
arg="while_before_block",
|
| 1912 |
+
value=ast.Name(id=while_before_block_name, ctx=ast.Load()),
|
| 1913 |
+
),
|
| 1914 |
+
ast.keyword(
|
| 1915 |
+
arg="while_after_block",
|
| 1916 |
+
value=ast.Name(id=while_after_block_name, ctx=ast.Load()),
|
| 1917 |
+
),
|
| 1918 |
+
ast.keyword(
|
| 1919 |
+
arg="write_args_names",
|
| 1920 |
+
value=ast.List(
|
| 1921 |
+
elts=[ast.Constant(value=arg) for arg in write_args],
|
| 1922 |
+
ctx=ast.Load(),
|
| 1923 |
+
),
|
| 1924 |
+
),
|
| 1925 |
+
]
|
| 1926 |
+
|
| 1927 |
+
execute_call = ast.Call(
|
| 1928 |
+
func=self._create_module_attribute(
|
| 1929 |
+
self.WHILE_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset
|
| 1930 |
+
),
|
| 1931 |
+
args=[],
|
| 1932 |
+
keywords=execute_keywords,
|
| 1933 |
+
)
|
| 1934 |
+
|
| 1935 |
+
# Putting everything together, FunctionDef for while_region
|
| 1936 |
+
func_args_args = [ast.arg(arg=pred_name, annotation=None)]
|
| 1937 |
+
func_args_args += [ast.arg(arg=var, annotation=None) for var in write_args]
|
| 1938 |
+
func_args = ast.arguments(
|
| 1939 |
+
posonlyargs=[],
|
| 1940 |
+
args=func_args_args,
|
| 1941 |
+
kwonlyargs=[],
|
| 1942 |
+
kw_defaults=[],
|
| 1943 |
+
defaults=[],
|
| 1944 |
+
)
|
| 1945 |
+
|
| 1946 |
+
return ast.copy_location(
|
| 1947 |
+
ast.FunctionDef(
|
| 1948 |
+
name=func_name,
|
| 1949 |
+
args=func_args,
|
| 1950 |
+
body=[
|
| 1951 |
+
while_before_block,
|
| 1952 |
+
while_after_block,
|
| 1953 |
+
ast.Return(value=execute_call),
|
| 1954 |
+
],
|
| 1955 |
+
decorator_list=[decorator],
|
| 1956 |
+
),
|
| 1957 |
+
node,
|
| 1958 |
+
)
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/cache_helpers.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This module provides jit cache load/dump helper functions
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import uuid
|
| 18 |
+
import random
|
| 19 |
+
import tempfile
|
| 20 |
+
import pwd
|
| 21 |
+
import time
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
import hashlib
|
| 24 |
+
|
| 25 |
+
from .utils.logger import log
|
| 26 |
+
from .jit_executor import JitExecutor
|
| 27 |
+
|
| 28 |
+
from .._mlir import ir
|
| 29 |
+
|
| 30 |
+
# =============================================================================
|
| 31 |
+
# Jit Cache Helper functions
|
| 32 |
+
# =============================================================================
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_current_user():
|
| 36 |
+
# Try to get the user from the environment variable first
|
| 37 |
+
user = os.getenv("USER") or os.getenv("USERNAME")
|
| 38 |
+
if not user:
|
| 39 |
+
# Fallback for Unix-like systems
|
| 40 |
+
user = pwd.getpwuid(os.getuid()).pw_name
|
| 41 |
+
return user
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
default_generated_ir_path = f"/tmp/{get_current_user()}/cutlass_python_cache/"
|
| 46 |
+
except Exception as e:
|
| 47 |
+
# If all else fails, provide a default fallback path
|
| 48 |
+
default_generated_ir_path = "/tmp/cutlass_python_cache/"
|
| 49 |
+
print(f"Could not determine user, using default path. Error: {e}")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_ir(file, asBytecode=False):
|
| 53 |
+
"""Load generated IR from a file."""
|
| 54 |
+
assert "mlir" in file
|
| 55 |
+
func_name = file.split(".mlir")[0].split("dsl_")[-1]
|
| 56 |
+
with ir.Context() as ctx:
|
| 57 |
+
with open(file, "rb" if asBytecode else "r") as f:
|
| 58 |
+
module = ir.Module.parse(f.read())
|
| 59 |
+
|
| 60 |
+
return func_name, module
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def make_unique_filename(fpath: Path, new_ext: str = None) -> Path:
|
| 64 |
+
"""Generate a unique filename with an optional new extension."""
|
| 65 |
+
random_part = random.randint(0, 999999)
|
| 66 |
+
timestamp = time.time()
|
| 67 |
+
hash_input = f"{fpath}_{timestamp}_{random_part}".encode()
|
| 68 |
+
hash_code = hashlib.md5(hash_input).hexdigest()[:16] # Shorter hash for readability
|
| 69 |
+
stem_with_hash = f"{fpath.stem}_{hash_code}"
|
| 70 |
+
return fpath.with_name(stem_with_hash).with_suffix(new_ext or fpath.suffix)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def save_ir(
|
| 74 |
+
dsl_name: str,
|
| 75 |
+
module: object,
|
| 76 |
+
fname: str,
|
| 77 |
+
isTemp: bool = False,
|
| 78 |
+
asBytecode: bool = False,
|
| 79 |
+
) -> str:
|
| 80 |
+
"""Save generated IR to a file."""
|
| 81 |
+
initial_name = f"{dsl_name.lower()}_{fname}.mlir"
|
| 82 |
+
save_path = Path(tempfile.gettempdir() if isTemp else os.getcwd())
|
| 83 |
+
save_fname = save_path / initial_name
|
| 84 |
+
# Random ID to avoid any collisions
|
| 85 |
+
rnd_id = str(uuid.uuid4())
|
| 86 |
+
pid = os.getpid()
|
| 87 |
+
# use temp dir to be robust against program interruptions
|
| 88 |
+
temp_dir = os.path.join(save_path, f"tmp.pid_{pid}_{rnd_id}")
|
| 89 |
+
# If the process exits abnormally, may leave a temporary folder. Needs to be removed manually.
|
| 90 |
+
os.makedirs(temp_dir, exist_ok=False)
|
| 91 |
+
temp_fname = os.path.join(temp_dir, initial_name)
|
| 92 |
+
|
| 93 |
+
if asBytecode:
|
| 94 |
+
with open(temp_fname, "wb") as f:
|
| 95 |
+
module.operation.write_bytecode(f)
|
| 96 |
+
else:
|
| 97 |
+
with open(temp_fname, "w") as f:
|
| 98 |
+
print(module, file=f)
|
| 99 |
+
# os.replace is guaranteed to be atomic on POSIX systems if it succeeds
|
| 100 |
+
# so filepath cannot see a partial write
|
| 101 |
+
os.replace(temp_fname, save_fname)
|
| 102 |
+
os.removedirs(temp_dir)
|
| 103 |
+
log().debug("Generated IR saved into %s", save_fname)
|
| 104 |
+
return save_fname
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def check_func_name(jit_cache, func_name):
|
| 108 |
+
if not func_name in jit_cache:
|
| 109 |
+
jit_cache[func_name] = JitExecutor(None, None, None, None, None, None)
|
| 110 |
+
return jit_cache
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def load_cache_from_path(dsl_name, cache_limit, path=default_generated_ir_path):
|
| 114 |
+
"""Load cache from a directory path."""
|
| 115 |
+
if not os.path.exists(path):
|
| 116 |
+
return dict()
|
| 117 |
+
files = os.listdir(path)
|
| 118 |
+
jit_cache = dict()
|
| 119 |
+
try:
|
| 120 |
+
for idx, file in enumerate(files):
|
| 121 |
+
if idx >= int(cache_limit):
|
| 122 |
+
break
|
| 123 |
+
# identify dsl prefix
|
| 124 |
+
if not file.startswith(f"{dsl_name.lower()}"):
|
| 125 |
+
continue
|
| 126 |
+
if ".mlir" in file:
|
| 127 |
+
func_name, ir_module = load_ir(
|
| 128 |
+
os.path.join(path, file), asBytecode=True
|
| 129 |
+
)
|
| 130 |
+
jit_cache = check_func_name(jit_cache, func_name)
|
| 131 |
+
jit_cache[func_name].ir_module = ir_module
|
| 132 |
+
except Exception as e:
|
| 133 |
+
print(f"{dsl_name} failed with loading generated IR cache.", e)
|
| 134 |
+
jit_cache = dict()
|
| 135 |
+
return jit_cache
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def dump_cache_to_path(
|
| 139 |
+
dsl_name, jit_cache, cache_limit, path=default_generated_ir_path
|
| 140 |
+
):
|
| 141 |
+
log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache))
|
| 142 |
+
os.makedirs(path, exist_ok=True)
|
| 143 |
+
original_path = os.getcwd()
|
| 144 |
+
try:
|
| 145 |
+
os.chdir(path)
|
| 146 |
+
for idx, [key, value] in enumerate(jit_cache.items()):
|
| 147 |
+
if idx >= int(cache_limit):
|
| 148 |
+
break
|
| 149 |
+
save_ir(dsl_name, value.ir_module, key, asBytecode=True)
|
| 150 |
+
except Exception as e:
|
| 151 |
+
print(f"{dsl_name} failed with caching generated IR", e)
|
| 152 |
+
finally:
|
| 153 |
+
os.chdir(original_path)
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/common.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
from typing import Any, Dict, Iterable, Optional, Union
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
This module provides a Exception classes DSL class for any Dialect.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Add color codes at the top of the file after imports
|
| 21 |
+
class Colors:
|
| 22 |
+
"""ANSI color codes for error messages"""
|
| 23 |
+
|
| 24 |
+
RED = "\033[91m"
|
| 25 |
+
YELLOW = "\033[93m"
|
| 26 |
+
BLUE = "\033[94m"
|
| 27 |
+
GREEN = "\033[92m"
|
| 28 |
+
BOLD = "\033[1m"
|
| 29 |
+
RESET = "\033[0m"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# =============================================================================
|
| 33 |
+
# DSL Exceptions
|
| 34 |
+
# =============================================================================
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DSLBaseError(Exception):
|
| 38 |
+
"""
|
| 39 |
+
Base exception for DSL-related errors.
|
| 40 |
+
Provides optional contextual metadata to aid in debugging.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
message: str,
|
| 46 |
+
line: Optional[int] = None,
|
| 47 |
+
snippet: Optional[str] = None,
|
| 48 |
+
filename: Optional[str] = None,
|
| 49 |
+
error_code: Optional[Union[str, int]] = None,
|
| 50 |
+
context: Optional[Union[Dict[str, Any], str]] = None,
|
| 51 |
+
suggestion: Optional[str] = None,
|
| 52 |
+
cause: Optional[BaseException] = None,
|
| 53 |
+
) -> None:
|
| 54 |
+
self.message = message
|
| 55 |
+
self.line = line
|
| 56 |
+
self.filename = filename
|
| 57 |
+
self.snippet = snippet
|
| 58 |
+
self.error_code = error_code
|
| 59 |
+
self.context = context
|
| 60 |
+
self.suggestion = suggestion
|
| 61 |
+
self.cause = cause
|
| 62 |
+
|
| 63 |
+
super().__init__(self._format_message())
|
| 64 |
+
|
| 65 |
+
def _format_message(self):
|
| 66 |
+
"""
|
| 67 |
+
Formats the complete error message with available metadata.
|
| 68 |
+
Override this in subclasses if you want to change formatting logic.
|
| 69 |
+
"""
|
| 70 |
+
parts = [f"{self.__class__.__name__}: {self.message}"]
|
| 71 |
+
|
| 72 |
+
if self.error_code is not None:
|
| 73 |
+
parts.append(f"{Colors.BOLD}Error Code:{Colors.RESET} {self.error_code}\n")
|
| 74 |
+
|
| 75 |
+
if self.line is not None:
|
| 76 |
+
parts.append(f" Line: {self.line}")
|
| 77 |
+
|
| 78 |
+
if self.filename is not None:
|
| 79 |
+
parts.append(f" File: {self.filename}")
|
| 80 |
+
|
| 81 |
+
if self.snippet:
|
| 82 |
+
# Optionally truncate long snippets for readability
|
| 83 |
+
parts.append(f" Snippet: \n {self.snippet}")
|
| 84 |
+
|
| 85 |
+
if self.cause:
|
| 86 |
+
parts.append(f" Caused exception: {self.cause}")
|
| 87 |
+
|
| 88 |
+
if self.context:
|
| 89 |
+
if isinstance(self.context, dict):
|
| 90 |
+
parts.append(f"{Colors.BLUE}🔍 Additional Context:{Colors.RESET}\n")
|
| 91 |
+
for key, value in self.context.items():
|
| 92 |
+
parts.append(f" {key}: {value}")
|
| 93 |
+
else:
|
| 94 |
+
parts.append(
|
| 95 |
+
f"{Colors.BLUE}🔍 Additional Context:{Colors.RESET} {self.context}"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
if self.suggestion:
|
| 99 |
+
parts.append(f"{Colors.GREEN}💡 Suggestions:{Colors.RESET}")
|
| 100 |
+
if isinstance(self.suggestion, (list, tuple)):
|
| 101 |
+
for suggestion in self.suggestion:
|
| 102 |
+
parts.append(f" {Colors.GREEN}{suggestion}{Colors.RESET}")
|
| 103 |
+
else:
|
| 104 |
+
parts.append(f" {self.suggestion}")
|
| 105 |
+
|
| 106 |
+
return "\n".join(parts)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class DSLRuntimeError(DSLBaseError):
|
| 110 |
+
"""
|
| 111 |
+
Raised when an error occurs during JIT-time code generation in the DSL.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
# Inherits all logic from DSLBaseError; override methods if you need
|
| 115 |
+
# specialized behavior or formatting for runtime errors.
|
| 116 |
+
pass
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _get_friendly_cuda_error_message(error_code, error_name):
|
| 120 |
+
# Avoid circular dependency
|
| 121 |
+
from .runtime.cuda import get_device_info
|
| 122 |
+
|
| 123 |
+
"""Get a user-friendly error message for common CUDA errors."""
|
| 124 |
+
# Strip the byte string markers if present
|
| 125 |
+
if isinstance(error_name, bytes):
|
| 126 |
+
error_name = error_name.decode("utf-8")
|
| 127 |
+
elif (
|
| 128 |
+
isinstance(error_name, str)
|
| 129 |
+
and error_name.startswith("b'")
|
| 130 |
+
and error_name.endswith("'")
|
| 131 |
+
):
|
| 132 |
+
error_name = error_name[2:-1]
|
| 133 |
+
|
| 134 |
+
# Add target architecture info
|
| 135 |
+
target_arch = os.getenv("CUTE_DSL_ARCH", "unknown")
|
| 136 |
+
|
| 137 |
+
error_messages = {
|
| 138 |
+
"CUDA_ERROR_INVALID_SOURCE": (
|
| 139 |
+
f"{Colors.RED}❌ Failed to load CUDA kernel - likely architecture mismatch.{Colors.RESET}\n\n"
|
| 140 |
+
),
|
| 141 |
+
"CUDA_ERROR_NO_BINARY_FOR_GPU": (
|
| 142 |
+
f"{Colors.RED}❌ CUDA kernel not compatible with your GPU.{Colors.RESET}\n\n"
|
| 143 |
+
),
|
| 144 |
+
"CUDA_ERROR_OUT_OF_MEMORY": (
|
| 145 |
+
f"{Colors.RED}💾 CUDA out of memory error.{Colors.RESET}\n\n"
|
| 146 |
+
),
|
| 147 |
+
"CUDA_ERROR_INVALID_DEVICE": (
|
| 148 |
+
f"{Colors.RED}❌ Invalid CUDA device.{Colors.RESET}\n\n"
|
| 149 |
+
),
|
| 150 |
+
"CUDA_ERROR_NOT_INITIALIZED": (
|
| 151 |
+
f"{Colors.RED}❌ CUDA context not initialized.{Colors.RESET}\n\n"
|
| 152 |
+
),
|
| 153 |
+
"CUDA_ERROR_INVALID_VALUE": (
|
| 154 |
+
f"{Colors.RED}⚠️ Invalid parameter passed to CUDA operation.{Colors.RESET}\n\n"
|
| 155 |
+
f"{Colors.YELLOW}This is likely a bug - please report it with:{Colors.RESET}"
|
| 156 |
+
),
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
error_suggestions = {
|
| 160 |
+
"CUDA_ERROR_INVALID_SOURCE": (
|
| 161 |
+
f"1. Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
| 162 |
+
f"2. Clear the compilation cache and regenerate the kernel",
|
| 163 |
+
f"3. Check CUDA toolkit installation",
|
| 164 |
+
),
|
| 165 |
+
"CUDA_ERROR_NO_BINARY_FOR_GPU": (
|
| 166 |
+
f"Set env CUTE_DSL_ARCH to match your GPU architecture",
|
| 167 |
+
),
|
| 168 |
+
"CUDA_ERROR_OUT_OF_MEMORY": (
|
| 169 |
+
f"1. Reduce batch size",
|
| 170 |
+
f"2. Reduce model size",
|
| 171 |
+
f"3. Free unused GPU memory",
|
| 172 |
+
),
|
| 173 |
+
"CUDA_ERROR_INVALID_DEVICE": (
|
| 174 |
+
f"1. Check if CUDA device is properly initialized",
|
| 175 |
+
f"2. Verify GPU is detected: nvidia-smi",
|
| 176 |
+
f"3. Check CUDA_VISIBLE_DEVICES environment variable",
|
| 177 |
+
),
|
| 178 |
+
"CUDA_ERROR_NOT_INITIALIZED": (
|
| 179 |
+
f"1. Check CUDA driver installation",
|
| 180 |
+
f"2. call `cuda.cuInit(0)` before any other CUDA operation",
|
| 181 |
+
f"3. Run nvidia-smi to confirm GPU status",
|
| 182 |
+
),
|
| 183 |
+
"CUDA_ERROR_INVALID_VALUE": (
|
| 184 |
+
f"1. Your GPU model",
|
| 185 |
+
f"2. SM ARCH setting",
|
| 186 |
+
f"3. Steps to reproduce",
|
| 187 |
+
),
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
message = error_messages.get(
|
| 191 |
+
error_name, f"{Colors.RED}Unknown CUDA error{Colors.RESET}"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Add debug information
|
| 195 |
+
debug_info = f"\n- {Colors.BOLD}Error name: {error_name}\n"
|
| 196 |
+
debug_info += f"- CUDA_TOOLKIT_PATH: {os.getenv('CUDA_TOOLKIT_PATH', 'not set')}\n"
|
| 197 |
+
debug_info += (
|
| 198 |
+
f"- Target SM ARCH: {os.getenv('CUTE_DSL_ARCH', 'not set')}{Colors.RESET}\n"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
# Get GPU information using CUDA Python API
|
| 203 |
+
debug_info += f"\n{Colors.BLUE}📊 GPU Information:{Colors.RESET}\n"
|
| 204 |
+
gpu_info = get_device_info()
|
| 205 |
+
debug_info += gpu_info.pretty_str()
|
| 206 |
+
|
| 207 |
+
if target_arch and gpu_info.compatible_archs:
|
| 208 |
+
debug_info += f"\n{Colors.BOLD}Compatibility Check:{Colors.RESET}\n"
|
| 209 |
+
|
| 210 |
+
if target_arch not in gpu_info.compatible_archs:
|
| 211 |
+
debug_info += (
|
| 212 |
+
f"{Colors.RED}❌ Error: Target SM ARCH {target_arch} is not compatible\n"
|
| 213 |
+
f"💡 Please use one of SM ARCHs: "
|
| 214 |
+
f"{Colors.GREEN}{', '.join(gpu_info.compatible_archs or [])}{Colors.RESET}\n"
|
| 215 |
+
)
|
| 216 |
+
elif target_arch != gpu_info.sm_arch:
|
| 217 |
+
debug_info += (
|
| 218 |
+
f"{Colors.YELLOW}⚠️ Warning: Using compatible but non-optimal architecture\n"
|
| 219 |
+
f"• Current: {target_arch}\n"
|
| 220 |
+
f"• Recommended: {Colors.GREEN}{gpu_info.sm_arch}{Colors.RESET} (native)\n"
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
debug_info += f"{Colors.GREEN}✓ Using optimal architecture: {gpu_info.sm_arch}{Colors.RESET}\n"
|
| 224 |
+
|
| 225 |
+
except Exception as e:
|
| 226 |
+
debug_info += (
|
| 227 |
+
f"\n{Colors.YELLOW}ℹ️ Could not retrieve GPU info: {str(e)}{Colors.RESET}"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
return message, debug_info, error_suggestions.get(error_name, "")
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class DSLCudaRuntimeError(DSLBaseError):
|
| 234 |
+
"""
|
| 235 |
+
Raised when an error occurs during CUDA runtime code generation in the DSL.
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
# Inherits all logic from DSLRuntimeError; override methods if you need
|
| 239 |
+
# specialized behavior or formatting for runtime errors.
|
| 240 |
+
def __init__(self, error_code, error_name) -> None:
|
| 241 |
+
self._error_code = error_code
|
| 242 |
+
self._error_name = error_name
|
| 243 |
+
message, debug_info, suggestion = _get_friendly_cuda_error_message(
|
| 244 |
+
error_code, error_name
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
super().__init__(
|
| 248 |
+
message, error_code=error_code, context=debug_info, suggestion=suggestion
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class DSLAstPreprocessorError(DSLBaseError):
|
| 253 |
+
"""
|
| 254 |
+
Raised when an error occurs during AST preprocessing or visiting in the DSL.
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
# Same approach: You could override _format_message if you want
|
| 258 |
+
# to emphasize AST node details or anything specific to preprocessing.
|
| 259 |
+
pass
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class DSLNotImplemented(DSLBaseError):
|
| 263 |
+
"""
|
| 264 |
+
Raised when a feature of the DSL is not implemented yet.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
# Useful for stubs in your DSL that you plan to implement in the future.
|
| 268 |
+
pass
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/compiler.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This module provides a class that compiles generated IR using MLIR's PassManager
|
| 14 |
+
and executes it using MLIR's ExecutionEngine.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from typing import Sequence, Optional, Tuple
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import inspect
|
| 22 |
+
import argparse
|
| 23 |
+
from .common import DSLRuntimeError
|
| 24 |
+
from .utils.logger import log
|
| 25 |
+
|
| 26 |
+
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
|
| 27 |
+
sys.path.append(_SCRIPT_PATH)
|
| 28 |
+
|
| 29 |
+
from .._mlir import ir
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# =============================================================================
|
| 33 |
+
# Compiler Class
|
| 34 |
+
# =============================================================================
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class CompilationError(RuntimeError):
|
| 38 |
+
"""Custom error class for compilation failures"""
|
| 39 |
+
|
| 40 |
+
# Add ANSI color codes
|
| 41 |
+
RED = "\033[91m"
|
| 42 |
+
YELLOW = "\033[93m"
|
| 43 |
+
BLUE = "\033[94m"
|
| 44 |
+
GREEN = "\033[92m"
|
| 45 |
+
BOLD = "\033[1m"
|
| 46 |
+
RESET = "\033[0m"
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
message: str,
|
| 51 |
+
nvvm_error: Optional[str] = None,
|
| 52 |
+
ir_context: Optional[str] = None,
|
| 53 |
+
cuda_toolkit: Optional[str] = None,
|
| 54 |
+
arch: Optional[str] = None,
|
| 55 |
+
):
|
| 56 |
+
self.nvvm_error = nvvm_error
|
| 57 |
+
self.ir_context = ir_context
|
| 58 |
+
self.cuda_toolkit = cuda_toolkit
|
| 59 |
+
self.arch = arch
|
| 60 |
+
# Call parent with formatted error to avoid showing class name
|
| 61 |
+
super().__init__("") # Empty string to avoid class name
|
| 62 |
+
# Store formatted error for str() representation
|
| 63 |
+
self._formatted_error = self._format_error()
|
| 64 |
+
|
| 65 |
+
def __str__(self) -> str:
|
| 66 |
+
"""Override string representation to avoid showing class name"""
|
| 67 |
+
return self._formatted_error
|
| 68 |
+
|
| 69 |
+
def __repr__(self) -> str:
|
| 70 |
+
"""Override repr representation to avoid showing class name"""
|
| 71 |
+
return self._formatted_error
|
| 72 |
+
|
| 73 |
+
def _format_error(self) -> str:
|
| 74 |
+
if not self.nvvm_error:
|
| 75 |
+
return str(self.args[0])
|
| 76 |
+
|
| 77 |
+
return f"""NVVM Compilation Error:
|
| 78 |
+
----------------------
|
| 79 |
+
|
| 80 |
+
{self.BLUE}⚙️ Current Settings:{self.RESET}
|
| 81 |
+
{self.BOLD}- CUDA Toolkit Path: {self.cuda_toolkit or "Not Set"}
|
| 82 |
+
- Target Architecture: {self.arch}{self.RESET}
|
| 83 |
+
|
| 84 |
+
IR Context (truncated):
|
| 85 |
+
{self.ir_context}
|
| 86 |
+
|
| 87 |
+
{self.YELLOW}💡 Possible Solutions:{self.RESET}
|
| 88 |
+
{self.GREEN}1. Check if CUDA_TOOLKIT_PATH is set correctly
|
| 89 |
+
2. Verify target architecture ({self.arch}) is supported by your CUDA toolkit
|
| 90 |
+
3. Make sure CUDA toolkit version matches the target architecture requirements{self.RESET}"""
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class Compiler:
|
| 94 |
+
"""Compiler class for compiling and building MLIR modules."""
|
| 95 |
+
|
| 96 |
+
def __init__(self, passmanager, execution_engine):
|
| 97 |
+
self.passmanager = passmanager
|
| 98 |
+
self.execution_engine = execution_engine
|
| 99 |
+
|
| 100 |
+
def __call__(self, module):
|
| 101 |
+
"""Convenience application method."""
|
| 102 |
+
self.compile(module)
|
| 103 |
+
|
| 104 |
+
def _process_error(self, error_msg: str) -> Tuple[Optional[str], Optional[str]]:
|
| 105 |
+
"""Process error message to extract NVVM error and IR context"""
|
| 106 |
+
nvvm_error = None
|
| 107 |
+
ir_msg = ""
|
| 108 |
+
|
| 109 |
+
if "NVVM_ERROR" in error_msg:
|
| 110 |
+
# Extract the specific NVVM error
|
| 111 |
+
nvvm_error = (
|
| 112 |
+
error_msg.split("libNVVM extra log:")[1].strip()
|
| 113 |
+
if "libNVVM extra log:" in error_msg
|
| 114 |
+
else error_msg
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Extract IR context
|
| 118 |
+
if "see current operation:" in error_msg:
|
| 119 |
+
# Get the IR section
|
| 120 |
+
ir_section = error_msg.split("see current operation:")[1].strip()
|
| 121 |
+
# Remove duplicate IR section
|
| 122 |
+
ir_section = ir_section.split("error: unknown: Failed translating")[
|
| 123 |
+
0
|
| 124 |
+
].strip()
|
| 125 |
+
|
| 126 |
+
# Get first few lines and last few lines of the IR
|
| 127 |
+
ir_lines = ir_section.split("\n")
|
| 128 |
+
if len(ir_lines) > 10:
|
| 129 |
+
ir_msg = "\n".join(ir_lines[:5] + [" ..."] + ir_lines[-5:])
|
| 130 |
+
else:
|
| 131 |
+
ir_msg = ir_section
|
| 132 |
+
|
| 133 |
+
return nvvm_error, ir_msg
|
| 134 |
+
|
| 135 |
+
def compile(
|
| 136 |
+
self,
|
| 137 |
+
module,
|
| 138 |
+
pipeline: str,
|
| 139 |
+
cuda_toolkit: str = "",
|
| 140 |
+
arch: str = "",
|
| 141 |
+
enable_verifier=False,
|
| 142 |
+
):
|
| 143 |
+
"""Compiles the module by invoking the pipeline."""
|
| 144 |
+
try:
|
| 145 |
+
pm = self.passmanager.PassManager.parse(pipeline)
|
| 146 |
+
pm.enable_verifier(enable_verifier)
|
| 147 |
+
pm.run(module.operation)
|
| 148 |
+
except Exception as e:
|
| 149 |
+
error_msg = str(e)
|
| 150 |
+
nvvm_error, ir_msg = self._process_error(error_msg)
|
| 151 |
+
|
| 152 |
+
if nvvm_error:
|
| 153 |
+
raise CompilationError(
|
| 154 |
+
error_msg,
|
| 155 |
+
nvvm_error=nvvm_error,
|
| 156 |
+
ir_context=ir_msg,
|
| 157 |
+
cuda_toolkit=cuda_toolkit,
|
| 158 |
+
arch=arch,
|
| 159 |
+
) from e
|
| 160 |
+
raise e
|
| 161 |
+
|
| 162 |
+
def jit(self, module, opt_level: int = 2, shared_libs: Sequence[str] = ()):
|
| 163 |
+
"""Wraps the module in a JIT execution engine."""
|
| 164 |
+
return self.execution_engine.ExecutionEngine(
|
| 165 |
+
module, opt_level=opt_level, shared_libs=shared_libs
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def compile_and_jit(
|
| 169 |
+
self,
|
| 170 |
+
module,
|
| 171 |
+
pipeline: str,
|
| 172 |
+
shared_libs: Sequence[str] = (),
|
| 173 |
+
opt_level: int = 2,
|
| 174 |
+
cuda_toolkit: str = "",
|
| 175 |
+
arch: str = "",
|
| 176 |
+
):
|
| 177 |
+
"""Compiles and jits the module."""
|
| 178 |
+
self.compile(
|
| 179 |
+
module,
|
| 180 |
+
pipeline,
|
| 181 |
+
cuda_toolkit,
|
| 182 |
+
arch,
|
| 183 |
+
)
|
| 184 |
+
return self.jit(module, opt_level, shared_libs)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class CompileOptions:
|
| 188 |
+
def __init__(self, options: str = ""):
|
| 189 |
+
"""
|
| 190 |
+
This class encapsulates all compilation options relevant to function compilation.
|
| 191 |
+
It provides a convenient way to manage and pass compilation options,
|
| 192 |
+
particularly for controlling compilation settings.
|
| 193 |
+
By centralizing these options, it ensures consistent and flexible configuration of
|
| 194 |
+
compilation parameters such as optimization level, debugging control, etc.
|
| 195 |
+
|
| 196 |
+
:param options: The options for the function. Will be parsed by argparse.
|
| 197 |
+
:type options: str
|
| 198 |
+
"""
|
| 199 |
+
if not isinstance(options, str):
|
| 200 |
+
raise DSLRuntimeError(
|
| 201 |
+
f"Invalid compilation `options`: {options}, it should be a string"
|
| 202 |
+
)
|
| 203 |
+
self._parser = argparse.ArgumentParser()
|
| 204 |
+
self._parser.add_argument("--opt-level", nargs="?", type=int, default=3)
|
| 205 |
+
self._parser.add_argument(
|
| 206 |
+
"--enable-device-assertions", action="store_true", default=False
|
| 207 |
+
)
|
| 208 |
+
self._parser.add_argument("--link-libraries", type=str, default="")
|
| 209 |
+
|
| 210 |
+
try:
|
| 211 |
+
self._options = self._parser.parse_args(options.split())
|
| 212 |
+
except SystemExit as e:
|
| 213 |
+
# catch argparse error and raise as DSLRuntimeError
|
| 214 |
+
raise DSLRuntimeError(
|
| 215 |
+
f"Invalid compile options: '{options}'. Please check the option values and format."
|
| 216 |
+
)
|
| 217 |
+
log().info("`cute.compile` CompileOptions: options=" + options)
|
| 218 |
+
|
| 219 |
+
def to_str(self):
|
| 220 |
+
"""
|
| 221 |
+
Generate a string representation of all compilation options
|
| 222 |
+
which will be used in pipeline options.
|
| 223 |
+
"""
|
| 224 |
+
option_strings = []
|
| 225 |
+
for key, value in vars(self._options).items():
|
| 226 |
+
hyphen_key = key.replace("_", "-")
|
| 227 |
+
if isinstance(value, bool):
|
| 228 |
+
formatted_value = "true" if value else "false"
|
| 229 |
+
else:
|
| 230 |
+
formatted_value = str(value)
|
| 231 |
+
option_strings.append(f"{hyphen_key}={formatted_value}")
|
| 232 |
+
|
| 233 |
+
return " ".join(option_strings)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def compile(func, *args, **kwargs):
|
| 237 |
+
"""
|
| 238 |
+
This function is used to compile a `cute.jit` decorated function.
|
| 239 |
+
It will process the compile options and input parameters, do explicit compilation and return the jit executor.
|
| 240 |
+
|
| 241 |
+
:param func: The function to compile. It can be a regular function, a method or a class instance.
|
| 242 |
+
:param args: The arguments to pass to the function.
|
| 243 |
+
:param kwargs: The keyword arguments to pass to the function. It can contain `options` like
|
| 244 |
+
`opt_level` to control the compilation flags.
|
| 245 |
+
|
| 246 |
+
:return: The jit executor.
|
| 247 |
+
|
| 248 |
+
:raises: DSLRuntimeError if the function is not decorated with `cute.jit` or is not callable.
|
| 249 |
+
"""
|
| 250 |
+
if func is None:
|
| 251 |
+
raise DSLRuntimeError("Function is not set or invalid.")
|
| 252 |
+
|
| 253 |
+
if not callable(func):
|
| 254 |
+
raise DSLRuntimeError("Object is not callable.")
|
| 255 |
+
|
| 256 |
+
kwargs["compile_only"] = True
|
| 257 |
+
kwargs["no_cache"] = True
|
| 258 |
+
|
| 259 |
+
if inspect.isfunction(func):
|
| 260 |
+
# regular function
|
| 261 |
+
pass
|
| 262 |
+
elif inspect.ismethod(func):
|
| 263 |
+
# if it's a method, add the instance to the first argument
|
| 264 |
+
args = [func.__self__] + list(args)
|
| 265 |
+
func = func.__func__
|
| 266 |
+
elif inspect.isclass(type(func)) and hasattr(func, "__call__"):
|
| 267 |
+
# If it's a class instance, get the class's __call__ method
|
| 268 |
+
args = [func] + list(args)
|
| 269 |
+
# Get the actual function from the class definition
|
| 270 |
+
func = func.__call__.__func__
|
| 271 |
+
else:
|
| 272 |
+
raise DSLRuntimeError(
|
| 273 |
+
"Invalid function type, only function, method and module are supported, but got",
|
| 274 |
+
func,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# If it's a wrapped function created by jit decorator, get the original function
|
| 278 |
+
if hasattr(func, "__wrapped__"):
|
| 279 |
+
func = func.__wrapped__
|
| 280 |
+
|
| 281 |
+
if not hasattr(func, "_dsl_object"):
|
| 282 |
+
raise DSLRuntimeError("Function is not decorated with jit decorator.")
|
| 283 |
+
|
| 284 |
+
# process compile options, extract the options and remove them from the kwargs
|
| 285 |
+
options = kwargs.pop("options", "")
|
| 286 |
+
func._dsl_object.compile_options = CompileOptions(options)
|
| 287 |
+
fcn_ptr = func._dsl_object._preprocess_and_execute(func)
|
| 288 |
+
return func._dsl_object._func(fcn_ptr, *args, **kwargs)
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/dsl.py
ADDED
|
@@ -0,0 +1,1686 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This module provides a main DSL class for any Dialect.
|
| 14 |
+
The DSL should be inherited as a new class, and its initialization requires dialects.
|
| 15 |
+
It handles most of the mechanics for the DSL in an agnostic way,
|
| 16 |
+
for example, it can handle various dialect-specific tasks.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Standard library imports
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
+
import atexit
|
| 23 |
+
import os
|
| 24 |
+
import io
|
| 25 |
+
import sys
|
| 26 |
+
import errno
|
| 27 |
+
import ctypes
|
| 28 |
+
import re
|
| 29 |
+
import inspect
|
| 30 |
+
import argparse
|
| 31 |
+
import hashlib
|
| 32 |
+
from functools import lru_cache, wraps
|
| 33 |
+
from collections import namedtuple
|
| 34 |
+
from abc import ABC, abstractmethod
|
| 35 |
+
from typing import Any, Union, Tuple, get_origin, get_args, List
|
| 36 |
+
from types import FunctionType, SimpleNamespace
|
| 37 |
+
import warnings
|
| 38 |
+
|
| 39 |
+
from . import typing as t
|
| 40 |
+
from .env_manager import EnvironmentVarManager
|
| 41 |
+
from .compiler import CompileOptions
|
| 42 |
+
from .ast_helpers import DSLOptimizationWarning
|
| 43 |
+
|
| 44 |
+
# =============================================================================
|
| 45 |
+
# CUDA Python
|
| 46 |
+
# =============================================================================
|
| 47 |
+
|
| 48 |
+
from ..base_dsl._mlir_helpers.arith import const
|
| 49 |
+
|
| 50 |
+
# =============================================================================
|
| 51 |
+
# Local module imports
|
| 52 |
+
# =============================================================================
|
| 53 |
+
|
| 54 |
+
from .cache_helpers import *
|
| 55 |
+
from .jit_executor import JitExecutor
|
| 56 |
+
from .utils.timer import timer
|
| 57 |
+
from .utils.logger import setup_log, log
|
| 58 |
+
from .utils.stacktrace import filter_exception, walk_to_top_module, filter_stackframe
|
| 59 |
+
from .runtime.jit_arg_adapters import is_argument_constexpr, JitArgAdapterRegistry
|
| 60 |
+
|
| 61 |
+
from .ast_preprocessor import DSLPreprocessor
|
| 62 |
+
from .common import *
|
| 63 |
+
from .typing import (
|
| 64 |
+
get_c_pointers,
|
| 65 |
+
get_mlir_types,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# =============================================================================
|
| 69 |
+
# MLIR modules
|
| 70 |
+
# =============================================================================
|
| 71 |
+
|
| 72 |
+
from .._mlir import ir
|
| 73 |
+
from .._mlir import runtime as rt
|
| 74 |
+
from .._mlir.extras import types as T
|
| 75 |
+
from .._mlir.dialects import arith, math, func
|
| 76 |
+
|
| 77 |
+
# =============================================================================
|
| 78 |
+
# Global Variables
|
| 79 |
+
# =============================================================================
|
| 80 |
+
|
| 81 |
+
MLIR_DYNAMIC = -9223372036854775808
|
| 82 |
+
|
| 83 |
+
# =============================================================================
|
| 84 |
+
# Codegen Utils
|
| 85 |
+
# =============================================================================
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _numpy_type_to_mlir_type(dtype):
|
| 89 |
+
if dtype == np.float64:
|
| 90 |
+
return T.f64()
|
| 91 |
+
if dtype == np.float16:
|
| 92 |
+
return T.f16()
|
| 93 |
+
if dtype == np.float32:
|
| 94 |
+
return T.f32()
|
| 95 |
+
if dtype == np.int64:
|
| 96 |
+
return T.i64()
|
| 97 |
+
if dtype == np.int32:
|
| 98 |
+
return T.i32()
|
| 99 |
+
if dtype == np.int16:
|
| 100 |
+
return T.i16()
|
| 101 |
+
if dtype == np.int8:
|
| 102 |
+
return T.i8()
|
| 103 |
+
if dtype == np.uint64:
|
| 104 |
+
return T.ui64()
|
| 105 |
+
if dtype == np.uint32:
|
| 106 |
+
return T.ui32()
|
| 107 |
+
if dtype == np.uint16:
|
| 108 |
+
return T.ui16()
|
| 109 |
+
if dtype == np.uint8:
|
| 110 |
+
return T.ui8()
|
| 111 |
+
if dtype == np.bool_:
|
| 112 |
+
return T.bool()
|
| 113 |
+
if dtype == f8E5M2:
|
| 114 |
+
return T.f8E5M2()
|
| 115 |
+
if dtype == f8E4M3FN:
|
| 116 |
+
return T.f8E4M3FN()
|
| 117 |
+
if dtype == f8E8M0FNU:
|
| 118 |
+
return T.f8E8M0FNU()
|
| 119 |
+
if dtype == f6E3M2FN:
|
| 120 |
+
return T.f6E3M2FN()
|
| 121 |
+
if dtype == f6E2M3FN:
|
| 122 |
+
return T.f6E2M3FN()
|
| 123 |
+
if dtype == f4E2M1FN:
|
| 124 |
+
return T.f4E2M1FN()
|
| 125 |
+
assert False, f"Unknown type {type}"
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _mlir_type_to_numpy_type(type):
|
| 129 |
+
if type == T.f64():
|
| 130 |
+
return np.float64
|
| 131 |
+
if type == T.f16():
|
| 132 |
+
return np.float16
|
| 133 |
+
if type == T.f32():
|
| 134 |
+
return np.float32
|
| 135 |
+
if type == T.i64():
|
| 136 |
+
return np.int64
|
| 137 |
+
if type == T.i32():
|
| 138 |
+
return np.int32
|
| 139 |
+
if type == T.i16():
|
| 140 |
+
return np.int16
|
| 141 |
+
if type == T.i8():
|
| 142 |
+
return np.int8
|
| 143 |
+
if type == T.ui64():
|
| 144 |
+
return np.uint64
|
| 145 |
+
if type == T.ui32():
|
| 146 |
+
return np.uint32
|
| 147 |
+
if type == T.ui16():
|
| 148 |
+
return np.uint16
|
| 149 |
+
if type == T.ui8():
|
| 150 |
+
return np.uint8
|
| 151 |
+
if type == T.bool():
|
| 152 |
+
return np.bool_
|
| 153 |
+
assert False, f"Unknown type {type}"
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# =============================================================================
|
| 157 |
+
# Main DSL Class
|
| 158 |
+
# =============================================================================
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def is_dynamic_expression(value):
|
| 162 |
+
"""
|
| 163 |
+
Given the `value`, check if itself is an IR value or recursively go through it to check if it contains IR value
|
| 164 |
+
"""
|
| 165 |
+
if isinstance(value, (tuple, list)):
|
| 166 |
+
for x in value:
|
| 167 |
+
if is_dynamic_expression(x):
|
| 168 |
+
return True
|
| 169 |
+
elif isinstance(value, (ir.Value, ir.BlockArgumentList)) or hasattr(
|
| 170 |
+
value, "__extract_mlir_values__"
|
| 171 |
+
):
|
| 172 |
+
return True
|
| 173 |
+
return False
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def extract_mlir_values(obj):
|
| 177 |
+
"""
|
| 178 |
+
Given the `obj`, recursively go through it to extract all contained IR values as list of MLIR values
|
| 179 |
+
"""
|
| 180 |
+
res = []
|
| 181 |
+
if hasattr(obj, "__extract_mlir_values__"):
|
| 182 |
+
res = obj.__extract_mlir_values__()
|
| 183 |
+
elif isinstance(obj, (tuple, list)):
|
| 184 |
+
res = sum((extract_mlir_values(x) for x in obj), [])
|
| 185 |
+
elif isinstance(obj, SimpleNamespace):
|
| 186 |
+
res = []
|
| 187 |
+
for k, v in obj.__dict__.items():
|
| 188 |
+
res.extend(extract_mlir_values(v))
|
| 189 |
+
# Can't call is_dynamic_expression as _is_dynamic_expression depends on extract_mlir_values
|
| 190 |
+
elif isinstance(obj, set):
|
| 191 |
+
raise DSLRuntimeError(
|
| 192 |
+
"Sets are not supported in extract_mlir_values to ensure order preservation",
|
| 193 |
+
context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.",
|
| 194 |
+
suggestion="Consider using a list or tuple instead",
|
| 195 |
+
)
|
| 196 |
+
elif isinstance(obj, ir.Value):
|
| 197 |
+
res = [obj]
|
| 198 |
+
elif isinstance(obj, ir.BlockArgumentList):
|
| 199 |
+
res = list(obj) # type: ignore
|
| 200 |
+
|
| 201 |
+
return res
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def new_from_mlir_values(obj, values):
|
| 205 |
+
"""
|
| 206 |
+
Create a new python object by populating containing MLIR values with list of new values
|
| 207 |
+
"""
|
| 208 |
+
if hasattr(obj, "__new_from_mlir_values__"):
|
| 209 |
+
return obj.__new_from_mlir_values__(values)
|
| 210 |
+
elif isinstance(obj, (tuple, list)):
|
| 211 |
+
res = []
|
| 212 |
+
for x in obj:
|
| 213 |
+
n_items = len(get_mlir_types(x))
|
| 214 |
+
res.append(new_from_mlir_values(x, values[:n_items]))
|
| 215 |
+
values = values[n_items:]
|
| 216 |
+
obj_ty = type(obj)
|
| 217 |
+
return obj_ty(res)
|
| 218 |
+
elif isinstance(obj, SimpleNamespace):
|
| 219 |
+
res = SimpleNamespace()
|
| 220 |
+
for k, v in obj.__dict__.items():
|
| 221 |
+
n_items = len(get_mlir_types(v))
|
| 222 |
+
res.__dict__[k] = new_from_mlir_values(v, values[:n_items])
|
| 223 |
+
values = values[n_items:]
|
| 224 |
+
return res
|
| 225 |
+
elif isinstance(obj, set):
|
| 226 |
+
raise DSLRuntimeError(
|
| 227 |
+
"Sets are not supported in new_from_mlir_values to ensure order preservation",
|
| 228 |
+
context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.",
|
| 229 |
+
suggestion="Consider using a list or tuple instead",
|
| 230 |
+
)
|
| 231 |
+
elif is_dynamic_expression(obj):
|
| 232 |
+
|
| 233 |
+
if len(values) == 0:
|
| 234 |
+
return obj
|
| 235 |
+
|
| 236 |
+
assert len(values) == 1
|
| 237 |
+
return values[0]
|
| 238 |
+
else:
|
| 239 |
+
assert len(values) == 0, f"{obj} expects 0 values, but got {values}"
|
| 240 |
+
return obj
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class DSLCallable:
|
| 244 |
+
"""
|
| 245 |
+
Wrapper class for a callable object used within the DSL.
|
| 246 |
+
|
| 247 |
+
DSLCallable is designed to wrap a function and provide additional
|
| 248 |
+
introspection utilities such as retrieving the argument specification
|
| 249 |
+
and signature. It ensures that the wrapped function can only be called
|
| 250 |
+
once, after which the reference to the function is cleared to prevent
|
| 251 |
+
further invocations. This is useful in scenarios where a function should
|
| 252 |
+
only be executed a single time within the DSL's execution model.
|
| 253 |
+
|
| 254 |
+
Attributes:
|
| 255 |
+
func (callable): The function to be wrapped and managed.
|
| 256 |
+
|
| 257 |
+
Methods:
|
| 258 |
+
__call__(*args, **kwargs): Calls the wrapped function and clears it.
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
def __init__(self, func):
|
| 262 |
+
self.func = func
|
| 263 |
+
|
| 264 |
+
def __call__(self, *args, **kwargs):
|
| 265 |
+
ret = self.__func__(*args, **kwargs)
|
| 266 |
+
self.func = None
|
| 267 |
+
return ret
|
| 268 |
+
|
| 269 |
+
@property
|
| 270 |
+
def __func__(self):
|
| 271 |
+
assert self.func is not None, "DSLCallable is already called"
|
| 272 |
+
return self.func
|
| 273 |
+
|
| 274 |
+
@property
|
| 275 |
+
def __signature__(self):
|
| 276 |
+
return inspect.signature(self.__func__)
|
| 277 |
+
|
| 278 |
+
@property
|
| 279 |
+
def __name__(self):
|
| 280 |
+
return self.__func__.__name__
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class BaseDSL:
|
| 284 |
+
gpu_module = None
|
| 285 |
+
|
| 286 |
+
def __init__(
|
| 287 |
+
self,
|
| 288 |
+
*,
|
| 289 |
+
name: str,
|
| 290 |
+
dsl_package_name: List[str],
|
| 291 |
+
compiler_provider: Any,
|
| 292 |
+
pass_sm_arch_name: str,
|
| 293 |
+
device_compilation_only=False,
|
| 294 |
+
preprocess=False,
|
| 295 |
+
):
|
| 296 |
+
"""
|
| 297 |
+
Constructor for initializing the class with required providers and environment settings.
|
| 298 |
+
|
| 299 |
+
Parameters:
|
| 300 |
+
- name (str): Name of DSL, used for environment variables and logging.
|
| 301 |
+
- package_name (str): Name of the package, used for the preprocessor.
|
| 302 |
+
- compiler_provider (MLIR dialect): Provider for compiler.
|
| 303 |
+
- pass_sm_arch_name (str): The keyword name of the SM.
|
| 304 |
+
- device_compilation_only (bool) : Only device code, and call it via cuda driver
|
| 305 |
+
- preprocess (bool): Enable AST transformation.
|
| 306 |
+
|
| 307 |
+
This constructs a DSL instance and sets up environment management,
|
| 308 |
+
warning configurations, and logging functionalities. It reads
|
| 309 |
+
environment variables using `EnvironmentVarManager` and configures
|
| 310 |
+
a logger with settings from the environment. If environment warnings
|
| 311 |
+
are detected, they are escalated to errors to ensure strict handling.
|
| 312 |
+
"""
|
| 313 |
+
# Enforcing initialization of instance variables
|
| 314 |
+
if not all([name, compiler_provider, pass_sm_arch_name]):
|
| 315 |
+
raise DSLRuntimeError(
|
| 316 |
+
"All required parameters must be provided and non-empty"
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
self.name = name
|
| 320 |
+
self.compiler_provider = compiler_provider
|
| 321 |
+
self.pass_sm_arch_name = pass_sm_arch_name
|
| 322 |
+
self.frame = None
|
| 323 |
+
self.no_cache = False
|
| 324 |
+
self.device_compilation_only = device_compilation_only
|
| 325 |
+
self.num_kernels = 0
|
| 326 |
+
# Read environment variables
|
| 327 |
+
self.envar = EnvironmentVarManager(self.name)
|
| 328 |
+
self.enable_preprocessor = preprocess
|
| 329 |
+
# This cache uses hash of original ir and env as key, allows dump/load to/from file. Enabled by default
|
| 330 |
+
self.jit_cache = (
|
| 331 |
+
dict()
|
| 332 |
+
if self.envar.disable_file_caching
|
| 333 |
+
else load_cache_from_path(self.name, self.envar.file_caching_capacity)
|
| 334 |
+
)
|
| 335 |
+
self.host_jit_decorator_name = f"@{BaseDSL.jit.__name__}"
|
| 336 |
+
self.device_jit_decorator_name = f"@{BaseDSL.kernel.__name__}"
|
| 337 |
+
|
| 338 |
+
# set warning
|
| 339 |
+
if not self.envar.enable_optimization_warnings:
|
| 340 |
+
# By default, optimization warnings are disabled
|
| 341 |
+
warnings.filterwarnings("ignore", category=DSLOptimizationWarning)
|
| 342 |
+
if self.envar.warnings_as_errors:
|
| 343 |
+
warnings.filterwarnings("error")
|
| 344 |
+
if self.envar.warnings_ignore:
|
| 345 |
+
warnings.filterwarnings("ignore")
|
| 346 |
+
|
| 347 |
+
# Initialize logger
|
| 348 |
+
if self.envar.log_to_console == False and self.envar.jitTimeProfiling:
|
| 349 |
+
self.envar.log_to_console = True
|
| 350 |
+
self.envar.log_level = 20 # info level
|
| 351 |
+
setup_log(
|
| 352 |
+
self.name,
|
| 353 |
+
self.envar.log_to_console,
|
| 354 |
+
self.envar.log_to_file,
|
| 355 |
+
f"{self.name}.log",
|
| 356 |
+
self.envar.log_level,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
# kernel symbols are temporary symbol string variables, their values are valid until the compilation is done.
|
| 360 |
+
self.kernel_symbols = []
|
| 361 |
+
# used to generate unique name for gpu.launch
|
| 362 |
+
self.launch_inner_count = 0
|
| 363 |
+
# initialize default compile options
|
| 364 |
+
self.compile_options = CompileOptions()
|
| 365 |
+
|
| 366 |
+
if preprocess:
|
| 367 |
+
self.preprocessor = DSLPreprocessor(dsl_package_name)
|
| 368 |
+
log().info(f"Initializing {name} DSL")
|
| 369 |
+
log().debug(f"Logger initialized for {self.name}")
|
| 370 |
+
|
| 371 |
+
# Hook excepthook
|
| 372 |
+
if self.envar.filterStacktrace:
|
| 373 |
+
origin_excepthook = sys.excepthook
|
| 374 |
+
module_dir = walk_to_top_module(os.path.dirname(os.path.abspath(__file__)))
|
| 375 |
+
|
| 376 |
+
def excepthook(excep_type, value, traceback):
|
| 377 |
+
filter_exception(value, module_dir)
|
| 378 |
+
if hasattr(value, "__traceback__"):
|
| 379 |
+
origin_excepthook(excep_type, value, value.__traceback__)
|
| 380 |
+
else:
|
| 381 |
+
origin_excepthook(
|
| 382 |
+
excep_type, value, filter_stackframe(traceback, module_dir)
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
sys.excepthook = excepthook
|
| 386 |
+
|
| 387 |
+
# Restore original excepthook
|
| 388 |
+
def restore_excepthook(hook):
|
| 389 |
+
sys.excepthook = hook
|
| 390 |
+
|
| 391 |
+
atexit.register(restore_excepthook, origin_excepthook)
|
| 392 |
+
|
| 393 |
+
def dump_cache(self):
|
| 394 |
+
if not self.envar.disable_file_caching:
|
| 395 |
+
dump_cache_to_path(
|
| 396 |
+
self.name, self.jit_cache, self.envar.file_caching_capacity
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
@lru_cache(maxsize=1)
|
| 400 |
+
def print_warning_once(self, message):
|
| 401 |
+
log().warning(f"Warning: {message}")
|
| 402 |
+
warnings.warn(message, UserWarning)
|
| 403 |
+
|
| 404 |
+
def print_warning(self, message):
|
| 405 |
+
log().warning(f"Warning: {message}")
|
| 406 |
+
warnings.warn(message, UserWarning)
|
| 407 |
+
|
| 408 |
+
@classmethod
|
| 409 |
+
@lru_cache(maxsize=1)
|
| 410 |
+
def _get_dsl(cls):
|
| 411 |
+
# Instantiate the DSL Class once
|
| 412 |
+
main_dsl = cls()
|
| 413 |
+
if not main_dsl.no_cache:
|
| 414 |
+
# register atexit callback
|
| 415 |
+
atexit.register(main_dsl.dump_cache)
|
| 416 |
+
return main_dsl
|
| 417 |
+
|
| 418 |
+
@staticmethod
|
| 419 |
+
def _can_preprocess(**dkwargs):
|
| 420 |
+
"""
|
| 421 |
+
Check if AST transformation is enabled or not for `jit` and `kernel` decorators.
|
| 422 |
+
"""
|
| 423 |
+
return dkwargs.pop("preprocess", True)
|
| 424 |
+
|
| 425 |
+
@staticmethod
|
| 426 |
+
def _get_original_function(fcn_ptr, name):
|
| 427 |
+
"""
|
| 428 |
+
Get the original function from the decorated function
|
| 429 |
+
"""
|
| 430 |
+
while fcn_ptr.__name__ != name:
|
| 431 |
+
# If the function is wrapped with functools, get from __wrapped__
|
| 432 |
+
if hasattr(fcn_ptr, "__wrapped__"):
|
| 433 |
+
fcn_ptr = fcn_ptr.__wrapped__
|
| 434 |
+
# If the function is wrapped manually, it's the first in clousure
|
| 435 |
+
elif callable(fcn_ptr.__closure__[0].cell_contents):
|
| 436 |
+
fcn_ptr = fcn_ptr.__closure__[0].cell_contents
|
| 437 |
+
else:
|
| 438 |
+
raise DSLRuntimeError(
|
| 439 |
+
f"Cannot find the original function {name} in the closure chain"
|
| 440 |
+
)
|
| 441 |
+
return fcn_ptr
|
| 442 |
+
|
| 443 |
+
@staticmethod
|
| 444 |
+
def _preprocess_and_execute(func):
|
| 445 |
+
"""
|
| 446 |
+
Run ast transformation and return the materialized function pointer
|
| 447 |
+
"""
|
| 448 |
+
if hasattr(func, "_transformed_ast"):
|
| 449 |
+
# If the function ptr is already materialized, use the existing one
|
| 450 |
+
func._dsl_object.frame = func._decorator_frame
|
| 451 |
+
if func._transformed_ast is None:
|
| 452 |
+
func._transformed_ast = func._dsl_object.run_preprocessor(func)
|
| 453 |
+
if func._transformed_ast is None:
|
| 454 |
+
del func._transformed_ast
|
| 455 |
+
func._dsl_object.frame = None
|
| 456 |
+
return func
|
| 457 |
+
|
| 458 |
+
fcn_ptr = func._dsl_object.get_function_ptr(func)
|
| 459 |
+
# If the function is decorated, de-decorate it
|
| 460 |
+
fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__)
|
| 461 |
+
func._dsl_object.frame = None
|
| 462 |
+
return DSLCallable(fcn_ptr)
|
| 463 |
+
return func
|
| 464 |
+
|
| 465 |
+
def jit_runner(self, executor, frame, *dargs, **dkwargs):
|
| 466 |
+
"""
|
| 467 |
+
Decorator to mark a function for JIT compilation.
|
| 468 |
+
"""
|
| 469 |
+
log().info("jit_runner")
|
| 470 |
+
|
| 471 |
+
def jit_runner_decorator(func):
|
| 472 |
+
func._dsl_object = self
|
| 473 |
+
# Run preprocessor that alters AST
|
| 474 |
+
if self.enable_preprocessor and BaseDSL._can_preprocess(**dkwargs):
|
| 475 |
+
# For an annotated function, add some DSL attributes
|
| 476 |
+
# When materializing the AST, we need decorator's frame
|
| 477 |
+
func._decorator_frame = frame
|
| 478 |
+
# No transformed ast at this point
|
| 479 |
+
func._transformed_ast = None
|
| 480 |
+
|
| 481 |
+
@wraps(func)
|
| 482 |
+
def jit_wrapper(*args, **kwargs):
|
| 483 |
+
func_ptr = BaseDSL._preprocess_and_execute(func)
|
| 484 |
+
return executor(func_ptr, *args, **kwargs)
|
| 485 |
+
|
| 486 |
+
return jit_wrapper
|
| 487 |
+
|
| 488 |
+
if len(dargs) == 1 and callable(dargs[0]):
|
| 489 |
+
return jit_runner_decorator(dargs[0])
|
| 490 |
+
else:
|
| 491 |
+
return jit_runner_decorator
|
| 492 |
+
|
| 493 |
+
@classmethod
|
| 494 |
+
def jit(cls, *dargs, **dkwargs):
|
| 495 |
+
"""
|
| 496 |
+
Decorator to mark a function for JIT compilation for Host code.
|
| 497 |
+
"""
|
| 498 |
+
frame = inspect.currentframe().f_back
|
| 499 |
+
# Instantiate the DSL Class
|
| 500 |
+
main_dsl = cls._get_dsl()
|
| 501 |
+
return main_dsl.jit_runner(main_dsl._func, frame, *dargs, **dkwargs)
|
| 502 |
+
|
| 503 |
+
@classmethod
|
| 504 |
+
def kernel(cls, *dargs, **dkwargs):
|
| 505 |
+
"""
|
| 506 |
+
Decorator to mark a function for JIT compilation for GPU.
|
| 507 |
+
"""
|
| 508 |
+
frame = inspect.currentframe().f_back
|
| 509 |
+
# Instantiate the DSL Class
|
| 510 |
+
main_dsl = cls._get_dsl()
|
| 511 |
+
return main_dsl.jit_runner(main_dsl._kernel_helper, frame, *dargs, **dkwargs)
|
| 512 |
+
|
| 513 |
+
@abstractmethod
|
| 514 |
+
def _kernel_helper(self, func, *args, **kwargs):
|
| 515 |
+
"""
|
| 516 |
+
Helper function to handle kernel generation logic
|
| 517 |
+
"""
|
| 518 |
+
pass
|
| 519 |
+
|
| 520 |
+
@abstractmethod
|
| 521 |
+
def _build_gpu_module(self, attrs):
|
| 522 |
+
"""
|
| 523 |
+
Build the module op that contains the kernels.
|
| 524 |
+
"""
|
| 525 |
+
pass
|
| 526 |
+
|
| 527 |
+
@abstractmethod
|
| 528 |
+
def _get_pipeline(self, pipeline):
|
| 529 |
+
"""
|
| 530 |
+
Get the pipeline from the other configuration options.
|
| 531 |
+
"""
|
| 532 |
+
if pipeline != None:
|
| 533 |
+
return pipeline
|
| 534 |
+
return None
|
| 535 |
+
|
| 536 |
+
@staticmethod
|
| 537 |
+
def log_additions(func_type, operands=None, types=None, arg_attrs=None):
|
| 538 |
+
if operands is not None and operands != []:
|
| 539 |
+
log().debug(
|
| 540 |
+
f"Added {func_type} operands: [%s]", ", ".join(map(str, operands))
|
| 541 |
+
)
|
| 542 |
+
if types is not None:
|
| 543 |
+
log().debug(
|
| 544 |
+
f"Added {func_type} arg_types: [%s]", ", ".join(map(str, types))
|
| 545 |
+
)
|
| 546 |
+
if arg_attrs is not None:
|
| 547 |
+
log().debug(
|
| 548 |
+
f"Added {func_type} arg_attrs: [%s]", ", ".join(map(str, arg_attrs))
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
def mangle_name(self, function_name, args, args_spec: inspect.FullArgSpec):
|
| 552 |
+
"""Does simple name mangling"""
|
| 553 |
+
|
| 554 |
+
for spec_arg, arg in zip(args_spec.args, args):
|
| 555 |
+
spec_ty = args_spec.annotations.get(spec_arg, None)
|
| 556 |
+
if spec_ty != None:
|
| 557 |
+
if issubclass(type(spec_ty), (t.IRValue, t.IRVariadic)):
|
| 558 |
+
continue
|
| 559 |
+
if isinstance(spec_ty, (ir.Type, ir.Value)):
|
| 560 |
+
continue
|
| 561 |
+
if isinstance(arg, (ir.Type, ir.Value, ir.OpResult)):
|
| 562 |
+
continue
|
| 563 |
+
if isinstance(type(arg), (ir.Type, ir.Value, ir.OpResult)):
|
| 564 |
+
continue
|
| 565 |
+
if self._is_tensor_descriptor(arg):
|
| 566 |
+
continue
|
| 567 |
+
if inspect.isclass(spec_ty):
|
| 568 |
+
class_name = str(arg).replace("class", "")
|
| 569 |
+
class_name = class_name.replace(" ", "")
|
| 570 |
+
function_name = f"{function_name}_{class_name}"
|
| 571 |
+
elif isinstance(arg, (list, tuple)):
|
| 572 |
+
function_name = f"{function_name}_{'_'.join(map(str, arg))}"
|
| 573 |
+
else:
|
| 574 |
+
function_name = f"{function_name}_{arg}"
|
| 575 |
+
# we would need a dedicated MR to follow up
|
| 576 |
+
unwanted_chars = r"'-![]#,.<>()\":{}=%?@;"
|
| 577 |
+
translation_table = str.maketrans("", "", unwanted_chars)
|
| 578 |
+
function_name = function_name.translate(translation_table)
|
| 579 |
+
# identify address and drop
|
| 580 |
+
function_name = re.sub(r"0x[a-f0-9]{8,16}", "", function_name)
|
| 581 |
+
function_name = re.sub(r"\s+", " ", function_name)
|
| 582 |
+
function_name = function_name.replace(" ", "_")
|
| 583 |
+
function_name = function_name.replace("\n", "_")
|
| 584 |
+
# max fname is 256 character, leave space
|
| 585 |
+
function_name = function_name[:180]
|
| 586 |
+
log().info(f"Final mangled function name: {function_name}")
|
| 587 |
+
return function_name
|
| 588 |
+
|
| 589 |
+
def _generate_execution_arguments_for_known_types(
|
| 590 |
+
self, arg, arg_spec, arg_name, i, fop_args, iv_block_args
|
| 591 |
+
):
|
| 592 |
+
"""
|
| 593 |
+
Generate MLIR arguments for known types.
|
| 594 |
+
|
| 595 |
+
Sub-DSLs can override this method to handle types that are not
|
| 596 |
+
natively supported by the Base DSL.
|
| 597 |
+
"""
|
| 598 |
+
ir_arg = []
|
| 599 |
+
if is_argument_constexpr(arg, arg_spec, arg_name, i, func):
|
| 600 |
+
ir_arg.append(arg)
|
| 601 |
+
|
| 602 |
+
return ir_arg, iv_block_args
|
| 603 |
+
|
| 604 |
+
def generate_execution_arguments(
|
| 605 |
+
self,
|
| 606 |
+
args,
|
| 607 |
+
kwargs,
|
| 608 |
+
fop,
|
| 609 |
+
args_spec: inspect.FullArgSpec,
|
| 610 |
+
):
|
| 611 |
+
"""Create list of arguments that will be passed to MLIR's func.func op"""
|
| 612 |
+
|
| 613 |
+
def gen_exec_args(input_args, arg_names, annotations, fop_args):
|
| 614 |
+
assert len(input_args) == len(arg_names)
|
| 615 |
+
|
| 616 |
+
ir_args = []
|
| 617 |
+
iv_block_args = 0
|
| 618 |
+
for i, arg in enumerate(input_args):
|
| 619 |
+
arg_name = arg_names[i]
|
| 620 |
+
arg_spec = annotations.get(arg_name, None)
|
| 621 |
+
log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, arg_spec)
|
| 622 |
+
|
| 623 |
+
# Implicit cast to NumericMeta
|
| 624 |
+
if isinstance(arg_spec, t.NumericMeta) and not isinstance(
|
| 625 |
+
arg, arg_spec
|
| 626 |
+
):
|
| 627 |
+
arg = t.cast(arg, arg_spec)
|
| 628 |
+
|
| 629 |
+
ir_arg, iv_block_args = (
|
| 630 |
+
self._generate_execution_arguments_for_known_types(
|
| 631 |
+
arg, arg_spec, arg_name, i, fop_args, iv_block_args
|
| 632 |
+
)
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
if not ir_arg:
|
| 636 |
+
# If it's not a known type, try JIT argument adapter
|
| 637 |
+
# to convert the argument if possible
|
| 638 |
+
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
|
| 639 |
+
arg = adapter(arg) if adapter else arg
|
| 640 |
+
|
| 641 |
+
n_args = len(get_mlir_types(arg))
|
| 642 |
+
blk_args = fop_args[iv_block_args : iv_block_args + n_args]
|
| 643 |
+
ir_arg.append(new_from_mlir_values(arg, blk_args))
|
| 644 |
+
iv_block_args += n_args
|
| 645 |
+
|
| 646 |
+
self.log_additions(ir_arg)
|
| 647 |
+
ir_args.extend(ir_arg)
|
| 648 |
+
|
| 649 |
+
return ir_args, iv_block_args
|
| 650 |
+
|
| 651 |
+
fop_args = list(fop.regions[0].blocks[0].arguments)
|
| 652 |
+
ir_args, iv_block_args = gen_exec_args(
|
| 653 |
+
args, args_spec.args, args_spec.annotations, fop_args
|
| 654 |
+
)
|
| 655 |
+
ir_kwargs, _ = gen_exec_args(
|
| 656 |
+
[kwargs[arg] for arg in args_spec.kwonlyargs],
|
| 657 |
+
args_spec.kwonlyargs,
|
| 658 |
+
args_spec.annotations,
|
| 659 |
+
fop_args[iv_block_args:],
|
| 660 |
+
)
|
| 661 |
+
ir_kwargs = {k: v for k, v in zip(args_spec.kwonlyargs, ir_kwargs)}
|
| 662 |
+
|
| 663 |
+
log().debug("execution args: %s", ", ".join(map(str, ir_args)))
|
| 664 |
+
log().debug("execution kwargs: %s", ", ".join(map(str, ir_kwargs)))
|
| 665 |
+
return ir_args, ir_kwargs
|
| 666 |
+
|
| 667 |
+
@abstractmethod
|
| 668 |
+
def _generate_mlir_type_for_tensor_descriptor(self, tensor):
|
| 669 |
+
"""
|
| 670 |
+
Generate MLIR type for the tensor descriptor.
|
| 671 |
+
"""
|
| 672 |
+
pass
|
| 673 |
+
|
| 674 |
+
@abstractmethod
|
| 675 |
+
def _generate_executable_arg_for_tensor_descriptor(
|
| 676 |
+
self, mlir_value=None, ptr_tensor_ty=None, tensor=None
|
| 677 |
+
):
|
| 678 |
+
"""
|
| 679 |
+
Generates executable value for the given tensor descriptor.
|
| 680 |
+
"""
|
| 681 |
+
pass
|
| 682 |
+
|
| 683 |
+
def _get_globals(self):
|
| 684 |
+
"""
|
| 685 |
+
Combines global and local variables from the current context and the
|
| 686 |
+
caller's frame comes. This includes the current module's globals, the
|
| 687 |
+
global variables from the caller's frame, and the local variables from
|
| 688 |
+
the caller's frame.
|
| 689 |
+
|
| 690 |
+
"self.frame" is used to fetch the caller's frame.
|
| 691 |
+
|
| 692 |
+
AST preprocessor generates a new python code, so the resulting globals
|
| 693 |
+
dictionary is used to execute the python code.
|
| 694 |
+
"""
|
| 695 |
+
all_globals = {}
|
| 696 |
+
if self.frame:
|
| 697 |
+
all_globals.update(self.frame.f_globals)
|
| 698 |
+
all_globals.update(self.frame.f_locals)
|
| 699 |
+
return all_globals
|
| 700 |
+
|
| 701 |
+
@abstractmethod
|
| 702 |
+
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
|
| 703 |
+
pass
|
| 704 |
+
|
| 705 |
+
@abstractmethod
|
| 706 |
+
def _handle_tensor_descriptor(
|
| 707 |
+
self, maybe_tensor, arg_name: str, need_gpu_memory: bool
|
| 708 |
+
) -> Any:
|
| 709 |
+
pass
|
| 710 |
+
|
| 711 |
+
def _validate_arg(self, arg, arg_index, arg_name, arg_spec):
|
| 712 |
+
"""
|
| 713 |
+
Validates if the arg is really of the annotated type for type safety.
|
| 714 |
+
|
| 715 |
+
The default implementation is empty. Subclasses can override this method to add more validation logic.
|
| 716 |
+
Returns None if validation passes, otherwise returns an error derived from DSLBaseError.
|
| 717 |
+
"""
|
| 718 |
+
pass
|
| 719 |
+
|
| 720 |
+
def _generate_jit_func_args_for_known_types(
|
| 721 |
+
self,
|
| 722 |
+
func,
|
| 723 |
+
arg,
|
| 724 |
+
arg_name,
|
| 725 |
+
arg_spec,
|
| 726 |
+
arg_index,
|
| 727 |
+
*,
|
| 728 |
+
is_host=True,
|
| 729 |
+
):
|
| 730 |
+
"""
|
| 731 |
+
Generate JIT function arguments for known types.
|
| 732 |
+
|
| 733 |
+
Sub-DSLs can override this method to handle types that are not
|
| 734 |
+
natively supported by the Base DSL.
|
| 735 |
+
"""
|
| 736 |
+
|
| 737 |
+
jit_arg_type, jit_arg_attr, jit_exec_arg = [], [], []
|
| 738 |
+
default_attr = ir.DictAttr.get({})
|
| 739 |
+
|
| 740 |
+
if is_argument_constexpr(arg, arg_spec, arg_name, arg_index, func):
|
| 741 |
+
jit_exec_arg = jit_arg_type = jit_arg_attr = None
|
| 742 |
+
|
| 743 |
+
return jit_exec_arg, jit_arg_type, jit_arg_attr
|
| 744 |
+
|
| 745 |
+
def _generate_jit_func_args(
|
| 746 |
+
self,
|
| 747 |
+
func,
|
| 748 |
+
function_name,
|
| 749 |
+
args,
|
| 750 |
+
kwargs,
|
| 751 |
+
args_spec: inspect.FullArgSpec,
|
| 752 |
+
*,
|
| 753 |
+
is_host=True,
|
| 754 |
+
):
|
| 755 |
+
"""Generate JIT function arguments."""
|
| 756 |
+
|
| 757 |
+
assert len(args) == len(args_spec.args) and len(kwargs) == len(
|
| 758 |
+
args_spec.kwonlyargs
|
| 759 |
+
), (
|
| 760 |
+
f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args "
|
| 761 |
+
f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}"
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], []
|
| 765 |
+
jit_adapted_args = []
|
| 766 |
+
default_attr = ir.DictAttr.get({})
|
| 767 |
+
|
| 768 |
+
input_args = [*args, *kwargs.values()]
|
| 769 |
+
input_arg_names = [*args_spec.args, *args_spec.kwonlyargs]
|
| 770 |
+
for i, (arg_name, arg) in enumerate(zip(input_arg_names, input_args)):
|
| 771 |
+
spec_ty = args_spec.annotations.get(arg_name, None)
|
| 772 |
+
log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, spec_ty)
|
| 773 |
+
|
| 774 |
+
# Implicitly convert into Numeric type if possible
|
| 775 |
+
if isinstance(spec_ty, t.NumericMeta) and not isinstance(arg, spec_ty):
|
| 776 |
+
arg = t.cast(arg, spec_ty)
|
| 777 |
+
|
| 778 |
+
# Type safety check
|
| 779 |
+
if spec_ty is not None:
|
| 780 |
+
err = self._validate_arg(arg, i, arg_name, spec_ty)
|
| 781 |
+
if err is not None:
|
| 782 |
+
raise err
|
| 783 |
+
|
| 784 |
+
jit_exec_arg, jit_arg_type, jit_arg_attr = (
|
| 785 |
+
self._generate_jit_func_args_for_known_types(
|
| 786 |
+
func,
|
| 787 |
+
arg,
|
| 788 |
+
arg_name,
|
| 789 |
+
spec_ty,
|
| 790 |
+
i,
|
| 791 |
+
is_host=is_host,
|
| 792 |
+
)
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
if jit_arg_type is not None and len(jit_arg_type) == 0:
|
| 796 |
+
# If not any known type, try JIT argument adapter
|
| 797 |
+
# to convert the argument
|
| 798 |
+
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
|
| 799 |
+
if adapter:
|
| 800 |
+
arg = adapter(arg)
|
| 801 |
+
jit_adapted_args.append(arg)
|
| 802 |
+
|
| 803 |
+
if is_host:
|
| 804 |
+
jit_exec_arg.extend(get_c_pointers(arg))
|
| 805 |
+
jit_arg_type.extend(get_mlir_types(arg))
|
| 806 |
+
else:
|
| 807 |
+
dyn_vals = extract_mlir_values(arg)
|
| 808 |
+
jit_exec_arg.extend(dyn_vals)
|
| 809 |
+
jit_arg_type.extend([v.type for v in dyn_vals])
|
| 810 |
+
|
| 811 |
+
if not jit_arg_type or not jit_exec_arg:
|
| 812 |
+
if (is_host and hasattr(arg, "__c_pointers__")) or (
|
| 813 |
+
not is_host
|
| 814 |
+
and hasattr(arg, "__extract_mlir_values__")
|
| 815 |
+
and hasattr(arg, "__new_from_mlir_values__")
|
| 816 |
+
):
|
| 817 |
+
pass
|
| 818 |
+
else:
|
| 819 |
+
raise DSLRuntimeError(
|
| 820 |
+
f"failed to generate argument #{i+1} ({arg_name}) for JIT function '{function_name}'.",
|
| 821 |
+
context={
|
| 822 |
+
f"Argument {arg_name}": "The DSL attempted to convert it into Dynamic Expression (aka MLIR values) but failed.",
|
| 823 |
+
f"Call-site argument value": arg,
|
| 824 |
+
f"Call-site argument type": type(arg),
|
| 825 |
+
},
|
| 826 |
+
suggestion=f"Consider annotating the argument with `{arg_name} : Constexpr` "
|
| 827 |
+
"if it's a value known at compile-time. "
|
| 828 |
+
f"Otherwise, implement the {'`JitArgument`' if is_host else '`DynamicExpression`'} "
|
| 829 |
+
f"protocol or register a custom JIT argument adapter for type `{type(arg)}` to "
|
| 830 |
+
"enable dynamic value conversion at runtime.",
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
jit_arg_attr.extend([default_attr] * len(jit_arg_type))
|
| 834 |
+
|
| 835 |
+
if jit_arg_type is not None:
|
| 836 |
+
jit_exec_args.extend(jit_exec_arg)
|
| 837 |
+
jit_arg_types.extend(jit_arg_type)
|
| 838 |
+
jit_arg_attrs.extend(jit_arg_attr)
|
| 839 |
+
|
| 840 |
+
return jit_exec_args, jit_arg_types, jit_arg_attrs, jit_adapted_args
|
| 841 |
+
|
| 842 |
+
def generate_mlir_function_types(
|
| 843 |
+
self, func, function_name, input_args, kwargs, args_spec: inspect.FullArgSpec
|
| 844 |
+
):
|
| 845 |
+
"""Convert input arguments to MLIR function signature also convert numpy arrays to memref."""
|
| 846 |
+
|
| 847 |
+
exe_args, types, attrs, adapted_args = self._generate_jit_func_args(
|
| 848 |
+
func, function_name, input_args, kwargs, args_spec, is_host=True
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
log().debug("Execution Arguments: %s", ", ".join(map(str, exe_args)))
|
| 852 |
+
log().debug("Types: %s", ", ".join(map(str, types)))
|
| 853 |
+
|
| 854 |
+
assert len(exe_args) == len(
|
| 855 |
+
types
|
| 856 |
+
), "expects the same number of arguments and function parameters"
|
| 857 |
+
|
| 858 |
+
return exe_args, types, adapted_args
|
| 859 |
+
|
| 860 |
+
@dataclass
|
| 861 |
+
class LaunchConfig:
|
| 862 |
+
cluster: list = None
|
| 863 |
+
grid: list = field(default_factory=lambda: [1, 1, 1])
|
| 864 |
+
block: list = field(default_factory=lambda: [1, 1, 1])
|
| 865 |
+
smem: int = None
|
| 866 |
+
async_deps: list = field(default_factory=list)
|
| 867 |
+
has_cluster: bool = False
|
| 868 |
+
min_blocks_per_mp: int = 0
|
| 869 |
+
auto_smem: bool = False
|
| 870 |
+
|
| 871 |
+
def __post_init__(self):
|
| 872 |
+
if len(self.grid) != 3:
|
| 873 |
+
raise DSLRuntimeError(f"Expect 3d grid!")
|
| 874 |
+
if len(self.block) != 3:
|
| 875 |
+
raise DSLRuntimeError(f"Expect 3d block!")
|
| 876 |
+
|
| 877 |
+
if self.smem is None:
|
| 878 |
+
self.smem = 0
|
| 879 |
+
self.auto_smem = True
|
| 880 |
+
|
| 881 |
+
self.has_cluster = self.cluster is not None
|
| 882 |
+
if self.cluster is None:
|
| 883 |
+
self.cluster = [None, None, None]
|
| 884 |
+
elif len(self.cluster) != 3:
|
| 885 |
+
raise DSLRuntimeError(f"Expect 3d cluster!")
|
| 886 |
+
|
| 887 |
+
def diagnostic(self):
|
| 888 |
+
"""Check command line parameters and enables diagnostic"""
|
| 889 |
+
# Check command line arguments "-diagnostic"
|
| 890 |
+
parser = argparse.ArgumentParser(description="Process diagnostic status.")
|
| 891 |
+
parser.add_argument(
|
| 892 |
+
"-diagnostic",
|
| 893 |
+
nargs="?",
|
| 894 |
+
const="all",
|
| 895 |
+
choices=["all", "fail", "success", "info", "suggestion"],
|
| 896 |
+
help="Set diagnostic status (fail, success, info, suggestion).",
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
args, _ = parser.parse_known_args()
|
| 900 |
+
ctx = ir.Context.current
|
| 901 |
+
|
| 902 |
+
def callback(d):
|
| 903 |
+
print(f" [{self.name} Diagnostic] : {d.message}")
|
| 904 |
+
|
| 905 |
+
ctx.attach_diagnostic_handler(callback)
|
| 906 |
+
|
| 907 |
+
# Early return, don't enable diagnostics
|
| 908 |
+
if args.diagnostic is None:
|
| 909 |
+
return
|
| 910 |
+
|
| 911 |
+
# Enable MLIR Flags
|
| 912 |
+
ctx.emit_error_diagnostics = True
|
| 913 |
+
ir._GlobalDebug.flag = True
|
| 914 |
+
if args.diagnostic == "all":
|
| 915 |
+
ir._GlobalDebug.set_types("diagnostic")
|
| 916 |
+
else:
|
| 917 |
+
ir._GlobalDebug.set_types(f"diagnostic-{args.diagnostic}")
|
| 918 |
+
|
| 919 |
+
def get_location(self):
|
| 920 |
+
"""
|
| 921 |
+
Get python location information and generate MLIR location
|
| 922 |
+
"""
|
| 923 |
+
|
| 924 |
+
if self.frame is None:
|
| 925 |
+
log().debug("Frame is None")
|
| 926 |
+
return None
|
| 927 |
+
|
| 928 |
+
file_loc = ir.Location.file(
|
| 929 |
+
self.frame.f_code.co_filename, self.frame.f_lineno, 0
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
loc = ir.Location.name(self.frame.f_code.co_name, childLoc=file_loc)
|
| 933 |
+
return loc
|
| 934 |
+
|
| 935 |
+
def compile_and_jit(self, module, pipeline, shared_libs, function_name=""):
|
| 936 |
+
"""
|
| 937 |
+
Compile and JIT an MLIR module.
|
| 938 |
+
"""
|
| 939 |
+
|
| 940 |
+
try:
|
| 941 |
+
self.diagnostic()
|
| 942 |
+
|
| 943 |
+
orig_stdout = sys.stdout
|
| 944 |
+
orig_stderr = sys.stderr
|
| 945 |
+
sys.stderr = redirect_stderr = io.StringIO()
|
| 946 |
+
sys.stdout = redirect_stdout = io.StringIO()
|
| 947 |
+
|
| 948 |
+
try:
|
| 949 |
+
kernel = self.compiler_provider.compile_and_jit(
|
| 950 |
+
module,
|
| 951 |
+
pipeline,
|
| 952 |
+
shared_libs=shared_libs,
|
| 953 |
+
cuda_toolkit=self.envar.cuda_toolkit,
|
| 954 |
+
arch=self.envar.arch,
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
finally:
|
| 958 |
+
sys.stdout = orig_stdout
|
| 959 |
+
sys.stderr = orig_stderr
|
| 960 |
+
ir._GlobalDebug.flag = False
|
| 961 |
+
|
| 962 |
+
# Print captured output.
|
| 963 |
+
print(redirect_stdout.getvalue(), file=sys.stdout, end="")
|
| 964 |
+
print(redirect_stderr.getvalue(), file=sys.stderr, end="")
|
| 965 |
+
|
| 966 |
+
return kernel
|
| 967 |
+
|
| 968 |
+
except Exception as e:
|
| 969 |
+
raise DSLRuntimeError("🧊🧊🧊 ICE 🧊🧊🧊", cause=e)
|
| 970 |
+
finally:
|
| 971 |
+
pass
|
| 972 |
+
|
| 973 |
+
def preprocess_pipeline(self, pipeline, arch) -> str:
|
| 974 |
+
|
| 975 |
+
if self.envar.cuda_toolkit is None:
|
| 976 |
+
self.print_warning(
|
| 977 |
+
"CUDA_TOOLKIT_PATH environment variable is not set. Cannot set toolkitPath."
|
| 978 |
+
)
|
| 979 |
+
|
| 980 |
+
options = {
|
| 981 |
+
"toolkitPath": self.envar.cuda_toolkit if self.envar.cuda_toolkit else None,
|
| 982 |
+
self.pass_sm_arch_name: arch,
|
| 983 |
+
}
|
| 984 |
+
|
| 985 |
+
opt_str = ""
|
| 986 |
+
for k, v in options.items():
|
| 987 |
+
if v:
|
| 988 |
+
opt_str += f"{k}={v} "
|
| 989 |
+
|
| 990 |
+
if opt_str:
|
| 991 |
+
# Automatically append the pipeline options if any is specified through env var
|
| 992 |
+
pattern = re.compile(r"{(.+)}")
|
| 993 |
+
match = pattern.search(pipeline)
|
| 994 |
+
if match:
|
| 995 |
+
opt_str = f"{{{match[1]} {opt_str}}}"
|
| 996 |
+
pipeline = re.sub(r"{.+}", opt_str, pipeline)
|
| 997 |
+
else:
|
| 998 |
+
pipeline = pipeline.rstrip(")") + f"{{{opt_str}}})"
|
| 999 |
+
log().debug(f"Using pipeline = {pipeline}")
|
| 1000 |
+
return pipeline
|
| 1001 |
+
|
| 1002 |
+
def get_shared_libs(self) -> list:
|
| 1003 |
+
shared_libs = []
|
| 1004 |
+
support_libs = self.envar.shared_libs
|
| 1005 |
+
if support_libs is not None:
|
| 1006 |
+
_libs = support_libs.split(":")
|
| 1007 |
+
for lib in _libs:
|
| 1008 |
+
if not os.path.exists(lib):
|
| 1009 |
+
raise FileNotFoundError(
|
| 1010 |
+
errno.ENOENT, os.strerror(errno.ENOENT), lib
|
| 1011 |
+
)
|
| 1012 |
+
shared_libs.append(lib)
|
| 1013 |
+
else:
|
| 1014 |
+
self.print_warning(f"{self.name}_LIBS environment variable is not set")
|
| 1015 |
+
|
| 1016 |
+
return shared_libs
|
| 1017 |
+
|
| 1018 |
+
@lru_cache(maxsize=1)
|
| 1019 |
+
def get_version(self):
|
| 1020 |
+
version_hash = hashlib.sha256()
|
| 1021 |
+
|
| 1022 |
+
return version_hash
|
| 1023 |
+
|
| 1024 |
+
def get_module_hash(self, module, function_name):
|
| 1025 |
+
s = io.BytesIO()
|
| 1026 |
+
module.operation.write_bytecode(s)
|
| 1027 |
+
for attr, value in self.envar.__dict__.items():
|
| 1028 |
+
if value is not None:
|
| 1029 |
+
s.write(str(value).encode())
|
| 1030 |
+
# Add compile options to the hash
|
| 1031 |
+
s.write(self.compile_options.to_str().encode())
|
| 1032 |
+
module_hash = self.get_version().copy()
|
| 1033 |
+
module_hash.update(s.getvalue())
|
| 1034 |
+
module_hash = module_hash.hexdigest()
|
| 1035 |
+
|
| 1036 |
+
log().debug("Bytecode=[%s]", s.getvalue().hex())
|
| 1037 |
+
log().debug("Version=[%s]", self.get_version().hexdigest())
|
| 1038 |
+
log().info(
|
| 1039 |
+
"Function=[%s] Computed module_hash=[%s]", function_name, module_hash
|
| 1040 |
+
)
|
| 1041 |
+
return module_hash
|
| 1042 |
+
|
| 1043 |
+
def build_module(self, module, function_name: str):
|
| 1044 |
+
"""
|
| 1045 |
+
Build the MLIR module, verify and return the module
|
| 1046 |
+
"""
|
| 1047 |
+
|
| 1048 |
+
# Save IR in a file
|
| 1049 |
+
if self.envar.keepIR:
|
| 1050 |
+
save_ir(self.name, module, function_name)
|
| 1051 |
+
|
| 1052 |
+
if self.envar.printIR:
|
| 1053 |
+
print("\n//===--- ------ Generated IR ------ ---====\n")
|
| 1054 |
+
module.operation.print(
|
| 1055 |
+
enable_debug_info=self.envar.generate_source_location
|
| 1056 |
+
)
|
| 1057 |
+
print("\n//===--- --- End of Generated IR -- ---====\n")
|
| 1058 |
+
|
| 1059 |
+
# Verify the module
|
| 1060 |
+
try:
|
| 1061 |
+
module.operation.verify()
|
| 1062 |
+
except Exception as e:
|
| 1063 |
+
raise DSLRuntimeError(f"🧊🧊🧊 ICE IR Verification Failed 🧊🧊🧊", cause=e)
|
| 1064 |
+
|
| 1065 |
+
return module
|
| 1066 |
+
|
| 1067 |
+
def generate_original_ir(
|
| 1068 |
+
self,
|
| 1069 |
+
ir,
|
| 1070 |
+
func,
|
| 1071 |
+
funcBody,
|
| 1072 |
+
kwargs,
|
| 1073 |
+
function_name,
|
| 1074 |
+
func_types,
|
| 1075 |
+
gpu_module_attrs,
|
| 1076 |
+
args,
|
| 1077 |
+
args_spec,
|
| 1078 |
+
):
|
| 1079 |
+
# This location is set to None for now; otherwise, calls to the same
|
| 1080 |
+
# function on different lines would produce different line numbers,
|
| 1081 |
+
# which would break the cache.
|
| 1082 |
+
loc = None # self.get_location()
|
| 1083 |
+
|
| 1084 |
+
def build_ir_module():
|
| 1085 |
+
module = ir.Module.create(loc=loc)
|
| 1086 |
+
unit_attr = ir.UnitAttr.get()
|
| 1087 |
+
module.operation.attributes["gpu.container_module"] = unit_attr
|
| 1088 |
+
|
| 1089 |
+
with ir.InsertionPoint(module.body):
|
| 1090 |
+
# Always generate gpu module. It's canonicalized by the compiler when it's not used.
|
| 1091 |
+
self._build_gpu_module(gpu_module_attrs)
|
| 1092 |
+
|
| 1093 |
+
fop = func.FuncOp(function_name, (func_types, []), loc=loc)
|
| 1094 |
+
fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
|
| 1095 |
+
log().debug("Generated Function OP [%s]", fop)
|
| 1096 |
+
with ir.InsertionPoint(fop.add_entry_block()):
|
| 1097 |
+
ir_args, ir_kwargs = self.generate_execution_arguments(
|
| 1098 |
+
args, kwargs, fop, args_spec
|
| 1099 |
+
)
|
| 1100 |
+
# Call user function body
|
| 1101 |
+
try:
|
| 1102 |
+
result = funcBody(*ir_args, **ir_kwargs)
|
| 1103 |
+
func.ReturnOp([])
|
| 1104 |
+
except NameError as name_error:
|
| 1105 |
+
raise DSLRuntimeError(
|
| 1106 |
+
f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥",
|
| 1107 |
+
cause=name_error,
|
| 1108 |
+
suggestion="Using variables defined in dynamic control flow is not supported. Please give an initial value before control flow.",
|
| 1109 |
+
)
|
| 1110 |
+
except DSLRuntimeError as dsl_error:
|
| 1111 |
+
# Throw it's already a DSL error
|
| 1112 |
+
raise dsl_error
|
| 1113 |
+
return module, result
|
| 1114 |
+
|
| 1115 |
+
# Build IR module
|
| 1116 |
+
profiler = timer(enable=self.envar.jitTimeProfiling)
|
| 1117 |
+
module, result = profiler(build_ir_module)()
|
| 1118 |
+
module_hash = self.get_module_hash(module, function_name)
|
| 1119 |
+
|
| 1120 |
+
module = self.build_module(module, function_name)
|
| 1121 |
+
|
| 1122 |
+
return module, module_hash, result
|
| 1123 |
+
|
| 1124 |
+
def compile_and_cache(
|
| 1125 |
+
self, module, module_hash, function_name, pipeline, args_spec, no_cache
|
| 1126 |
+
):
|
| 1127 |
+
arch = self.envar.arch
|
| 1128 |
+
pipeline = self.preprocess_pipeline(self._get_pipeline(pipeline), arch)
|
| 1129 |
+
shared_libs = self.get_shared_libs()
|
| 1130 |
+
profiler = timer(enable=self.envar.jitTimeProfiling)
|
| 1131 |
+
if (
|
| 1132 |
+
no_cache
|
| 1133 |
+
or module_hash not in self.jit_cache
|
| 1134 |
+
or self.jit_cache[module_hash].ir_module is None
|
| 1135 |
+
):
|
| 1136 |
+
log().info(
|
| 1137 |
+
"JIT cache miss function=[%s] module_hash=[%s]",
|
| 1138 |
+
function_name,
|
| 1139 |
+
module_hash,
|
| 1140 |
+
)
|
| 1141 |
+
# Compile and JIT MLIR module
|
| 1142 |
+
engine = profiler(self.compile_and_jit)(
|
| 1143 |
+
module, pipeline, shared_libs, function_name=function_name
|
| 1144 |
+
)
|
| 1145 |
+
else:
|
| 1146 |
+
log().info(
|
| 1147 |
+
"JIT cache hit IN-FILE function=[%s] module_hash=[%s]",
|
| 1148 |
+
function_name,
|
| 1149 |
+
module_hash,
|
| 1150 |
+
)
|
| 1151 |
+
module = self.jit_cache[module_hash].ir_module
|
| 1152 |
+
engine = self.compiler_provider.jit(module, shared_libs=shared_libs)
|
| 1153 |
+
capi_func = profiler(engine.lookup)(function_name)
|
| 1154 |
+
jit_executor = JitExecutor(
|
| 1155 |
+
self,
|
| 1156 |
+
engine,
|
| 1157 |
+
capi_func,
|
| 1158 |
+
module,
|
| 1159 |
+
args_spec,
|
| 1160 |
+
function_name,
|
| 1161 |
+
jit_time_profiling=self.envar.jitTimeProfiling,
|
| 1162 |
+
)
|
| 1163 |
+
jit_executor = jit_executor.update_jit_cuda_modules(self.kernel_symbols)
|
| 1164 |
+
|
| 1165 |
+
if not no_cache:
|
| 1166 |
+
# module stored in cache is compiled.
|
| 1167 |
+
self.jit_cache[module_hash] = jit_executor
|
| 1168 |
+
|
| 1169 |
+
return jit_executor
|
| 1170 |
+
|
| 1171 |
+
def post_compilation_cleanup(self):
|
| 1172 |
+
"""Clean up some internal state after one compilation is completed."""
|
| 1173 |
+
# clear the kernel symbols after the compilation is done.
|
| 1174 |
+
self.kernel_symbols = []
|
| 1175 |
+
self.launch_inner_count = 0
|
| 1176 |
+
# reset num_kernels to 0 for next compilation.
|
| 1177 |
+
self.num_kernels = 0
|
| 1178 |
+
# reset the compile options after the compilation is done.
|
| 1179 |
+
self.compile_options = CompileOptions()
|
| 1180 |
+
|
| 1181 |
+
def generate_mlir(
|
| 1182 |
+
self,
|
| 1183 |
+
funcBody,
|
| 1184 |
+
kwargs,
|
| 1185 |
+
function_name,
|
| 1186 |
+
gpu_module_attrs,
|
| 1187 |
+
args,
|
| 1188 |
+
args_spec,
|
| 1189 |
+
pipeline,
|
| 1190 |
+
no_cache,
|
| 1191 |
+
compile_only,
|
| 1192 |
+
loc=None,
|
| 1193 |
+
):
|
| 1194 |
+
"""Generate MLIR module and compile iself.T_provider."""
|
| 1195 |
+
with ir.Context(), ir.Location.unknown():
|
| 1196 |
+
# Convert input arguments to MLIR arguments
|
| 1197 |
+
exe_args, func_types, adapted_args = self.generate_mlir_function_types(
|
| 1198 |
+
funcBody, function_name, args, kwargs, args_spec
|
| 1199 |
+
)
|
| 1200 |
+
|
| 1201 |
+
# Generate original ir module and its hash value.
|
| 1202 |
+
module, module_hash, result = self.generate_original_ir(
|
| 1203 |
+
ir,
|
| 1204 |
+
func,
|
| 1205 |
+
funcBody,
|
| 1206 |
+
kwargs,
|
| 1207 |
+
function_name,
|
| 1208 |
+
func_types,
|
| 1209 |
+
gpu_module_attrs,
|
| 1210 |
+
args,
|
| 1211 |
+
args_spec,
|
| 1212 |
+
)
|
| 1213 |
+
|
| 1214 |
+
# dryrun is used to only generate IR
|
| 1215 |
+
if self.envar.dryrun:
|
| 1216 |
+
return result
|
| 1217 |
+
|
| 1218 |
+
if (
|
| 1219 |
+
no_cache
|
| 1220 |
+
or module_hash not in self.jit_cache
|
| 1221 |
+
or self.jit_cache[module_hash].capi_func is None
|
| 1222 |
+
):
|
| 1223 |
+
# no cache or cache miss, do ir generation/compilation/jit engine
|
| 1224 |
+
jit_executor = self.compile_and_cache(
|
| 1225 |
+
module, module_hash, function_name, pipeline, args_spec, no_cache
|
| 1226 |
+
)
|
| 1227 |
+
else:
|
| 1228 |
+
# cache hit
|
| 1229 |
+
log().info(
|
| 1230 |
+
"JIT cache hit IN-MEMORY function=[%s] module_hash=[%s]",
|
| 1231 |
+
function_name,
|
| 1232 |
+
module_hash,
|
| 1233 |
+
)
|
| 1234 |
+
jit_executor = self.jit_cache[module_hash]
|
| 1235 |
+
|
| 1236 |
+
self.post_compilation_cleanup()
|
| 1237 |
+
# If compile_only is set, bypass execution return the jit_executor directly
|
| 1238 |
+
if compile_only:
|
| 1239 |
+
return jit_executor
|
| 1240 |
+
# Run the compiled program
|
| 1241 |
+
jit_executor.run_compiled_program(exe_args)
|
| 1242 |
+
|
| 1243 |
+
return result
|
| 1244 |
+
|
| 1245 |
+
def run_preprocessor(self, funcBody):
|
| 1246 |
+
if not hasattr(funcBody, "_preprocessed"):
|
| 1247 |
+
function_name = funcBody.__name__
|
| 1248 |
+
self.funcBody = funcBody
|
| 1249 |
+
log().info("Started preprocessing [%s]", function_name)
|
| 1250 |
+
exec_globals = self._get_globals()
|
| 1251 |
+
transformed_ast = self.preprocessor.transform(funcBody, exec_globals)
|
| 1252 |
+
if self.envar.print_after_preprocessor:
|
| 1253 |
+
log().info(
|
| 1254 |
+
f"# Printing unparsed AST after preprocess of func=`{function_name}` id=`{id(funcBody)}`"
|
| 1255 |
+
)
|
| 1256 |
+
DSLPreprocessor.print_ast(transformed_ast)
|
| 1257 |
+
funcBody._preprocessed = True
|
| 1258 |
+
return transformed_ast
|
| 1259 |
+
return None
|
| 1260 |
+
|
| 1261 |
+
def get_function_ptr(self, original_function):
|
| 1262 |
+
file_name = inspect.getsourcefile(original_function)
|
| 1263 |
+
code_object = compile(
|
| 1264 |
+
original_function._transformed_ast, filename=file_name, mode="exec"
|
| 1265 |
+
)
|
| 1266 |
+
return self.preprocessor.exec(
|
| 1267 |
+
original_function.__name__,
|
| 1268 |
+
original_function,
|
| 1269 |
+
code_object,
|
| 1270 |
+
self._get_globals(),
|
| 1271 |
+
)
|
| 1272 |
+
|
| 1273 |
+
def _get_function_bound_args(self, sig, func_name, *args, **kwargs):
|
| 1274 |
+
"""
|
| 1275 |
+
Binds provided arguments to a function's signature and applies default values.
|
| 1276 |
+
|
| 1277 |
+
E.g. given a function signature `def foo(a, b=2, c=3)`, and at call-site if we do
|
| 1278 |
+
`foo(a=1, c=4)`, the returned BoundArguments object will have args = `[1]`
|
| 1279 |
+
and kwargs = `{'b': 2, 'c': 4}`
|
| 1280 |
+
|
| 1281 |
+
An exception will be raised if binding fails.
|
| 1282 |
+
"""
|
| 1283 |
+
try:
|
| 1284 |
+
bound_args = sig.bind_partial(*args, **kwargs)
|
| 1285 |
+
bound_args.apply_defaults()
|
| 1286 |
+
except Exception as e:
|
| 1287 |
+
raise DSLRuntimeError(
|
| 1288 |
+
f"Failed to bind arguments to function `{func_name}` with signature `{sig}`",
|
| 1289 |
+
cause=e,
|
| 1290 |
+
)
|
| 1291 |
+
return bound_args
|
| 1292 |
+
|
| 1293 |
+
def _canonicalize_args(self, sig, *args, **kwargs):
|
| 1294 |
+
"""
|
| 1295 |
+
Canonicalize the input arguments so that returned args only contain
|
| 1296 |
+
positional arguments and kwargs only contain keyword arguments.
|
| 1297 |
+
"""
|
| 1298 |
+
function_name = self.funcBody.__name__
|
| 1299 |
+
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
|
| 1300 |
+
canonicalized_args = bound_args.args
|
| 1301 |
+
canonicalized_kwargs = bound_args.kwargs
|
| 1302 |
+
return canonicalized_args, canonicalized_kwargs
|
| 1303 |
+
|
| 1304 |
+
def _check_arg_count(self, *args, **kwargs):
|
| 1305 |
+
if not self.funcBody:
|
| 1306 |
+
raise DSLRuntimeError("Function body is not set.")
|
| 1307 |
+
|
| 1308 |
+
# Pass the actual function object to inspect.signature to get the signature.
|
| 1309 |
+
sig = inspect.signature(self.funcBody)
|
| 1310 |
+
|
| 1311 |
+
function_name = self.funcBody.__name__
|
| 1312 |
+
|
| 1313 |
+
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
|
| 1314 |
+
|
| 1315 |
+
# Check if all non-default arguments are provided
|
| 1316 |
+
for param in sig.parameters.values():
|
| 1317 |
+
if (
|
| 1318 |
+
param.default is inspect.Parameter.empty
|
| 1319 |
+
and param.name not in bound_args.arguments
|
| 1320 |
+
):
|
| 1321 |
+
raise DSLRuntimeError(
|
| 1322 |
+
f"Missing required argument in `{function_name}`: '{param.name}'"
|
| 1323 |
+
)
|
| 1324 |
+
|
| 1325 |
+
return sig
|
| 1326 |
+
|
| 1327 |
+
def _func(self, funcBody, *args, **kwargs):
|
| 1328 |
+
"""Decorator for MLIR functions.
|
| 1329 |
+
It cuts the boilerplate code, does the following:
|
| 1330 |
+
1. Generates `func.func`
|
| 1331 |
+
2. Types translation (numpy arrays -> cute.memref, float -> <f32>, etc.)
|
| 1332 |
+
3. Compiles and JITs the MLIR module
|
| 1333 |
+
4. Invokes the generated function
|
| 1334 |
+
5. Operator overloading (a + b --> arith.addi a, b)
|
| 1335 |
+
6. Generates GPU kernel function with GPU module and kernel attributes baked
|
| 1336 |
+
"""
|
| 1337 |
+
if ir.Context.current is None:
|
| 1338 |
+
pass
|
| 1339 |
+
elif ir.InsertionPoint.current is not None:
|
| 1340 |
+
return funcBody(*args, **kwargs)
|
| 1341 |
+
|
| 1342 |
+
function_name = funcBody.__name__
|
| 1343 |
+
self.funcBody = funcBody
|
| 1344 |
+
|
| 1345 |
+
pipeline = kwargs.pop("pipeline", None)
|
| 1346 |
+
gpu_module_attrs = kwargs.pop("gpu_module_attrs", {})
|
| 1347 |
+
|
| 1348 |
+
# Disable cache
|
| 1349 |
+
no_cache = kwargs.pop("no_cache", False)
|
| 1350 |
+
|
| 1351 |
+
# Always compile(disable cache) and return the result jit_executor
|
| 1352 |
+
compile_only = kwargs.pop("compile_only", False)
|
| 1353 |
+
|
| 1354 |
+
if not no_cache and compile_only:
|
| 1355 |
+
no_cache = True
|
| 1356 |
+
self.print_warning("Cache is disabled as user wants to compile only.")
|
| 1357 |
+
|
| 1358 |
+
# Check the number of arguments
|
| 1359 |
+
sig = self._check_arg_count(*args, **kwargs)
|
| 1360 |
+
|
| 1361 |
+
args_spec = inspect.getfullargspec(funcBody)
|
| 1362 |
+
|
| 1363 |
+
# Canonicalize the input arguments
|
| 1364 |
+
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
|
| 1365 |
+
sig, *args, **kwargs
|
| 1366 |
+
)
|
| 1367 |
+
|
| 1368 |
+
# Simple name mangling
|
| 1369 |
+
function_name = self.mangle_name(function_name, canonicalized_args, args_spec)
|
| 1370 |
+
|
| 1371 |
+
# Generate MLIR Context and start generating IR
|
| 1372 |
+
log().debug(f"Generating MLIR for function '{function_name}'")
|
| 1373 |
+
result = self.generate_mlir(
|
| 1374 |
+
funcBody,
|
| 1375 |
+
canonicalized_kwargs,
|
| 1376 |
+
function_name,
|
| 1377 |
+
gpu_module_attrs,
|
| 1378 |
+
canonicalized_args,
|
| 1379 |
+
args_spec,
|
| 1380 |
+
pipeline,
|
| 1381 |
+
no_cache,
|
| 1382 |
+
compile_only,
|
| 1383 |
+
)
|
| 1384 |
+
|
| 1385 |
+
return result
|
| 1386 |
+
|
| 1387 |
+
class _KernelGenHelper(ABC):
|
| 1388 |
+
def __init__(self):
|
| 1389 |
+
self.func_op = None
|
| 1390 |
+
self.func_type = None
|
| 1391 |
+
|
| 1392 |
+
@abstractmethod
|
| 1393 |
+
def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None):
|
| 1394 |
+
assert arg_types is not None, "Invalid arg_types!"
|
| 1395 |
+
assert kernel_name is not None, "kernel name is empty"
|
| 1396 |
+
pass
|
| 1397 |
+
|
| 1398 |
+
@abstractmethod
|
| 1399 |
+
def generate_func_ret_op(self):
|
| 1400 |
+
pass
|
| 1401 |
+
|
| 1402 |
+
@abstractmethod
|
| 1403 |
+
def generate_launch_op(self, *args, **kwargs):
|
| 1404 |
+
pass
|
| 1405 |
+
|
| 1406 |
+
@abstractmethod
|
| 1407 |
+
def get_func_body_start(self):
|
| 1408 |
+
pass
|
| 1409 |
+
|
| 1410 |
+
@abstractmethod
|
| 1411 |
+
def enter_gpu_module(module):
|
| 1412 |
+
"""Compute the insertion point into the given module."""
|
| 1413 |
+
pass
|
| 1414 |
+
|
| 1415 |
+
@lru_cache(maxsize=1)
|
| 1416 |
+
def _get_default_stream(self):
|
| 1417 |
+
"""Returns the default stream 0"""
|
| 1418 |
+
from .runtime import cuda as cuda_helpers
|
| 1419 |
+
|
| 1420 |
+
return cuda_helpers.stream_create()
|
| 1421 |
+
|
| 1422 |
+
def _execute_cuda(
|
| 1423 |
+
self, fname_cubin, kernel_name, grid_size, block_size, smem_size, stream=None
|
| 1424 |
+
):
|
| 1425 |
+
"""
|
| 1426 |
+
Executes a specified CUDA kernel from a cubin file, handling module loading,
|
| 1427 |
+
kernel retrieval, stream creation, kernel launch, and synchronization.
|
| 1428 |
+
"""
|
| 1429 |
+
from .runtime import cuda as cuda_helpers
|
| 1430 |
+
|
| 1431 |
+
# Step 1. Load CUDA Module
|
| 1432 |
+
module = cuda_helpers.load_cubin_module(fname_cubin)
|
| 1433 |
+
# Step 2. Find CUDA function
|
| 1434 |
+
kernel_ptr = cuda_helpers.get_kernel_function(module, kernel_name)
|
| 1435 |
+
|
| 1436 |
+
sync_execution_default = False
|
| 1437 |
+
if stream is None:
|
| 1438 |
+
stream = self._get_default_stream()
|
| 1439 |
+
sync_execution_default = True
|
| 1440 |
+
|
| 1441 |
+
# Step 4. Launch the kernel
|
| 1442 |
+
cuda_helpers.launch_kernel(
|
| 1443 |
+
kernel_ptr,
|
| 1444 |
+
grid_size,
|
| 1445 |
+
block_size,
|
| 1446 |
+
stream,
|
| 1447 |
+
smem_size=smem_size,
|
| 1448 |
+
kernel_args=self.exe_args,
|
| 1449 |
+
)
|
| 1450 |
+
|
| 1451 |
+
if sync_execution_default:
|
| 1452 |
+
# Step 5. Optional Sync cuda stream
|
| 1453 |
+
cuda_helpers.stream_sync(stream)
|
| 1454 |
+
|
| 1455 |
+
def _execute_by_cuda_driver(
|
| 1456 |
+
self,
|
| 1457 |
+
kernel_generator,
|
| 1458 |
+
generate_cubin,
|
| 1459 |
+
grid_size,
|
| 1460 |
+
block_size,
|
| 1461 |
+
smem_size,
|
| 1462 |
+
stream=None,
|
| 1463 |
+
):
|
| 1464 |
+
"""
|
| 1465 |
+
This function builds IR and execute the module using cuda driver.
|
| 1466 |
+
It doesn't use mlir's cuda runtime
|
| 1467 |
+
"""
|
| 1468 |
+
ret = None
|
| 1469 |
+
|
| 1470 |
+
# Step 1. Build IR
|
| 1471 |
+
with ir.Context(), ir.Location.unknown():
|
| 1472 |
+
loc = self.get_location()
|
| 1473 |
+
module = ir.Module.create(loc=loc)
|
| 1474 |
+
unit_attr = ir.UnitAttr.get()
|
| 1475 |
+
module.operation.attributes["gpu.container_module"] = unit_attr
|
| 1476 |
+
with ir.InsertionPoint(module.body):
|
| 1477 |
+
self._build_gpu_module()
|
| 1478 |
+
ret, kernel_name = kernel_generator()
|
| 1479 |
+
log().debug(
|
| 1480 |
+
f"Kernel generator returned: ret={ret}, kernel_name={kernel_name}"
|
| 1481 |
+
)
|
| 1482 |
+
|
| 1483 |
+
module = self.build_module(module, kernel_name)
|
| 1484 |
+
|
| 1485 |
+
# dryrun is used to only generate IR
|
| 1486 |
+
if self.envar.dryrun:
|
| 1487 |
+
return ret
|
| 1488 |
+
|
| 1489 |
+
# Generate cubin
|
| 1490 |
+
fname_cubin = generate_cubin(module, kernel_name)
|
| 1491 |
+
|
| 1492 |
+
# Execute a cuda kernel from cubin
|
| 1493 |
+
self._execute_cuda(
|
| 1494 |
+
fname_cubin, kernel_name, grid_size, block_size, smem_size, stream
|
| 1495 |
+
)
|
| 1496 |
+
|
| 1497 |
+
return ret
|
| 1498 |
+
|
| 1499 |
+
def generate_kernel_operands_and_types(
|
| 1500 |
+
self, kernel_func, kernel_name, args_spec, args, kwargs
|
| 1501 |
+
):
|
| 1502 |
+
"""
|
| 1503 |
+
Generate the operands and types for the kernel function
|
| 1504 |
+
"""
|
| 1505 |
+
|
| 1506 |
+
kernel_operands, kernel_arg_types, kernel_arg_attrs = [], [], []
|
| 1507 |
+
|
| 1508 |
+
log().debug(
|
| 1509 |
+
"Processing GPU kernel call in [%s] mode",
|
| 1510 |
+
(
|
| 1511 |
+
f"Only {self.device_jit_decorator_name}"
|
| 1512 |
+
if self.device_compilation_only
|
| 1513 |
+
else f"{self.host_jit_decorator_name} + {self.device_jit_decorator_name}"
|
| 1514 |
+
),
|
| 1515 |
+
)
|
| 1516 |
+
|
| 1517 |
+
if self.device_compilation_only:
|
| 1518 |
+
return kernel_operands, kernel_arg_types, kernel_arg_attrs
|
| 1519 |
+
|
| 1520 |
+
kernel_operands, kernel_arg_types, kernel_arg_attrs, _ = (
|
| 1521 |
+
self._generate_jit_func_args(
|
| 1522 |
+
kernel_func, kernel_name, args, kwargs, args_spec, is_host=False
|
| 1523 |
+
)
|
| 1524 |
+
)
|
| 1525 |
+
|
| 1526 |
+
log().debug("Final kernel_operands: %s", ", ".join(map(str, kernel_operands)))
|
| 1527 |
+
log().debug("Final kernel_arg_types: %s", ", ".join(map(str, kernel_arg_types)))
|
| 1528 |
+
log().debug("Final kernel_arg_attrs: %s", ", ".join(map(str, kernel_arg_attrs)))
|
| 1529 |
+
|
| 1530 |
+
assert (
|
| 1531 |
+
len(kernel_operands) == len(kernel_arg_types) == len(kernel_arg_attrs)
|
| 1532 |
+
), "Size of kernel_operands, kernel_arg_types and kernel_arg_attrs must be equal"
|
| 1533 |
+
|
| 1534 |
+
return kernel_operands, kernel_arg_types, kernel_arg_attrs
|
| 1535 |
+
|
| 1536 |
+
def kernel_launcher(self, *dargs, **dkwargs):
|
| 1537 |
+
def decorator(funcBody):
|
| 1538 |
+
@wraps(funcBody)
|
| 1539 |
+
def kernel_wrapper(*args, **kwargs):
|
| 1540 |
+
"""
|
| 1541 |
+
Base decorator for generating kernel function
|
| 1542 |
+
|
| 1543 |
+
This decorator provides a template for kernel function generation
|
| 1544 |
+
including kernel function header/body and kernel launch op at call site
|
| 1545 |
+
|
| 1546 |
+
Optional arguments (with default value in <>):
|
| 1547 |
+
- requiredArgs <[]>: specifies the mandatory arguments that must present in kernel function signature
|
| 1548 |
+
the args will be validated and collected as a namedtuple
|
| 1549 |
+
- optionalArgs <[]>: specifies the optional arguments that might present in kernel function signature
|
| 1550 |
+
the args will be collected (if present) as a namedtuple
|
| 1551 |
+
- unitAttrNames <[]>: specifies the name(s) of ir.UnitAttr to be set for kernel function op
|
| 1552 |
+
- valueAttrDict <{}>: specifies the name(s) and value(s) of ir.Attribute to be set for kernel function op
|
| 1553 |
+
- kernelGenHelper <None>: specifies the mandatory customized kernel generation helper class (derived from _KernelGenHelper)
|
| 1554 |
+
|
| 1555 |
+
Return value:
|
| 1556 |
+
A namedtuple "KernelReturns" is returned with following fields:
|
| 1557 |
+
- kernel_func_ret: the return of the kernel function
|
| 1558 |
+
- launch_op_ret: the return of the launch op
|
| 1559 |
+
"""
|
| 1560 |
+
|
| 1561 |
+
requiredArgs = dkwargs.get("requiredArgs", [])
|
| 1562 |
+
optionalArgs = dkwargs.get("optionalArgs", [])
|
| 1563 |
+
unitAttrNames = dkwargs.get("unitAttrNames", [])
|
| 1564 |
+
valueAttrDict = dkwargs.get("valueAttrDict", {})
|
| 1565 |
+
kernelGenHelper = dkwargs.get("kernelGenHelper", None)
|
| 1566 |
+
|
| 1567 |
+
kernel_name = funcBody.__name__
|
| 1568 |
+
args_spec = inspect.getfullargspec(funcBody)
|
| 1569 |
+
self.funcBody = funcBody
|
| 1570 |
+
|
| 1571 |
+
# Give each kernel a unique name. (The same kernel may be
|
| 1572 |
+
# called multiple times, resulting in multiple kernel traces.)
|
| 1573 |
+
# The mangled name of Python function is part of the name to
|
| 1574 |
+
# improve readability.
|
| 1575 |
+
kernel_name = f"kernel_{self.mangle_name(kernel_name, args, args_spec)}_{self.num_kernels}"
|
| 1576 |
+
self.num_kernels += 1
|
| 1577 |
+
|
| 1578 |
+
# Step 0. Preprocess the arguments
|
| 1579 |
+
def extract_args(argNames, assertIfNone=False) -> list:
|
| 1580 |
+
extracted = []
|
| 1581 |
+
for name in argNames:
|
| 1582 |
+
value = kwargs.pop(name, None)
|
| 1583 |
+
if assertIfNone and value is None:
|
| 1584 |
+
raise DSLRuntimeError(
|
| 1585 |
+
f"{name} is required for {kernel_name}"
|
| 1586 |
+
)
|
| 1587 |
+
extracted.append(value)
|
| 1588 |
+
|
| 1589 |
+
return extracted
|
| 1590 |
+
|
| 1591 |
+
RequiredArgs = namedtuple("RequiredArgs", requiredArgs)
|
| 1592 |
+
req_args = (
|
| 1593 |
+
RequiredArgs._make(extract_args(requiredArgs, assertIfNone=True))
|
| 1594 |
+
if requiredArgs
|
| 1595 |
+
else None
|
| 1596 |
+
)
|
| 1597 |
+
OptionalArgs = namedtuple("OptionalArgs", optionalArgs)
|
| 1598 |
+
opt_args = (
|
| 1599 |
+
OptionalArgs._make(extract_args(optionalArgs))
|
| 1600 |
+
if optionalArgs
|
| 1601 |
+
else None
|
| 1602 |
+
)
|
| 1603 |
+
assert (
|
| 1604 |
+
kernelGenHelper is not None
|
| 1605 |
+
), "kernelGenHelper should be explicitly specified!"
|
| 1606 |
+
|
| 1607 |
+
# check arguments
|
| 1608 |
+
sig = self._check_arg_count(*args, **kwargs)
|
| 1609 |
+
|
| 1610 |
+
# Canonicalize the input arguments
|
| 1611 |
+
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
|
| 1612 |
+
sig, *args, **kwargs
|
| 1613 |
+
)
|
| 1614 |
+
|
| 1615 |
+
kernel_operands, kernel_types, kernel_arg_attrs = (
|
| 1616 |
+
self.generate_kernel_operands_and_types(
|
| 1617 |
+
funcBody,
|
| 1618 |
+
kernel_name,
|
| 1619 |
+
args_spec,
|
| 1620 |
+
canonicalized_args,
|
| 1621 |
+
canonicalized_kwargs,
|
| 1622 |
+
)
|
| 1623 |
+
)
|
| 1624 |
+
|
| 1625 |
+
with self._enter_gpu_module():
|
| 1626 |
+
log().debug("Generating device kernel")
|
| 1627 |
+
if self.device_compilation_only:
|
| 1628 |
+
log().debug("Generating cuda-python arguments")
|
| 1629 |
+
# Convert input arguments to MLIR arguments
|
| 1630 |
+
self.exe_args, kernel_types, _ = (
|
| 1631 |
+
self.generate_mlir_function_types(
|
| 1632 |
+
funcBody,
|
| 1633 |
+
kernel_name,
|
| 1634 |
+
canonicalized_args,
|
| 1635 |
+
canonicalized_kwargs,
|
| 1636 |
+
args_spec,
|
| 1637 |
+
)
|
| 1638 |
+
)
|
| 1639 |
+
|
| 1640 |
+
helper = kernelGenHelper()
|
| 1641 |
+
loc = self.get_location()
|
| 1642 |
+
fop = helper.generate_func_op(
|
| 1643 |
+
kernel_types, kernel_arg_attrs, kernel_name, loc
|
| 1644 |
+
)
|
| 1645 |
+
log().debug(f"Kernel function op: {fop}")
|
| 1646 |
+
for attr in unitAttrNames:
|
| 1647 |
+
fop.attributes[attr] = ir.UnitAttr.get()
|
| 1648 |
+
for key, val in valueAttrDict.items():
|
| 1649 |
+
fop.attributes[key] = val
|
| 1650 |
+
|
| 1651 |
+
fop.sym_visibility = ir.StringAttr.get("public")
|
| 1652 |
+
with ir.InsertionPoint(helper.get_func_body_start()):
|
| 1653 |
+
ir_args, ir_kwargs = self.generate_execution_arguments(
|
| 1654 |
+
canonicalized_args, canonicalized_kwargs, fop, args_spec
|
| 1655 |
+
)
|
| 1656 |
+
log().debug(
|
| 1657 |
+
f"IR arguments - args: {ir_args} ; kwargs: {ir_kwargs}"
|
| 1658 |
+
)
|
| 1659 |
+
# Call user function body
|
| 1660 |
+
kernel_ret = funcBody(*ir_args, **ir_kwargs)
|
| 1661 |
+
helper.generate_func_ret_op()
|
| 1662 |
+
|
| 1663 |
+
# Step 3. Generate call site `launch_func`
|
| 1664 |
+
kernel_sym = ir.SymbolRefAttr.get(["kernels", kernel_name])
|
| 1665 |
+
launch_ret = helper.generate_launch_op(
|
| 1666 |
+
kernelSym=kernel_sym,
|
| 1667 |
+
kernelOperands=kernel_operands,
|
| 1668 |
+
requiredArgs=req_args,
|
| 1669 |
+
optionalArgs=opt_args,
|
| 1670 |
+
)
|
| 1671 |
+
|
| 1672 |
+
KernelReturns = namedtuple(
|
| 1673 |
+
"KernelReturns", ["kernel_func_ret", "launch_op_ret"]
|
| 1674 |
+
)
|
| 1675 |
+
result = KernelReturns(
|
| 1676 |
+
kernel_func_ret=kernel_ret, launch_op_ret=launch_ret
|
| 1677 |
+
)
|
| 1678 |
+
log().debug(f"Kernel result: {result}, kernel name: {kernel_name}")
|
| 1679 |
+
return result, kernel_name
|
| 1680 |
+
|
| 1681 |
+
return kernel_wrapper
|
| 1682 |
+
|
| 1683 |
+
if len(dargs) == 1 and callable(dargs[0]):
|
| 1684 |
+
return decorator(dargs[0])
|
| 1685 |
+
else:
|
| 1686 |
+
return decorator
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/env_manager.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This module provides utilities for the environment variables setup.
|
| 14 |
+
|
| 15 |
+
It provides an EnvironmentVarManager, which reads environment variables for the DSL
|
| 16 |
+
and caches them for efficient access.
|
| 17 |
+
|
| 18 |
+
It also provides utilities to automatically setup a subset of environment variables
|
| 19 |
+
based on heuristics.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import sys
|
| 24 |
+
import shutil
|
| 25 |
+
import glob
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from functools import lru_cache
|
| 28 |
+
from typing import Any
|
| 29 |
+
|
| 30 |
+
from ..base_dsl.runtime.cuda import get_compute_capability_major_minor
|
| 31 |
+
from .utils.logger import log
|
| 32 |
+
|
| 33 |
+
IS_WINDOWS = sys.platform == "win32"
|
| 34 |
+
CLIB_EXT = ".dll" if IS_WINDOWS else ".so"
|
| 35 |
+
|
| 36 |
+
# =============================================================================
|
| 37 |
+
# Environment Variable Helpers
|
| 38 |
+
# =============================================================================
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@lru_cache(maxsize=None)
|
| 42 |
+
def get_str_env_var(var_name, default_value=None):
|
| 43 |
+
value = os.getenv(var_name)
|
| 44 |
+
return value if value is not None else default_value
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@lru_cache(maxsize=None)
|
| 48 |
+
def get_bool_env_var(var_name, default_value=False):
|
| 49 |
+
value = get_str_env_var(var_name)
|
| 50 |
+
if value is None:
|
| 51 |
+
return default_value
|
| 52 |
+
return value not in {"False", "0", ""}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@lru_cache(maxsize=None)
|
| 56 |
+
def get_int_env_var(var_name, default_value=0):
|
| 57 |
+
value = get_str_env_var(var_name)
|
| 58 |
+
return int(value) if value and value.isdigit() else default_value
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@lru_cache(maxsize=None)
|
| 62 |
+
def has_env_var(var_name):
|
| 63 |
+
return os.getenv(var_name) is not None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def detect_gpu_arch(prefix):
|
| 67 |
+
"""
|
| 68 |
+
Attempts to detect the machine's GPU architecture.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
A string representing the GPU architecture (e.g. "70" for compute capability 7.0),
|
| 72 |
+
or a default value(e.g. "sm_100") if the GPU architecture cannot be determined.
|
| 73 |
+
"""
|
| 74 |
+
arch = (None, None)
|
| 75 |
+
try:
|
| 76 |
+
arch = get_compute_capability_major_minor()
|
| 77 |
+
except Exception as e:
|
| 78 |
+
log().info(f"Failed to get CUDA compute capability: {e}")
|
| 79 |
+
|
| 80 |
+
if arch == (None, None):
|
| 81 |
+
# default to sm_100
|
| 82 |
+
arch = (10, 0)
|
| 83 |
+
|
| 84 |
+
major, minor = arch
|
| 85 |
+
suffix = ""
|
| 86 |
+
if major >= 9:
|
| 87 |
+
suffix = "a"
|
| 88 |
+
|
| 89 |
+
return f"sm_{major}{minor}{suffix}"
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def find_libs_in_ancestors(start, target_libs, lib_folder_guesses):
|
| 93 |
+
"""
|
| 94 |
+
Search ancestor directories for a candidate library folder containing all required libraries.
|
| 95 |
+
|
| 96 |
+
Starting from the given path, this function traverses up through each parent directory.
|
| 97 |
+
For every ancestor, it checks candidate subdirectories (specified by lib_folder_guesses)
|
| 98 |
+
for files that match the required library extension (CLIB_EXT). Library file names are
|
| 99 |
+
canonicalized by removing the "lib" prefix from their stem. If a candidate directory contains
|
| 100 |
+
all of the required libraries (as specified in target_libs), the function returns a list of
|
| 101 |
+
absolute paths to these library files.
|
| 102 |
+
|
| 103 |
+
Parameters:
|
| 104 |
+
start (str or Path): The starting directory from which to begin the search.
|
| 105 |
+
target_libs (iterable of str): A collection of required library names (without the "lib" prefix).
|
| 106 |
+
lib_folder_guesses (iterable of str): Relative paths from an ancestor directory that may contain the libraries.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
list[str] or None: A list of resolved paths to the required library files if found; otherwise, None.
|
| 110 |
+
"""
|
| 111 |
+
# Traverse through all parent directories of the resolved starting path.
|
| 112 |
+
for ancestor in Path(start).resolve().parents:
|
| 113 |
+
# Iterate over each candidate relative directory path.
|
| 114 |
+
for rel_path in lib_folder_guesses:
|
| 115 |
+
target_dir = ancestor / rel_path
|
| 116 |
+
# Skip if the candidate directory does not exist.
|
| 117 |
+
if not target_dir.is_dir():
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
# Initialize a list to hold the resolved paths of matching library files.
|
| 121 |
+
libs_cand = []
|
| 122 |
+
# Create a set of the remaining libraries we need to find.
|
| 123 |
+
remaining_libs = set(target_libs)
|
| 124 |
+
|
| 125 |
+
# Iterate over all items in the candidate directory.
|
| 126 |
+
for p in target_dir.iterdir():
|
| 127 |
+
# Consider only files with the expected library extension.
|
| 128 |
+
if p.suffix == CLIB_EXT:
|
| 129 |
+
# Canonicalize the library name by removing the "lib" prefix.
|
| 130 |
+
lib_name = p.stem.removeprefix("lib")
|
| 131 |
+
# If this library is required, add its resolved path and mark it as found.
|
| 132 |
+
if lib_name in remaining_libs:
|
| 133 |
+
libs_cand.append(str(p.resolve()))
|
| 134 |
+
remaining_libs.remove(lib_name)
|
| 135 |
+
|
| 136 |
+
# If all required libraries have been found, return the list of library paths.
|
| 137 |
+
if len(remaining_libs) == 0:
|
| 138 |
+
return libs_cand
|
| 139 |
+
|
| 140 |
+
# Return None if no candidate directory contains all required libraries.
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _find_cuda_home():
|
| 145 |
+
"""Find the CUDA installation path using a series of heuristic methods.
|
| 146 |
+
Methods below are checked in order, and the function returns on first match:
|
| 147 |
+
1. Checking the environment variables CUDA_HOME and CUDA_PATH.
|
| 148 |
+
2. Searching for the 'nvcc' compiler in the system PATH and deriving the path of cuda.
|
| 149 |
+
3. Scanning common installation directories based on the operating system.
|
| 150 |
+
- On Windows systems (when IS_WINDOWS is True), it searches in:
|
| 151 |
+
C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*
|
| 152 |
+
- On Unix-like systems, it searches in:
|
| 153 |
+
/usr/local/cuda*
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Optional[str]: The absolute CUDA installation path if found; otherwise, None.
|
| 157 |
+
|
| 158 |
+
Note:
|
| 159 |
+
The variable IS_WINDOWS is defined in the module scope.
|
| 160 |
+
"""
|
| 161 |
+
# Guess #1
|
| 162 |
+
cuda_home = get_str_env_var("CUDA_HOME") or get_str_env_var("CUDA_PATH")
|
| 163 |
+
if cuda_home is None:
|
| 164 |
+
# Guess #2
|
| 165 |
+
nvcc_path = shutil.which("nvcc")
|
| 166 |
+
if nvcc_path is not None:
|
| 167 |
+
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
|
| 168 |
+
else:
|
| 169 |
+
# Guess #3
|
| 170 |
+
if IS_WINDOWS:
|
| 171 |
+
glob_pat = "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*"
|
| 172 |
+
else:
|
| 173 |
+
glob_pat = "/usr/local/cuda*"
|
| 174 |
+
cuda_homes = glob.glob(glob_pat)
|
| 175 |
+
if len(cuda_homes) == 0:
|
| 176 |
+
cuda_home = ""
|
| 177 |
+
else:
|
| 178 |
+
cuda_home = cuda_homes[0]
|
| 179 |
+
if not os.path.exists(cuda_home):
|
| 180 |
+
cuda_home = None
|
| 181 |
+
return cuda_home
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def get_cuda_toolkit_path():
|
| 185 |
+
"""
|
| 186 |
+
Get cuda_toolkit_path. It returns get_str_env_var('CUDA_TOOLKIT_PATH') if
|
| 187 |
+
set. Otherwise, attempts to discover a valid CUDA toolkit location and
|
| 188 |
+
return. If not found, return None.
|
| 189 |
+
"""
|
| 190 |
+
# Check if the environment variable is already set, if so, return it immediately.
|
| 191 |
+
try:
|
| 192 |
+
cuda_toolkit_path_existing = get_str_env_var("CUDA_TOOLKIT_PATH")
|
| 193 |
+
if cuda_toolkit_path_existing:
|
| 194 |
+
return cuda_toolkit_path_existing
|
| 195 |
+
|
| 196 |
+
found_cuda_home = _find_cuda_home()
|
| 197 |
+
if found_cuda_home:
|
| 198 |
+
return found_cuda_home
|
| 199 |
+
except Exception as e:
|
| 200 |
+
log().info("default_env: exception on get_cuda_toolkit_path", e)
|
| 201 |
+
return None
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def get_prefix_dsl_libs(prefix: str):
|
| 205 |
+
"""
|
| 206 |
+
Returns get_str_env_var('{prefix}_LIBS') if set.
|
| 207 |
+
Otherwise, attempts to discover libs based on heuristics and return
|
| 208 |
+
If not found, return None.
|
| 209 |
+
"""
|
| 210 |
+
# Check if the environment variable is already set, if so, return it immediately.
|
| 211 |
+
try:
|
| 212 |
+
prefix_libs_existing = get_str_env_var(f"{prefix}_LIBS")
|
| 213 |
+
if prefix_libs_existing:
|
| 214 |
+
return prefix_libs_existing
|
| 215 |
+
|
| 216 |
+
def get_libs_cand(start):
|
| 217 |
+
target_libs = {
|
| 218 |
+
"mlir_c_runner_utils",
|
| 219 |
+
"mlir_runner_utils",
|
| 220 |
+
"mlir_cuda_runtime",
|
| 221 |
+
}
|
| 222 |
+
lib_folder_guesses = [
|
| 223 |
+
"lib",
|
| 224 |
+
]
|
| 225 |
+
|
| 226 |
+
libs_cand = find_libs_in_ancestors(start, target_libs, lib_folder_guesses)
|
| 227 |
+
if libs_cand:
|
| 228 |
+
dsl_libs = ":".join(libs_cand)
|
| 229 |
+
return dsl_libs
|
| 230 |
+
|
| 231 |
+
return None
|
| 232 |
+
|
| 233 |
+
# find from install folder
|
| 234 |
+
dsl_libs = get_libs_cand(__file__)
|
| 235 |
+
|
| 236 |
+
if not dsl_libs:
|
| 237 |
+
# try to find from build folder structure
|
| 238 |
+
dsl_libs = get_libs_cand(Path(__file__).parent.parent.resolve())
|
| 239 |
+
|
| 240 |
+
return dsl_libs
|
| 241 |
+
|
| 242 |
+
except Exception as e:
|
| 243 |
+
log().info(f"default_env: exception on get_prefix_dsl_libs", e)
|
| 244 |
+
return None
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class EnvironmentVarManager:
|
| 248 |
+
"""Manages environment variables for configuration options.
|
| 249 |
+
|
| 250 |
+
Printing options:
|
| 251 |
+
- [DSL_NAME]_LOG_TO_CONSOLE: Print logging to stderr (default: False)
|
| 252 |
+
- [DSL_NAME]_PRINT_AFTER_PREPROCESSOR: Print after preprocess (default: False)
|
| 253 |
+
- [DSL_NAME]_PRINT_IR: Print generated IR (default: False)
|
| 254 |
+
- [DSL_NAME]_FILTER_STACKTRACE: Filter internal stacktrace (default: True)
|
| 255 |
+
File options:
|
| 256 |
+
- [DSL_NAME]_KEEP_IR: Save generated IR in a file (default: False)
|
| 257 |
+
- [DSL_NAME]_LOG_TO_FILE: Store all logging into a file, excluding COMPILE_LOGS (default: False)
|
| 258 |
+
Other options:
|
| 259 |
+
- [DSL_NAME]_LOG_LEVEL: Logging level to set, for LOG_TO_CONSOLE or LOG_TO_FILE (default: 1).
|
| 260 |
+
- [DSL_NAME]_DRYRUN: Generates IR only (default: False)
|
| 261 |
+
- [DSL_NAME]_ARCH: GPU architecture (default: "sm_100")
|
| 262 |
+
- [DSL_NAME]_WARNINGS_AS_ERRORS: Enable warnings as error (default: False)
|
| 263 |
+
- [DSL_NAME]_WARNINGS_IGNORE: Ignore warnings (default: False)
|
| 264 |
+
- [DSL_NAME]_ENABLE_OPTIMIZATION_WARNINGS: Enable warnings of optimization warnings (default: False)
|
| 265 |
+
- [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False)
|
| 266 |
+
- [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False)
|
| 267 |
+
- [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000)
|
| 268 |
+
- [DSL_NAME]_LIBS: Path to dependent shared libraries (default: None)
|
| 269 |
+
- [DSL_NAME]_NO_SOURCE_LOCATION: Generate source location (default: False)
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
def __init__(self, prefix="DSL"):
|
| 273 |
+
self.prefix = prefix # change if needed
|
| 274 |
+
|
| 275 |
+
# Printing options
|
| 276 |
+
self.print_after_preprocessor = get_bool_env_var(
|
| 277 |
+
f"{prefix}_PRINT_AFTER_PREPROCESSOR", False
|
| 278 |
+
)
|
| 279 |
+
self.printIR = get_bool_env_var(f"{prefix}_PRINT_IR", False)
|
| 280 |
+
self.filterStacktrace = get_bool_env_var(f"{prefix}_FILTER_STACKTRACE", True)
|
| 281 |
+
# File options
|
| 282 |
+
self.keepIR = get_bool_env_var(f"{prefix}_KEEP_IR", False)
|
| 283 |
+
# Logging options
|
| 284 |
+
self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False)
|
| 285 |
+
self.log_to_file = get_bool_env_var(f"{prefix}_LOG_TO_FILE", False)
|
| 286 |
+
if (
|
| 287 |
+
has_env_var(f"{prefix}_LOG_LEVEL")
|
| 288 |
+
and not self.log_to_console
|
| 289 |
+
and not self.log_to_file
|
| 290 |
+
):
|
| 291 |
+
log().warning(
|
| 292 |
+
f"Log level was set, but neither logging to file ({prefix}_LOG_TO_FILE) nor logging to console ({prefix}_LOG_TO_CONSOLE) is enabled!"
|
| 293 |
+
)
|
| 294 |
+
self.log_level = get_int_env_var(f"{prefix}_LOG_LEVEL", 1)
|
| 295 |
+
|
| 296 |
+
# Other options
|
| 297 |
+
self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False)
|
| 298 |
+
self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix))
|
| 299 |
+
self.warnings_as_errors = get_bool_env_var(
|
| 300 |
+
f"{prefix}_WARNINGS_AS_ERRORS", False
|
| 301 |
+
)
|
| 302 |
+
self.warnings_ignore = get_bool_env_var(f"{prefix}_WARNINGS_IGNORE", False)
|
| 303 |
+
self.enable_optimization_warnings = get_bool_env_var(
|
| 304 |
+
f"{prefix}_ENABLE_OPTIMIZATION_WARNINGS", False
|
| 305 |
+
)
|
| 306 |
+
self.jitTimeProfiling = get_bool_env_var(f"{prefix}_JIT_TIME_PROFILING", False)
|
| 307 |
+
self.disable_file_caching = get_bool_env_var(
|
| 308 |
+
f"{prefix}_DISABLE_FILE_CACHING", False
|
| 309 |
+
)
|
| 310 |
+
self.file_caching_capacity = get_int_env_var(
|
| 311 |
+
f"{prefix}_FILE_CACHING_CAPACITY", 1000
|
| 312 |
+
)
|
| 313 |
+
self.generate_source_location = not get_bool_env_var(
|
| 314 |
+
f"{prefix}_NO_SOURCE_LOCATION", False
|
| 315 |
+
)
|
| 316 |
+
# set cuda
|
| 317 |
+
self.cuda_toolkit = get_cuda_toolkit_path()
|
| 318 |
+
|
| 319 |
+
# set mlir shared libraries
|
| 320 |
+
self.shared_libs = get_prefix_dsl_libs(prefix)
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/jit_executor.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This module provides jit executor related classes
|
| 14 |
+
"""
|
| 15 |
+
import ctypes
|
| 16 |
+
import inspect
|
| 17 |
+
import io
|
| 18 |
+
from typing import get_origin
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
# MLIR modules imports
|
| 23 |
+
from .._mlir import ir
|
| 24 |
+
|
| 25 |
+
# Local modules imports
|
| 26 |
+
from . import typing as t
|
| 27 |
+
from .common import DSLRuntimeError
|
| 28 |
+
from .runtime import cuda as cuda_helpers
|
| 29 |
+
from .runtime.jit_arg_adapters import JitArgAdapterRegistry, is_arg_spec_constexpr
|
| 30 |
+
from .typing import get_c_pointers
|
| 31 |
+
from .utils.logger import log
|
| 32 |
+
from .utils.timer import timer
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class CudaSingleModule:
|
| 36 |
+
def __init__(self, cuda_module, kernel_ptr):
|
| 37 |
+
self.cuda_module = cuda_module
|
| 38 |
+
self.kernel_ptr = kernel_ptr
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class CudaModules:
|
| 42 |
+
def __init__(self, modules, args):
|
| 43 |
+
# list of CudaSingleModule
|
| 44 |
+
self.modules = modules
|
| 45 |
+
# extra kernel ptr arguments for launch
|
| 46 |
+
self.args = args
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class JitExecutor:
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
dsl,
|
| 53 |
+
engine,
|
| 54 |
+
capi_func,
|
| 55 |
+
ir_module,
|
| 56 |
+
args_spec,
|
| 57 |
+
function_name,
|
| 58 |
+
cuda_modules: CudaModules = None,
|
| 59 |
+
jit_time_profiling=False,
|
| 60 |
+
):
|
| 61 |
+
self.dsl = dsl
|
| 62 |
+
self.engine = engine
|
| 63 |
+
self.capi_func = capi_func
|
| 64 |
+
self.ir_module = ir_module
|
| 65 |
+
self.args_spec = args_spec
|
| 66 |
+
self.function_name = function_name
|
| 67 |
+
if args_spec is not None:
|
| 68 |
+
self.original_args_spec = args_spec
|
| 69 |
+
self.args_spec = self.filter_runtime_arg_spec(args_spec)
|
| 70 |
+
# cuda kernels
|
| 71 |
+
self.cuda_modules = cuda_modules
|
| 72 |
+
self.jit_time_profiling = jit_time_profiling
|
| 73 |
+
|
| 74 |
+
def filter_runtime_arg_spec(self, arg_spec: inspect.FullArgSpec):
|
| 75 |
+
runtime_args = []
|
| 76 |
+
runtime_annotations = {}
|
| 77 |
+
runtime_defaults = []
|
| 78 |
+
|
| 79 |
+
# Calculate the offset where defaults start in the original args
|
| 80 |
+
if arg_spec.defaults:
|
| 81 |
+
defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults)
|
| 82 |
+
else:
|
| 83 |
+
defaults_start_idx = len(arg_spec.args)
|
| 84 |
+
|
| 85 |
+
# Filter arguments and maintain their properties
|
| 86 |
+
for i, arg_name in enumerate(arg_spec.args):
|
| 87 |
+
arg_type = arg_spec.annotations.get(arg_name, None)
|
| 88 |
+
|
| 89 |
+
# Skip compile-time arguments
|
| 90 |
+
if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name):
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
# Keep runtime arguments
|
| 94 |
+
runtime_args.append(arg_name)
|
| 95 |
+
if arg_name in arg_spec.annotations:
|
| 96 |
+
runtime_annotations[arg_name] = arg_type
|
| 97 |
+
|
| 98 |
+
# Keep corresponding default if it exists
|
| 99 |
+
if i >= defaults_start_idx:
|
| 100 |
+
default_idx = i - defaults_start_idx
|
| 101 |
+
runtime_defaults.append(arg_spec.defaults[default_idx])
|
| 102 |
+
|
| 103 |
+
# Filter kwonlyargs and their defaults
|
| 104 |
+
runtime_kwonlyargs = []
|
| 105 |
+
runtime_kwonlydefaults = {}
|
| 106 |
+
|
| 107 |
+
if arg_spec.kwonlyargs:
|
| 108 |
+
for kwarg in arg_spec.kwonlyargs:
|
| 109 |
+
arg_type = arg_spec.annotations.get(kwarg, None)
|
| 110 |
+
|
| 111 |
+
# Apply same filtering logic
|
| 112 |
+
if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name):
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
runtime_kwonlyargs.append(kwarg)
|
| 116 |
+
if kwarg in arg_spec.annotations:
|
| 117 |
+
runtime_annotations[kwarg] = arg_type
|
| 118 |
+
if arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults:
|
| 119 |
+
runtime_kwonlydefaults[kwarg] = arg_spec.kwonlydefaults[kwarg]
|
| 120 |
+
|
| 121 |
+
# Convert runtime_defaults to tuple if not empty (as expected by FullArgSpec)
|
| 122 |
+
runtime_defaults = tuple(runtime_defaults) if runtime_defaults else None
|
| 123 |
+
|
| 124 |
+
return inspect.FullArgSpec(
|
| 125 |
+
args=runtime_args,
|
| 126 |
+
varargs=arg_spec.varargs, # Keep original varargs
|
| 127 |
+
varkw=arg_spec.varkw, # Keep original varkw
|
| 128 |
+
defaults=runtime_defaults,
|
| 129 |
+
kwonlyargs=runtime_kwonlyargs,
|
| 130 |
+
kwonlydefaults=runtime_kwonlydefaults if runtime_kwonlydefaults else None,
|
| 131 |
+
annotations=runtime_annotations,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def __del__(self):
|
| 135 |
+
if self.cuda_modules:
|
| 136 |
+
cuda_modules = [module.cuda_module for module in self.cuda_modules.modules]
|
| 137 |
+
for module in set(cuda_modules):
|
| 138 |
+
cuda_helpers.unload_cubin_module(module)
|
| 139 |
+
|
| 140 |
+
def get_constexpr_args(self) -> list[dict[str, int | str]]:
|
| 141 |
+
"""
|
| 142 |
+
This function returns the constexpr args that have been pruned from the original function signature.
|
| 143 |
+
The return type is a list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name).
|
| 144 |
+
|
| 145 |
+
:return: list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name).
|
| 146 |
+
:rtype: list[dict[str, int | str]]
|
| 147 |
+
"""
|
| 148 |
+
if self.original_args_spec is None:
|
| 149 |
+
return list()
|
| 150 |
+
constexpr_args = list()
|
| 151 |
+
for i, arg_name in enumerate(self.original_args_spec.args):
|
| 152 |
+
if arg_name not in self.args_spec.args:
|
| 153 |
+
constexpr_args.append({"argument_index": i, "argument_name": arg_name})
|
| 154 |
+
|
| 155 |
+
if self.original_args_spec.kwonlyargs:
|
| 156 |
+
for kwarg in self.original_args_spec.kwonlyargs:
|
| 157 |
+
if kwarg not in self.args_spec.kwonlyargs:
|
| 158 |
+
constexpr_args.append(
|
| 159 |
+
{"argument_index": None, "argument_name": kwarg}
|
| 160 |
+
)
|
| 161 |
+
return constexpr_args
|
| 162 |
+
|
| 163 |
+
def generate_execution_args(self, args, kwargs, args_spec: inspect.FullArgSpec):
|
| 164 |
+
"""
|
| 165 |
+
This function is the prune version of `generate_mlir_function_types` which only generates execution args
|
| 166 |
+
to get rid of mlir context.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
# Process positional arguments with defaults
|
| 170 |
+
rectified_args = list(args)
|
| 171 |
+
if args_spec.defaults and len(args) < len(args_spec.args):
|
| 172 |
+
rectified_args.extend(args_spec.defaults[len(args) - len(args_spec.args) :])
|
| 173 |
+
for k, v in kwargs.items():
|
| 174 |
+
if k in args_spec.args:
|
| 175 |
+
idx = args_spec.args.index(k)
|
| 176 |
+
if idx < len(rectified_args):
|
| 177 |
+
rectified_args[idx] = v
|
| 178 |
+
else:
|
| 179 |
+
rectified_args.append(v)
|
| 180 |
+
|
| 181 |
+
# Process keyword arguments
|
| 182 |
+
rectified_kwargs = {k: v for k, v in kwargs.items() if k not in args_spec.args}
|
| 183 |
+
if args_spec.kwonlydefaults and len(rectified_kwargs) < len(
|
| 184 |
+
args_spec.kwonlyargs
|
| 185 |
+
):
|
| 186 |
+
rectified_kwargs.update(args_spec.kwonlydefaults)
|
| 187 |
+
|
| 188 |
+
# args/kwargs must match arg_specs
|
| 189 |
+
if len(rectified_args) != len(args_spec.args) or len(rectified_kwargs) != len(
|
| 190 |
+
args_spec.kwonlyargs
|
| 191 |
+
):
|
| 192 |
+
raise DSLRuntimeError(
|
| 193 |
+
"input args/kwargs length does not match runtime function signature!",
|
| 194 |
+
context={
|
| 195 |
+
"input args length": len(rectified_args),
|
| 196 |
+
"input kwargs length": len(rectified_kwargs),
|
| 197 |
+
"function signature args length": len(args_spec.args),
|
| 198 |
+
"function signature kwonlyargs length": len(args_spec.kwonlyargs),
|
| 199 |
+
},
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
exe_args = []
|
| 203 |
+
adapted_args = []
|
| 204 |
+
input_args = rectified_args + list(rectified_kwargs.values())
|
| 205 |
+
input_arg_names = args_spec.args + args_spec.kwonlyargs
|
| 206 |
+
for arg, arg_name in zip(input_args, input_arg_names):
|
| 207 |
+
# short-cut for args already converted
|
| 208 |
+
if hasattr(arg, "__c_pointers__"):
|
| 209 |
+
exe_args.extend(arg.__c_pointers__())
|
| 210 |
+
continue
|
| 211 |
+
|
| 212 |
+
arg_type = args_spec.annotations.get(arg_name, None)
|
| 213 |
+
|
| 214 |
+
# Implicit cast to NumericMeta
|
| 215 |
+
if isinstance(arg_type, t.NumericMeta):
|
| 216 |
+
arg = t.cast(arg, arg_type)
|
| 217 |
+
else:
|
| 218 |
+
# If not any known type, try registered adapter to do the conversion
|
| 219 |
+
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
|
| 220 |
+
if adapter:
|
| 221 |
+
arg = adapter(arg)
|
| 222 |
+
adapted_args.append(arg)
|
| 223 |
+
|
| 224 |
+
exe_args.extend(get_c_pointers(arg))
|
| 225 |
+
|
| 226 |
+
return exe_args, adapted_args
|
| 227 |
+
|
| 228 |
+
def __call__(self, *args, **kwargs):
|
| 229 |
+
exe_args, adapted_args = self.generate_execution_args(
|
| 230 |
+
args, kwargs, self.args_spec
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
self.run_compiled_program(exe_args)
|
| 234 |
+
|
| 235 |
+
# Assume each execution args has type `c_void_p` to reduce the overhead of `ctypes.cast`.
|
| 236 |
+
def get_invoke_packed_args(self, exe_args):
|
| 237 |
+
if self.cuda_modules:
|
| 238 |
+
exe_args += self.cuda_modules.args
|
| 239 |
+
packed_args = (ctypes.c_void_p * len(exe_args))()
|
| 240 |
+
for argNum in range(len(exe_args)):
|
| 241 |
+
packed_args[argNum] = exe_args[argNum]
|
| 242 |
+
return packed_args
|
| 243 |
+
|
| 244 |
+
def run_compiled_program(self, exe_args):
|
| 245 |
+
if self.jit_time_profiling:
|
| 246 |
+
profiler = timer(enable=True)
|
| 247 |
+
try:
|
| 248 |
+
packed_args = profiler(self.get_invoke_packed_args)(exe_args)
|
| 249 |
+
profiler(self.capi_func)(packed_args)
|
| 250 |
+
except Exception as e:
|
| 251 |
+
raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e)
|
| 252 |
+
else:
|
| 253 |
+
try:
|
| 254 |
+
packed_args = self.get_invoke_packed_args(exe_args)
|
| 255 |
+
self.capi_func(packed_args)
|
| 256 |
+
except Exception as e:
|
| 257 |
+
raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e)
|
| 258 |
+
|
| 259 |
+
def update_jit_cuda_modules(self, kernel_symbols):
|
| 260 |
+
# preload cuda module from compiled cubin in ir and store to jit_executor.kernels.
|
| 261 |
+
if len(kernel_symbols) > 0:
|
| 262 |
+
extra_args = []
|
| 263 |
+
module = self.ir_module
|
| 264 |
+
cuda_kernel_cache = dict()
|
| 265 |
+
cuda_driver_version = cuda_helpers.get_driver_version()
|
| 266 |
+
for sym in kernel_symbols:
|
| 267 |
+
if sym not in cuda_kernel_cache:
|
| 268 |
+
log().debug(f"Loading CUDA module for symbol: {sym}")
|
| 269 |
+
|
| 270 |
+
# load cuda module/get function pointer from module and cache
|
| 271 |
+
def walk_callback(sym, func_sym, cubin_data):
|
| 272 |
+
cubin_module = cuda_helpers.load_cubin_module_data(cubin_data)
|
| 273 |
+
kernel_ptr = cuda_helpers.get_kernel_function(
|
| 274 |
+
cubin_module, func_sym
|
| 275 |
+
)
|
| 276 |
+
# Enable non-portable cluster size for CUDA version 11.8 or higher.
|
| 277 |
+
if cuda_driver_version >= 11080:
|
| 278 |
+
cuda_helpers.set_kernel_attribute(
|
| 279 |
+
kernel_ptr,
|
| 280 |
+
cuda_helpers.cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED,
|
| 281 |
+
1,
|
| 282 |
+
)
|
| 283 |
+
cuda_kernel_cache[sym] = CudaSingleModule(
|
| 284 |
+
cubin_module, kernel_ptr
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
self.walk_module_and_get_cubin_data(module, sym, walk_callback)
|
| 288 |
+
else:
|
| 289 |
+
log().debug(f"Symbol {sym} already in cache")
|
| 290 |
+
# check if kernel is empty.
|
| 291 |
+
if sym in cuda_kernel_cache:
|
| 292 |
+
extra_args.append(
|
| 293 |
+
ctypes.c_void_p(cuda_kernel_cache[sym].kernel_ptr.getPtr())
|
| 294 |
+
)
|
| 295 |
+
# store to the jit result if jit result is cached.
|
| 296 |
+
self.cuda_modules = CudaModules(cuda_kernel_cache.values(), extra_args)
|
| 297 |
+
|
| 298 |
+
return self
|
| 299 |
+
|
| 300 |
+
def _get_escaped_cubin_bytes(self, cubin_data):
|
| 301 |
+
"""This function escapes cubin data from mlir raw bytecode to executable binary bytes"""
|
| 302 |
+
|
| 303 |
+
def ishex(inp):
|
| 304 |
+
return (
|
| 305 |
+
inp in range(0x30, 0x3A)
|
| 306 |
+
or inp in range(0x61, 0x67)
|
| 307 |
+
or inp in range(0x41, 0x47)
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
converted = bytearray()
|
| 311 |
+
idx = 0
|
| 312 |
+
while idx < len(cubin_data):
|
| 313 |
+
# escape the original bytes
|
| 314 |
+
if cubin_data[idx] == 0x5C:
|
| 315 |
+
# if data of idx is b'\\'
|
| 316 |
+
if ishex(cubin_data[idx + 1]) and ishex(cubin_data[idx + 2]):
|
| 317 |
+
converted += bytearray.fromhex(
|
| 318 |
+
cubin_data[idx + 1 : idx + 3].decode()
|
| 319 |
+
)
|
| 320 |
+
idx += 3
|
| 321 |
+
elif cubin_data[idx + 1] == 0x5C:
|
| 322 |
+
converted.append(cubin_data[idx])
|
| 323 |
+
idx += 2
|
| 324 |
+
else:
|
| 325 |
+
# no escape, directly write
|
| 326 |
+
converted.append(cubin_data[idx])
|
| 327 |
+
idx += 1
|
| 328 |
+
return bytes(converted)
|
| 329 |
+
|
| 330 |
+
def walk_module_and_get_cubin_data(self, module, sym, callback):
|
| 331 |
+
"""This function is used to walk gpu binary op, extract the cubin inside, and process cubin data with callback."""
|
| 332 |
+
|
| 333 |
+
def walk_gpu_binary_op(op):
|
| 334 |
+
if op.name != "gpu.binary":
|
| 335 |
+
return ir.WalkResult.ADVANCE
|
| 336 |
+
s = io.BytesIO()
|
| 337 |
+
op.write_bytecode(s)
|
| 338 |
+
cubin_data = s.getvalue()
|
| 339 |
+
if sym.encode() not in cubin_data:
|
| 340 |
+
return ir.WalkResult.ADVANCE
|
| 341 |
+
|
| 342 |
+
if (
|
| 343 |
+
"kernels" != op.opview.sym_name.value
|
| 344 |
+
and sym != op.opview.sym_name.value
|
| 345 |
+
):
|
| 346 |
+
return ir.WalkResult.ADVANCE
|
| 347 |
+
# function symbol of kernel(gpu.launch_func) is equal to sym name in mlir
|
| 348 |
+
func_sym = sym
|
| 349 |
+
if sym == op.opview.sym_name.value and not sym.endswith("_kernel"):
|
| 350 |
+
func_sym = sym.rsplit("_", 1)[0]
|
| 351 |
+
|
| 352 |
+
cubin_data = cubin_data.split(b'bin = "')[1].split(b'">')[0]
|
| 353 |
+
cubin_data = self._get_escaped_cubin_bytes(cubin_data)
|
| 354 |
+
callback(sym, func_sym, cubin_data)
|
| 355 |
+
return ir.WalkResult.ADVANCE
|
| 356 |
+
|
| 357 |
+
module.operation.walk(walk_gpu_binary_op)
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This module provides a runtime utility functions that are needed for
|
| 14 |
+
the DSL.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from . import dlpack_types
|
| 18 |
+
from . import cuda
|
| 19 |
+
from . import jit_arg_adapters
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"dlpack_types",
|
| 23 |
+
"cuda",
|
| 24 |
+
"jit_arg_adapters",
|
| 25 |
+
]
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/cuda.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This module provides CUDA Python helper functions
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
from functools import lru_cache
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import List, Optional
|
| 20 |
+
import numpy as np
|
| 21 |
+
import os
|
| 22 |
+
import ctypes
|
| 23 |
+
|
| 24 |
+
import cuda.bindings.driver as cuda
|
| 25 |
+
import cuda.bindings.nvrtc as nvrtc
|
| 26 |
+
|
| 27 |
+
# MLIR imports
|
| 28 |
+
from ..._mlir import ir
|
| 29 |
+
from ..._mlir.dialects import gpu
|
| 30 |
+
|
| 31 |
+
# Local module imports
|
| 32 |
+
from ..utils.logger import log as _log
|
| 33 |
+
from ..common import *
|
| 34 |
+
from .jit_arg_adapters import JitArgAdapterRegistry
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# =============================================================================
|
| 38 |
+
# Utils
|
| 39 |
+
# =============================================================================
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _cudaGetErrorEnum(error):
|
| 43 |
+
if isinstance(error, cuda.CUresult):
|
| 44 |
+
err, name = cuda.cuGetErrorName(error)
|
| 45 |
+
return name if err == cuda.CUresult.CUDA_SUCCESS else "<unknown>"
|
| 46 |
+
elif isinstance(error, nvrtc.nvrtcResult):
|
| 47 |
+
return nvrtc.nvrtcGetErrorString(error)[1]
|
| 48 |
+
else:
|
| 49 |
+
raise DSLRuntimeError("Unknown error type: {}".format(error))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _get_gpu_arch_info(major, minor):
|
| 53 |
+
"""Get GPU architecture information and compatibility details."""
|
| 54 |
+
gpu_arch_map = {
|
| 55 |
+
(7, 0): ("Volta", "sm_70", ["sm_70"]), # V100
|
| 56 |
+
(7, 5): ("Turing", "sm_75", ["sm_75"]), # RTX 20 Series, Quadro RTX
|
| 57 |
+
(8, 0): ("Ampere", "sm_80", ["sm_80"]), # A100
|
| 58 |
+
(8, 6): ("Ampere", "sm_86", ["sm_86", "sm_80"]), # RTX 30 Series
|
| 59 |
+
(8, 9): ("Ada", "sm_89", ["sm_89", "sm_86"]), # RTX 40 Series
|
| 60 |
+
(8, 7): ("Ampere", "sm_87", ["sm_87", "sm_86", "sm_80"]), # A10, A40
|
| 61 |
+
(9, 0): ("Hopper", "sm_90a", ["sm_90a"]), # H100
|
| 62 |
+
(10, 0): ("Blackwell", "sm_100a", ["sm_100a"]), # B200
|
| 63 |
+
}
|
| 64 |
+
return gpu_arch_map.get(
|
| 65 |
+
(major, minor), ("Unknown", f"sm_{major}{minor}", [f"sm_{major}{minor}"])
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_compute_capability_major_minor(device_id: int = 0):
|
| 70 |
+
"""
|
| 71 |
+
Returns the compute capability of the CUDA device as a tuple of (major, minor).
|
| 72 |
+
For example: (8, 0) for Ampere, (9, 0) for Hopper, (10, 0) for Blackwell.
|
| 73 |
+
Returns None on failure.
|
| 74 |
+
"""
|
| 75 |
+
try:
|
| 76 |
+
checkCudaErrors(cuda.cuInit(0))
|
| 77 |
+
device = checkCudaErrors(cuda.cuDeviceGet(device_id))
|
| 78 |
+
major = checkCudaErrors(
|
| 79 |
+
cuda.cuDeviceGetAttribute(
|
| 80 |
+
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
|
| 81 |
+
device,
|
| 82 |
+
)
|
| 83 |
+
)
|
| 84 |
+
minor = checkCudaErrors(
|
| 85 |
+
cuda.cuDeviceGetAttribute(
|
| 86 |
+
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
|
| 87 |
+
device,
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
return major, minor
|
| 91 |
+
except RuntimeError as e:
|
| 92 |
+
_log().info(f"Failed to get CUDA compute capability: {e}")
|
| 93 |
+
return None, None
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@dataclass
|
| 97 |
+
class DeviceInfo:
|
| 98 |
+
"""Data class to store CUDA device information."""
|
| 99 |
+
|
| 100 |
+
device_count: int = 0
|
| 101 |
+
current_device: int = 0
|
| 102 |
+
device_name: Optional[str] = None
|
| 103 |
+
major_version: Optional[int] = None
|
| 104 |
+
minor_version: Optional[int] = None
|
| 105 |
+
arch_name: Optional[str] = None
|
| 106 |
+
sm_arch: Optional[str] = None
|
| 107 |
+
compatible_archs: Optional[List[str]] = None
|
| 108 |
+
memory_gb: Optional[float] = None
|
| 109 |
+
target_arch: Optional[str] = None
|
| 110 |
+
error_message: Optional[str] = None
|
| 111 |
+
initialization_failed: bool = False
|
| 112 |
+
|
| 113 |
+
def pretty_str(self) -> str:
|
| 114 |
+
"""
|
| 115 |
+
Convert DeviceInfo to a formatted string for display.
|
| 116 |
+
"""
|
| 117 |
+
info = ""
|
| 118 |
+
|
| 119 |
+
if self.initialization_failed:
|
| 120 |
+
return f"{Colors.BOLD}- CUDA initialization failed{Colors.RESET}"
|
| 121 |
+
|
| 122 |
+
if self.error_message:
|
| 123 |
+
return f"{Colors.BOLD}- Failed to get GPU info: {self.error_message}{Colors.RESET}"
|
| 124 |
+
|
| 125 |
+
if self.device_count > 0:
|
| 126 |
+
info += f"{Colors.BOLD}- CUDA devices available: {self.device_count} (current: {self.current_device})\n"
|
| 127 |
+
|
| 128 |
+
if self.major_version is not None and self.minor_version is not None:
|
| 129 |
+
info += f"- Architecture: {Colors.BLUE}{self.arch_name}{Colors.RESET} ({Colors.GREEN}{self.sm_arch}{Colors.RESET})\n"
|
| 130 |
+
info += f"- Compatible SM archs: {Colors.GREEN}{', '.join(self.compatible_archs or [])}{Colors.RESET}\n"
|
| 131 |
+
|
| 132 |
+
if self.memory_gb is not None:
|
| 133 |
+
info += f"- Total Memory: {Colors.BLUE}{self.memory_gb:.2f} GB{Colors.RESET}\n"
|
| 134 |
+
|
| 135 |
+
else:
|
| 136 |
+
info += f"- Compute capability: unknown\n"
|
| 137 |
+
info += f"- SM arch: unknown{Colors.RESET}\n"
|
| 138 |
+
else:
|
| 139 |
+
info += f"- No devices available\n"
|
| 140 |
+
|
| 141 |
+
return info
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_device_info() -> DeviceInfo:
|
| 145 |
+
"""
|
| 146 |
+
Get detailed information about CUDA devices.
|
| 147 |
+
Returns a DeviceInfo dataclass with device information.
|
| 148 |
+
"""
|
| 149 |
+
device_info = DeviceInfo()
|
| 150 |
+
|
| 151 |
+
# Initialize CUDA if not already initialized
|
| 152 |
+
try:
|
| 153 |
+
result = cuda.cuInit(0)
|
| 154 |
+
if result[0].value: # Check for error
|
| 155 |
+
device_info.initialization_failed = True
|
| 156 |
+
return device_info
|
| 157 |
+
except:
|
| 158 |
+
pass
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
# Get device count
|
| 162 |
+
result = cuda.cuDeviceGetCount()
|
| 163 |
+
device_info.device_count = result[1] if result[0].value == 0 else 0
|
| 164 |
+
|
| 165 |
+
if device_info.device_count > 0:
|
| 166 |
+
# Get current device
|
| 167 |
+
try:
|
| 168 |
+
result = cuda.cuCtxGetDevice()
|
| 169 |
+
if result[0].value == 0:
|
| 170 |
+
device_info.current_device = result[1]
|
| 171 |
+
except:
|
| 172 |
+
pass
|
| 173 |
+
|
| 174 |
+
# Get device name
|
| 175 |
+
try:
|
| 176 |
+
name_result = cuda.cuDeviceGetName(100, device_info.current_device)
|
| 177 |
+
if name_result[0].value == 0:
|
| 178 |
+
device_info.device_name = name_result[1]
|
| 179 |
+
except:
|
| 180 |
+
pass
|
| 181 |
+
|
| 182 |
+
# Get compute capability and architecture info
|
| 183 |
+
try:
|
| 184 |
+
major, minor = get_compute_capability_major_minor(
|
| 185 |
+
device_info.current_device
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Check if we successfully got the compute capability
|
| 189 |
+
if major is not None and minor is not None:
|
| 190 |
+
device_info.major_version = major
|
| 191 |
+
device_info.minor_version = minor
|
| 192 |
+
|
| 193 |
+
arch_name, sm_arch, compatible_archs = _get_gpu_arch_info(
|
| 194 |
+
device_info.major_version, device_info.minor_version
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
device_info.arch_name = arch_name
|
| 198 |
+
device_info.sm_arch = sm_arch
|
| 199 |
+
device_info.compatible_archs = compatible_archs
|
| 200 |
+
|
| 201 |
+
# Get memory info
|
| 202 |
+
try:
|
| 203 |
+
total_mem = cuda.cuDeviceGetAttribute(
|
| 204 |
+
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_TOTAL_MEMORY,
|
| 205 |
+
device_info.current_device,
|
| 206 |
+
)
|
| 207 |
+
if total_mem[0].value == 0:
|
| 208 |
+
device_info.memory_gb = total_mem[1] / (
|
| 209 |
+
1024 * 1024 * 1024
|
| 210 |
+
) # Convert to GB
|
| 211 |
+
except:
|
| 212 |
+
pass
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
pass # Compute capability info will remain None
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
device_info.error_message = str(e)
|
| 219 |
+
|
| 220 |
+
return device_info
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def checkCudaErrors(result):
|
| 224 |
+
"""Check CUDA errors and provide detailed error messages."""
|
| 225 |
+
if result[0].value:
|
| 226 |
+
error_code = result[0].value
|
| 227 |
+
error_name = _cudaGetErrorEnum(result[0])
|
| 228 |
+
|
| 229 |
+
raise DSLCudaRuntimeError(error_code, error_name)
|
| 230 |
+
|
| 231 |
+
if len(result) == 1:
|
| 232 |
+
return None
|
| 233 |
+
elif len(result) == 2:
|
| 234 |
+
return result[1]
|
| 235 |
+
else:
|
| 236 |
+
return result[1:]
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# =============================================================================
|
| 240 |
+
# Driver Helpers
|
| 241 |
+
# =============================================================================
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@lru_cache(maxsize=1)
|
| 245 |
+
def initialize_cuda_context(device_id: int = 0, flags: int = 0):
|
| 246 |
+
"""
|
| 247 |
+
Initializes the CUDA context for a specified device.
|
| 248 |
+
"""
|
| 249 |
+
# Initialize CUDA Driver API
|
| 250 |
+
_log().info(f"cuInit {flags}")
|
| 251 |
+
checkCudaErrors(cuda.cuInit(flags))
|
| 252 |
+
# Retrieve handle for device
|
| 253 |
+
_log().info(f"cuDeviceGet {device_id}")
|
| 254 |
+
cuDevice = checkCudaErrors(cuda.cuDeviceGet(device_id))
|
| 255 |
+
_log().info(f"{cuDevice} <-- cuDeviceGet")
|
| 256 |
+
# Create context
|
| 257 |
+
_log().info(f"cuCtxCreate {0} {cuDevice}")
|
| 258 |
+
if cuda.CUDA_VERSION >= 13000:
|
| 259 |
+
# Use cuCtxCreate_v4 API with explicit CUctxCreateParams None, since v2
|
| 260 |
+
# and v3 API has been removed from CTK 13.
|
| 261 |
+
# See https://github.com/NVIDIA/cuda-python/pull/792
|
| 262 |
+
context = checkCudaErrors(cuda.cuCtxCreate(None, 0, cuDevice))
|
| 263 |
+
else:
|
| 264 |
+
context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice))
|
| 265 |
+
_log().info(f"{context} <-- cuCtxCreate")
|
| 266 |
+
|
| 267 |
+
return context
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def load_cubin_module(cubin_file):
|
| 271 |
+
"""
|
| 272 |
+
Loads a CUBIN file and returns the module.
|
| 273 |
+
"""
|
| 274 |
+
# Load CUBIN file as binary data
|
| 275 |
+
_log().info(f"read cubin {cubin_file}")
|
| 276 |
+
with open(cubin_file, "rb") as f:
|
| 277 |
+
cubin_data = f.read()
|
| 278 |
+
# Load module data
|
| 279 |
+
_log().info(f"cuModuleLoadData {np.char.array(cubin_data).ctypes.data}")
|
| 280 |
+
module = checkCudaErrors(
|
| 281 |
+
cuda.cuModuleLoadData(np.char.array(cubin_data).ctypes.data)
|
| 282 |
+
)
|
| 283 |
+
return module
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def unload_cubin_module(module):
|
| 287 |
+
"""
|
| 288 |
+
Unloads a CUBIN module.
|
| 289 |
+
"""
|
| 290 |
+
_log().info(f"cuModuleUnload {module}")
|
| 291 |
+
checkCudaErrors(cuda.cuModuleUnload(module))
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def load_cubin_module_data(cubin_data):
|
| 295 |
+
"""
|
| 296 |
+
Loads a CUBIN from data and returns the module.
|
| 297 |
+
"""
|
| 298 |
+
# Load module data
|
| 299 |
+
_log().info(f"cuModuleLoadData {np.char.array(cubin_data).ctypes.data}")
|
| 300 |
+
module = checkCudaErrors(
|
| 301 |
+
cuda.cuModuleLoadData(np.char.array(cubin_data).ctypes.data)
|
| 302 |
+
)
|
| 303 |
+
return module
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def get_kernel_function(module, kernel_name):
|
| 307 |
+
"""
|
| 308 |
+
Retrieves the kernel function from the module.
|
| 309 |
+
"""
|
| 310 |
+
_log().info(f"cuModuleGetFunction {module} {kernel_name}")
|
| 311 |
+
kernel = checkCudaErrors(
|
| 312 |
+
cuda.cuModuleGetFunction(module, bytes(kernel_name, "utf-8"))
|
| 313 |
+
)
|
| 314 |
+
_log().info(f"{kernel} <-- cuModuleGetFunction")
|
| 315 |
+
return kernel
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size, kernel_args=None):
|
| 319 |
+
"""
|
| 320 |
+
Launches the CUDA kernel.
|
| 321 |
+
"""
|
| 322 |
+
_log().info(
|
| 323 |
+
f"cuLaunchKernel {kernel} grid={grid_dims} blocks={block_dims} smem_size={smem_size} stream={stream} {kernel_args}"
|
| 324 |
+
)
|
| 325 |
+
checkCudaErrors(
|
| 326 |
+
cuda.cuLaunchKernel(
|
| 327 |
+
kernel,
|
| 328 |
+
grid_dims[0],
|
| 329 |
+
grid_dims[1],
|
| 330 |
+
grid_dims[2],
|
| 331 |
+
block_dims[0],
|
| 332 |
+
block_dims[1],
|
| 333 |
+
block_dims[2],
|
| 334 |
+
smem_size, # Shared memory size
|
| 335 |
+
stream,
|
| 336 |
+
kernel_args,
|
| 337 |
+
0, # Extra parameters
|
| 338 |
+
)
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def stream_sync(stream):
|
| 343 |
+
"""
|
| 344 |
+
Synchronizes the CUDA stream.
|
| 345 |
+
"""
|
| 346 |
+
_log().info(f"cuStreamSynchronize {stream}")
|
| 347 |
+
checkCudaErrors(cuda.cuStreamSynchronize(stream))
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def stream_create(id=0):
|
| 351 |
+
"""
|
| 352 |
+
Creates the CUDA stream.
|
| 353 |
+
"""
|
| 354 |
+
_log().info(f"cuStreamCreate {id}")
|
| 355 |
+
stream = checkCudaErrors(cuda.cuStreamCreate(id))
|
| 356 |
+
_log().info(f"{stream} <-- cuStreamCreate")
|
| 357 |
+
return stream
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def stream_destroy(stream):
|
| 361 |
+
"""
|
| 362 |
+
Destroys the CUDA stream.
|
| 363 |
+
"""
|
| 364 |
+
_log().info(f"cuStreamDestroy {stream}")
|
| 365 |
+
checkCudaErrors(cuda.cuStreamDestroy(stream))
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def context_destroy(context):
|
| 369 |
+
"""
|
| 370 |
+
Destroys the CUDA context.
|
| 371 |
+
"""
|
| 372 |
+
_log().info(f"cuCtxDestroy {context}")
|
| 373 |
+
checkCudaErrors(cuda.cuCtxDestroy(context))
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def allocate(size_in_bytes: int, stream=None):
|
| 377 |
+
"""
|
| 378 |
+
Allocate device memory based on numpy host array size.
|
| 379 |
+
"""
|
| 380 |
+
_log().info("Allocate size_in_bytes=[%s] stream=[%s]", size_in_bytes, stream)
|
| 381 |
+
if stream is None:
|
| 382 |
+
device_memory = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes))
|
| 383 |
+
else:
|
| 384 |
+
device_memory = checkCudaErrors(cuda.cuMemAllocAsync(size_in_bytes, stream))
|
| 385 |
+
_log().info("Allocated [%s]", device_memory)
|
| 386 |
+
return device_memory
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def deallocate(device_pointer, stream=None):
|
| 390 |
+
"""
|
| 391 |
+
Deallocate the specified device memory pointer.
|
| 392 |
+
"""
|
| 393 |
+
_log().info(
|
| 394 |
+
"Deallocate device_pointer=[%s] stream=[%s]", hex(int(device_pointer)), stream
|
| 395 |
+
)
|
| 396 |
+
if stream is None:
|
| 397 |
+
checkCudaErrors(cuda.cuMemFree(device_pointer))
|
| 398 |
+
else:
|
| 399 |
+
checkCudaErrors(cuda.cuMemFreeAsync(device_pointer, stream))
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def memcpy_h2d(host_pointer, device_pointer, size_in_bytes, stream=None):
|
| 403 |
+
"""
|
| 404 |
+
Copy data from host to device memory.
|
| 405 |
+
"""
|
| 406 |
+
_log().info(
|
| 407 |
+
"Copy host-to-device host_pointer[%s] device_ptr=[%s] size_in_bytes=[%s] stream=[%s]",
|
| 408 |
+
hex(host_pointer),
|
| 409 |
+
hex(int(device_pointer)),
|
| 410 |
+
size_in_bytes,
|
| 411 |
+
stream,
|
| 412 |
+
)
|
| 413 |
+
if stream is None:
|
| 414 |
+
checkCudaErrors(cuda.cuMemcpyHtoD(device_pointer, host_pointer, size_in_bytes))
|
| 415 |
+
else:
|
| 416 |
+
checkCudaErrors(
|
| 417 |
+
cuda.cuMemcpyHtoDAsync(device_pointer, host_pointer, size_in_bytes, stream)
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def memcpy_d2h(host_pointer, device_pointer, size_in_bytes, stream=None):
|
| 422 |
+
"""
|
| 423 |
+
Copy data from device to host memory.
|
| 424 |
+
"""
|
| 425 |
+
_log().info(
|
| 426 |
+
"Copy device-host-to device_pointer=[%s] host_pointer[%s] size_in_bytes=[%s] stream=[%s]",
|
| 427 |
+
hex(int(device_pointer)),
|
| 428 |
+
hex(host_pointer),
|
| 429 |
+
size_in_bytes,
|
| 430 |
+
stream,
|
| 431 |
+
)
|
| 432 |
+
if stream is None:
|
| 433 |
+
checkCudaErrors(cuda.cuMemcpyDtoH(host_pointer, device_pointer, size_in_bytes))
|
| 434 |
+
else:
|
| 435 |
+
checkCudaErrors(
|
| 436 |
+
cuda.cuMemcpyDtoHAsync(host_pointer, device_pointer, size_in_bytes, stream)
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def default_stream():
|
| 441 |
+
return cuda.CUstream(0)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def get_driver_version():
|
| 445 |
+
"""
|
| 446 |
+
Returns the CUDA driver version.
|
| 447 |
+
"""
|
| 448 |
+
return checkCudaErrors(cuda.cuDriverGetVersion())
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def set_kernel_attribute(kernel, attribute, value):
|
| 452 |
+
"""
|
| 453 |
+
Sets a CUDA kernel attribute.
|
| 454 |
+
"""
|
| 455 |
+
return checkCudaErrors(cuda.cuFuncSetAttribute(kernel, attribute, value))
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
@JitArgAdapterRegistry.register_jit_arg_adapter(cuda.CUstream)
|
| 459 |
+
class StreamAdapter:
|
| 460 |
+
"""
|
| 461 |
+
Convert a CUDA stream to a stream representation for JIT arg generation.
|
| 462 |
+
"""
|
| 463 |
+
|
| 464 |
+
def __init__(self, arg):
|
| 465 |
+
self._arg = arg
|
| 466 |
+
self._c_pointer = self._arg.getPtr()
|
| 467 |
+
|
| 468 |
+
def __new_from_mlir_values__(self, values):
|
| 469 |
+
assert len(values) == 1
|
| 470 |
+
return values[0]
|
| 471 |
+
|
| 472 |
+
def __c_pointers__(self):
|
| 473 |
+
return [self._c_pointer]
|
| 474 |
+
|
| 475 |
+
def __get_mlir_types__(self):
|
| 476 |
+
return [gpu.AsyncTokenType.get()]
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/device_tensor.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
import copy
|
| 13 |
+
|
| 14 |
+
from . import cuda as cuda_helpers
|
| 15 |
+
from .tensor_descriptor import *
|
| 16 |
+
from ..common import *
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def allocate(tensor: TensorDescriptor, stream=None):
|
| 20 |
+
"""
|
| 21 |
+
Allocates GPU memory
|
| 22 |
+
"""
|
| 23 |
+
if tensor._check_is_managed_by_framework():
|
| 24 |
+
raise DSLRuntimeError(
|
| 25 |
+
"GPU tensors are managed by the framework and cannot be modified."
|
| 26 |
+
)
|
| 27 |
+
if not tensor.device_pointer is None:
|
| 28 |
+
raise DSLRuntimeError("Tensor is already allocated on the device.")
|
| 29 |
+
|
| 30 |
+
tensor.device_pointer = cuda_helpers.allocate(tensor.size_in_bytes, stream)
|
| 31 |
+
|
| 32 |
+
log().info("Allocate done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def deallocate(tensor: TensorDescriptor, stream=None):
|
| 36 |
+
"""
|
| 37 |
+
Deallocates GPU memory
|
| 38 |
+
"""
|
| 39 |
+
if tensor._check_is_managed_by_framework():
|
| 40 |
+
raise DSLRuntimeError(
|
| 41 |
+
"GPU tensors are managed by the framework and cannot be modified."
|
| 42 |
+
)
|
| 43 |
+
if tensor.device_pointer is None:
|
| 44 |
+
raise DSLRuntimeError("Tensor is not allocated on the device.")
|
| 45 |
+
|
| 46 |
+
log().info(
|
| 47 |
+
"Deallocating done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
cuda_helpers.deallocate(tensor.device_pointer, stream)
|
| 51 |
+
tensor.device_pointer = None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def copy_to_gpu(tensor: TensorDescriptor, do_allocate=True, stream=None):
|
| 55 |
+
"""
|
| 56 |
+
Copies data from host memory to the GPU memory.
|
| 57 |
+
If do_allocate is True, it first calls allocate
|
| 58 |
+
"""
|
| 59 |
+
log().info("copyin tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
|
| 60 |
+
if do_allocate:
|
| 61 |
+
allocate(tensor, stream)
|
| 62 |
+
cuda_helpers.memcpy_h2d(
|
| 63 |
+
tensor.data_ptr, tensor.device_pointer, tensor.size_in_bytes, stream
|
| 64 |
+
)
|
| 65 |
+
log().info("copyin done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
|
| 66 |
+
return tensor
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def copy_from_gpu(tensor: TensorDescriptor, do_deallocate=True, stream=None):
|
| 70 |
+
"""
|
| 71 |
+
Copies data from GPU memory back to the host.
|
| 72 |
+
If do_deallocate is True, it calls deallocate
|
| 73 |
+
"""
|
| 74 |
+
log().info("copyout tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
|
| 75 |
+
if tensor._check_is_managed_by_framework():
|
| 76 |
+
raise DSLRuntimeError(
|
| 77 |
+
"GPU tensors are managed by the framework and cannot be modified."
|
| 78 |
+
)
|
| 79 |
+
if tensor.device_pointer is None:
|
| 80 |
+
raise DSLRuntimeError("Tensor is not allocated on the device.")
|
| 81 |
+
|
| 82 |
+
cuda_helpers.memcpy_d2h(
|
| 83 |
+
tensor.data_ptr, tensor.device_pointer, tensor.size_in_bytes, stream
|
| 84 |
+
)
|
| 85 |
+
if do_deallocate:
|
| 86 |
+
deallocate(tensor, stream)
|
| 87 |
+
log().info("copyout done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def to_gpu(tensor, stream=None) -> TensorDescriptor:
|
| 91 |
+
"""
|
| 92 |
+
Copies the tensor to the GPU memory from Host memory
|
| 93 |
+
"""
|
| 94 |
+
if isinstance(tensor, TensorDescriptor):
|
| 95 |
+
new_tensor = copy.copy(tensor)
|
| 96 |
+
copy_to_gpu(new_tensor, stream=stream)
|
| 97 |
+
return new_tensor
|
| 98 |
+
|
| 99 |
+
if TensorDescriptor.can_transformed_to_dlpack(tensor):
|
| 100 |
+
new_tensor = TensorDescriptor(tensor)
|
| 101 |
+
copy_to_gpu(new_tensor, stream=stream)
|
| 102 |
+
return new_tensor
|
| 103 |
+
|
| 104 |
+
raise DSLRuntimeError("Unsupported type")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def from_gpu(tensor, stream=None) -> TensorDescriptor:
|
| 108 |
+
"""
|
| 109 |
+
Copies the tensor to the GPU memory from Host memory
|
| 110 |
+
"""
|
| 111 |
+
if isinstance(tensor, TensorDescriptor):
|
| 112 |
+
new_tensor = copy.copy(tensor)
|
| 113 |
+
copy_from_gpu(new_tensor, stream=stream)
|
| 114 |
+
return new_tensor
|
| 115 |
+
|
| 116 |
+
if TensorDescriptor.can_transformed_to_dlpack(tensor):
|
| 117 |
+
new_tensor = TensorDescriptor(tensor)
|
| 118 |
+
copy_from_gpu(new_tensor, stream=stream)
|
| 119 |
+
return new_tensor
|
| 120 |
+
|
| 121 |
+
raise DSLRuntimeError("Unsupported type")
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/dlpack_types.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This module provides helper structs for dlpack.
|
| 14 |
+
DLPack is an open standard for in-memory tensor structures, enabling
|
| 15 |
+
seamless sharing of tensors across different frameworks.
|
| 16 |
+
Learn more at: https://github.com/dmlc/dlpack
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import ctypes
|
| 20 |
+
import enum
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DLDeviceType(enum.IntEnum):
|
| 24 |
+
"""Enums for device types based on the DLPack specification."""
|
| 25 |
+
|
| 26 |
+
kDLCPU = 1
|
| 27 |
+
kDLGPU = 2
|
| 28 |
+
kDLCPUPinned = 3
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DLDataTypeCode:
|
| 32 |
+
"""Enums for data type codes based on the DLPack specification.
|
| 33 |
+
|
| 34 |
+
see https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
kDLInt = 0
|
| 38 |
+
kDLUInt = 1
|
| 39 |
+
kDLFloat = 2
|
| 40 |
+
kDLOpaqueHandle = 3
|
| 41 |
+
kDLBfloat = 4
|
| 42 |
+
kDLComplex = 5
|
| 43 |
+
kDLBool = 6
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class DLDevice(ctypes.Structure):
|
| 47 |
+
"""Structure representing the device information in DLPack."""
|
| 48 |
+
|
| 49 |
+
_fields_ = [
|
| 50 |
+
("device_type", ctypes.c_int), # kDLCPU, kDLGPU, etc.
|
| 51 |
+
("device_id", ctypes.c_int), # Device ID (e.g., GPU ID)
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class DLDataType(ctypes.Structure):
|
| 56 |
+
"""Structure representing the data type in DLPack."""
|
| 57 |
+
|
| 58 |
+
_fields_ = [
|
| 59 |
+
("code", ctypes.c_uint8), # Data type code (e.g., kDLFloat)
|
| 60 |
+
("bits", ctypes.c_uint8), # Number of bits per value
|
| 61 |
+
("lanes", ctypes.c_uint16), # Number of lanes
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class DLTensor(ctypes.Structure):
|
| 66 |
+
"""Structure representing the DLTensor in DLPack."""
|
| 67 |
+
|
| 68 |
+
_fields_ = [
|
| 69 |
+
("data", ctypes.c_void_p), # Pointer to tensor data
|
| 70 |
+
("device", DLDevice), # Device info
|
| 71 |
+
("ndim", ctypes.c_int), # Number of dimensions
|
| 72 |
+
("dtype", DLDataType), # Data type
|
| 73 |
+
("shape", ctypes.POINTER(ctypes.c_int64)), # Shape of tensor
|
| 74 |
+
("strides", ctypes.POINTER(ctypes.c_int64)), # Strides of tensor
|
| 75 |
+
("byte_offset", ctypes.c_uint64), # Byte offset to tensor data
|
| 76 |
+
]
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This module provides runtime utilities for JIT argument conversion in DSL.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from functools import wraps
|
| 17 |
+
from typing import get_origin
|
| 18 |
+
|
| 19 |
+
# Local modules imports
|
| 20 |
+
from ..common import DSLRuntimeError
|
| 21 |
+
from ..typing import (
|
| 22 |
+
Constexpr,
|
| 23 |
+
Int32,
|
| 24 |
+
Float32,
|
| 25 |
+
Boolean,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func):
|
| 30 |
+
"""
|
| 31 |
+
Check if the argument spec is a constexpr.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def _is_reserved_python_func_arg(arg_index, arg_name, func):
|
| 35 |
+
"""
|
| 36 |
+
Check if the argument is a reserved python function argument.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
if arg_index != 0:
|
| 40 |
+
return False
|
| 41 |
+
|
| 42 |
+
if arg_name == "self":
|
| 43 |
+
return True
|
| 44 |
+
|
| 45 |
+
is_classmethod = isinstance(func, classmethod) or (
|
| 46 |
+
hasattr(func, "__func__") and isinstance(func.__func__, classmethod)
|
| 47 |
+
)
|
| 48 |
+
return arg_name == "cls" and is_classmethod
|
| 49 |
+
|
| 50 |
+
return (
|
| 51 |
+
_is_reserved_python_func_arg(arg_index, arg_name, owning_func)
|
| 52 |
+
or (isinstance(arg_spec, type) and issubclass(arg_spec, Constexpr))
|
| 53 |
+
or (get_origin(arg_spec) is Constexpr)
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def is_argument_constexpr(arg, arg_spec, arg_name, arg_index, owning_func):
|
| 58 |
+
"""
|
| 59 |
+
Check if the argument is a constexpr.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def _is_type_argument(arg, arg_annotation):
|
| 63 |
+
"""
|
| 64 |
+
Check if the argument is a type argument like Type[X]
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
return isinstance(arg, type) and (
|
| 68 |
+
arg_annotation is None or get_origin(arg_annotation) is type
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return (
|
| 72 |
+
is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func)
|
| 73 |
+
or _is_type_argument(arg, arg_spec)
|
| 74 |
+
or arg is None
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class JitArgAdapterRegistry:
|
| 79 |
+
"""
|
| 80 |
+
A registry to keep track of the JIT argument adapters.
|
| 81 |
+
|
| 82 |
+
An adapter is a callable that converts a Python type to a type with following protocols supported:
|
| 83 |
+
- JitArgument
|
| 84 |
+
- DynamicExpression
|
| 85 |
+
The converted type can then be further processed by DSL to generate arguments for JIT functions.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
# A dictionary with key=type and value=callable
|
| 89 |
+
jit_arg_adapter_registry = {}
|
| 90 |
+
|
| 91 |
+
@classmethod
|
| 92 |
+
def register_jit_arg_adapter(cls, *dargs, **dkwargs):
|
| 93 |
+
"""
|
| 94 |
+
Register a JIT argument adapter callable
|
| 95 |
+
|
| 96 |
+
This can be used as a decorator on any callable like:
|
| 97 |
+
|
| 98 |
+
@register_jit_arg_adapter(my_py_type)
|
| 99 |
+
def my_adapter_for_my_py_type(arg):
|
| 100 |
+
...
|
| 101 |
+
|
| 102 |
+
@register_jit_arg_adapter(my_py_type)
|
| 103 |
+
class MyAdapterForMyPythonType:
|
| 104 |
+
...
|
| 105 |
+
|
| 106 |
+
The adapters are registered per type. If a type is already registerd, an error will be raised.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def decorator(*dargs, **dkwargs):
|
| 110 |
+
darg_python_ty = dargs[0]
|
| 111 |
+
|
| 112 |
+
@wraps(darg_python_ty)
|
| 113 |
+
def wrapper(*args, **kwargs):
|
| 114 |
+
if len(args) != 1 or not callable(args[0]):
|
| 115 |
+
raise DSLRuntimeError(
|
| 116 |
+
"a callable must be provided for registering JIT argument adapter"
|
| 117 |
+
)
|
| 118 |
+
adapter = args[0]
|
| 119 |
+
|
| 120 |
+
if darg_python_ty in cls.jit_arg_adapter_registry:
|
| 121 |
+
raise DSLRuntimeError(
|
| 122 |
+
f"JIT argument adapter for {darg_python_ty} is already registered!",
|
| 123 |
+
context={
|
| 124 |
+
"Registered adapter": cls.jit_arg_adapter_registry[
|
| 125 |
+
darg_python_ty
|
| 126 |
+
],
|
| 127 |
+
"Adapter to be registered": adapter,
|
| 128 |
+
},
|
| 129 |
+
)
|
| 130 |
+
cls.jit_arg_adapter_registry[darg_python_ty] = adapter
|
| 131 |
+
return adapter
|
| 132 |
+
|
| 133 |
+
return wrapper
|
| 134 |
+
|
| 135 |
+
if len(dargs) > 0:
|
| 136 |
+
return decorator(*dargs, **dkwargs)
|
| 137 |
+
else:
|
| 138 |
+
raise DSLRuntimeError(
|
| 139 |
+
"a Python type must be provided for registering JIT argument adapter"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
@classmethod
|
| 143 |
+
def get_registered_adapter(cls, ty):
|
| 144 |
+
"""
|
| 145 |
+
Get the registered JIT argument adapter for the given type.
|
| 146 |
+
"""
|
| 147 |
+
return cls.jit_arg_adapter_registry.get(ty, None)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# =============================================================================
|
| 151 |
+
# JIT Argument Adapters
|
| 152 |
+
# =============================================================================
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@JitArgAdapterRegistry.register_jit_arg_adapter(int)
|
| 156 |
+
@JitArgAdapterRegistry.register_jit_arg_adapter(float)
|
| 157 |
+
@JitArgAdapterRegistry.register_jit_arg_adapter(bool)
|
| 158 |
+
def _convert_python_scalar(arg):
|
| 159 |
+
"""
|
| 160 |
+
Convert a Python scalar to a DSL type.
|
| 161 |
+
"""
|
| 162 |
+
conversion_map = {
|
| 163 |
+
int: Int32,
|
| 164 |
+
float: Float32,
|
| 165 |
+
bool: Boolean,
|
| 166 |
+
}
|
| 167 |
+
return conversion_map.get(type(arg))(arg)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@JitArgAdapterRegistry.register_jit_arg_adapter(tuple)
|
| 171 |
+
@JitArgAdapterRegistry.register_jit_arg_adapter(list)
|
| 172 |
+
def _convert_python_sequence(arg):
|
| 173 |
+
"""
|
| 174 |
+
Go through each element in the sequence and convert it to a type that can be
|
| 175 |
+
further processed by DSL to generate the corresponding JIT argument(s).
|
| 176 |
+
"""
|
| 177 |
+
adapted_arg = []
|
| 178 |
+
for elem in arg:
|
| 179 |
+
adapter = JitArgAdapterRegistry.get_registered_adapter(type(elem))
|
| 180 |
+
if adapter is not None:
|
| 181 |
+
converted_elem = adapter(elem)
|
| 182 |
+
adapted_arg.append(converted_elem)
|
| 183 |
+
else:
|
| 184 |
+
# If no registered adapter is found, just return the original element
|
| 185 |
+
adapted_arg.append(elem)
|
| 186 |
+
|
| 187 |
+
assert len(adapted_arg) == len(arg)
|
| 188 |
+
return type(arg)(adapted_arg)
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
# Helpers
|
| 13 |
+
import itertools, operator
|
| 14 |
+
import ctypes
|
| 15 |
+
from . import dlpack_types as _dpack
|
| 16 |
+
from .dlpack_runtime import (
|
| 17 |
+
dlpack_to_tensor_desc,
|
| 18 |
+
get_tensor_desc_data_ptr,
|
| 19 |
+
get_tensor_desc_is_in_device,
|
| 20 |
+
get_tensor_desc_element_type,
|
| 21 |
+
get_tensor_desc_shape,
|
| 22 |
+
get_tensor_desc_stride,
|
| 23 |
+
get_tensor_desc_element_size_in_bytes,
|
| 24 |
+
get_tensor_desc_ndim,
|
| 25 |
+
get_tensor_desc_dtype_code,
|
| 26 |
+
get_tensor_desc_dtype_bits,
|
| 27 |
+
get_tensor_desc_device_type,
|
| 28 |
+
get_tensor_desc_device_id,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
from ..utils.logger import log
|
| 32 |
+
from ..common import *
|
| 33 |
+
from ..typing import (
|
| 34 |
+
Boolean,
|
| 35 |
+
Float8E5M2,
|
| 36 |
+
Int64,
|
| 37 |
+
Int32,
|
| 38 |
+
Int16,
|
| 39 |
+
Int8,
|
| 40 |
+
Uint64,
|
| 41 |
+
Uint32,
|
| 42 |
+
Uint16,
|
| 43 |
+
Uint8,
|
| 44 |
+
Float64,
|
| 45 |
+
Float32,
|
| 46 |
+
Float16,
|
| 47 |
+
BFloat16,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class TensorDescriptor:
|
| 52 |
+
def __init__(self, tensor):
|
| 53 |
+
"""Initialize with a tensor that supports the DLPack protocol.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
tensor: Any tensor object that implements __dlpack__ and __dlpack_device__
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
self.tensor = tensor
|
| 60 |
+
self._capsule = dlpack_to_tensor_desc(tensor)
|
| 61 |
+
|
| 62 |
+
self.data_ptr = get_tensor_desc_data_ptr(self._capsule)
|
| 63 |
+
self.device_type = get_tensor_desc_device_type(self._capsule)
|
| 64 |
+
self.device_type = _dpack.DLDeviceType(self.device_type)
|
| 65 |
+
|
| 66 |
+
if self.device_type == _dpack.DLDeviceType.kDLGPU:
|
| 67 |
+
self.device_pointer = self.data_ptr
|
| 68 |
+
elif self.device_type == _dpack.DLDeviceType.kDLCPU:
|
| 69 |
+
self.device_pointer = None
|
| 70 |
+
else:
|
| 71 |
+
raise DSLRuntimeError(
|
| 72 |
+
f"DLPack device type is not supported {self.dl_tensor.device.device_type}"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
log().info("TensorDescriptor is created = [%s]", self)
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def can_transformed_to_dlpack(dl_tensor):
|
| 79 |
+
if not hasattr(dl_tensor, "__dlpack__") or not hasattr(
|
| 80 |
+
dl_tensor, "__dlpack_device__"
|
| 81 |
+
):
|
| 82 |
+
return False
|
| 83 |
+
return True
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def is_in_device(self):
|
| 87 |
+
"""Check if the tensor is stored on a device."""
|
| 88 |
+
return not self.device_pointer is None
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def device_id(self):
|
| 92 |
+
"""Return device id where tensor resides."""
|
| 93 |
+
if self.is_in_device:
|
| 94 |
+
return get_tensor_desc_device_id(self._capsule)
|
| 95 |
+
return -1
|
| 96 |
+
|
| 97 |
+
@property
|
| 98 |
+
def element_type(self):
|
| 99 |
+
"""Return the corresponding Python type based on DLPack dtype metadata."""
|
| 100 |
+
str_element_type = get_tensor_desc_element_type(self._capsule)
|
| 101 |
+
dtype_map = {
|
| 102 |
+
# bool is 8bit from numpy and torch
|
| 103 |
+
"Bool": Boolean,
|
| 104 |
+
"Int64": Int64,
|
| 105 |
+
"Int32": Int32,
|
| 106 |
+
"Int16": Int16,
|
| 107 |
+
"Int8": Int8,
|
| 108 |
+
"UInt64": Uint64,
|
| 109 |
+
"UInt32": Uint32,
|
| 110 |
+
"UInt16": Uint16,
|
| 111 |
+
"UInt8": Uint8,
|
| 112 |
+
"Float64": Float64,
|
| 113 |
+
"Float32": Float32,
|
| 114 |
+
"Float16": Float16,
|
| 115 |
+
"BFloat16": BFloat16,
|
| 116 |
+
"Float8E5M2": Float8E5M2,
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
if str_element_type not in dtype_map:
|
| 120 |
+
raise KeyError(
|
| 121 |
+
f"Unsupported element type in dlpack: '{str_element_type}'. Supported types are: {list(dtype_map.keys())}"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
return dtype_map[str_element_type]
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def shape(self):
|
| 128 |
+
"""Return the shape of the tensor."""
|
| 129 |
+
return get_tensor_desc_shape(self._capsule)
|
| 130 |
+
|
| 131 |
+
@property
|
| 132 |
+
def rank(self):
|
| 133 |
+
"""Return the rank of the tensor."""
|
| 134 |
+
return get_tensor_desc_ndim(self._capsule)
|
| 135 |
+
|
| 136 |
+
@property
|
| 137 |
+
def strides(self):
|
| 138 |
+
"""Return the rank of the tensor."""
|
| 139 |
+
return get_tensor_desc_stride(self._capsule)
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def element_size_in_bytes(self):
|
| 143 |
+
"""Calculate the element size in bytes of the DLPack tensor."""
|
| 144 |
+
return get_tensor_desc_element_size_in_bytes(self._capsule)
|
| 145 |
+
|
| 146 |
+
@property
|
| 147 |
+
def size_in_bytes(self):
|
| 148 |
+
"""Calculate the total size in bytes of the DLPack tensor."""
|
| 149 |
+
# Calculate the number of elements using the shape
|
| 150 |
+
ndim = get_tensor_desc_ndim(self._capsule)
|
| 151 |
+
shape = get_tensor_desc_shape(self._capsule)
|
| 152 |
+
num_elements = 1
|
| 153 |
+
for i in range(ndim):
|
| 154 |
+
num_elements *= shape[i]
|
| 155 |
+
|
| 156 |
+
# Total bytes
|
| 157 |
+
total_bytes = self.element_size_in_bytes * num_elements
|
| 158 |
+
return total_bytes
|
| 159 |
+
|
| 160 |
+
def __str__(self):
|
| 161 |
+
"""Return a compact string representation of the device_tensor with a tensor prefix."""
|
| 162 |
+
# Extract shape
|
| 163 |
+
shape = "x".join(map(str, self.shape))
|
| 164 |
+
|
| 165 |
+
# Extract dtype
|
| 166 |
+
dtype_code = get_tensor_desc_dtype_code(self._capsule)
|
| 167 |
+
dtype_bits = get_tensor_desc_dtype_bits(self._capsule)
|
| 168 |
+
dtype = (
|
| 169 |
+
f"i{dtype_bits}"
|
| 170 |
+
if dtype_code == _dpack.DLDataTypeCode.kDLInt
|
| 171 |
+
else f"f{dtype_bits}"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Extract device
|
| 175 |
+
device_type = "cpu" if not self.is_in_device else "gpu"
|
| 176 |
+
|
| 177 |
+
return f"tensor<{shape}x{dtype}>_{device_type}"
|
| 178 |
+
|
| 179 |
+
def _check_is_managed_by_framework(self):
|
| 180 |
+
"""
|
| 181 |
+
Ensure the tensor is not managed by the framework (e.g., GPU tensor).
|
| 182 |
+
Raises an exception if the tensor is framework-managed.
|
| 183 |
+
"""
|
| 184 |
+
return self.device_type == _dpack.DLDeviceType.kDLGPU
|
| 185 |
+
|
| 186 |
+
@staticmethod
|
| 187 |
+
def is_compatible(maybe_tensor_descriptor) -> bool:
|
| 188 |
+
"""Check if the object is a TensorDescriptor or can be converted to one."""
|
| 189 |
+
return isinstance(
|
| 190 |
+
maybe_tensor_descriptor, TensorDescriptor
|
| 191 |
+
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def from_tensor(tensor) -> TensorDescriptor:
|
| 195 |
+
"""Create a TensorDescriptor from a tensor object."""
|
| 196 |
+
return TensorDescriptor(tensor)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def to_tensor(tensor_descriptor: TensorDescriptor):
|
| 200 |
+
"""Return tensor object from tensor descriptor."""
|
| 201 |
+
return tensor_descriptor.tensor
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/typing.py
ADDED
|
@@ -0,0 +1,1962 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
import ctypes
|
| 13 |
+
import numpy as np
|
| 14 |
+
import operator
|
| 15 |
+
from typing_extensions import deprecated
|
| 16 |
+
from functools import reduce
|
| 17 |
+
from typing import (
|
| 18 |
+
Generic,
|
| 19 |
+
Protocol,
|
| 20 |
+
Union,
|
| 21 |
+
Any,
|
| 22 |
+
List,
|
| 23 |
+
Type,
|
| 24 |
+
TypeVar,
|
| 25 |
+
overload,
|
| 26 |
+
runtime_checkable,
|
| 27 |
+
get_origin,
|
| 28 |
+
)
|
| 29 |
+
from types import FunctionType
|
| 30 |
+
from dataclasses import dataclass
|
| 31 |
+
from abc import ABC, abstractmethod
|
| 32 |
+
|
| 33 |
+
from .common import *
|
| 34 |
+
from .ast_helpers import const_expr
|
| 35 |
+
from ._mlir_helpers import arith as arith_helper, lru_cache_ir
|
| 36 |
+
from ._mlir_helpers.arith import ArithValue
|
| 37 |
+
|
| 38 |
+
from .._mlir import ir
|
| 39 |
+
from .._mlir.extras import types as T
|
| 40 |
+
from .._mlir.dialects import arith, math
|
| 41 |
+
|
| 42 |
+
# =============================================================================
|
| 43 |
+
# Dynamic Expression Protocol
|
| 44 |
+
# =============================================================================
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@runtime_checkable
|
| 48 |
+
class DynamicExpression(Protocol):
|
| 49 |
+
"""Protocol defining the interface for object holding dynamic values in the DSL.
|
| 50 |
+
|
| 51 |
+
This protocol enables classes to represent dynamic values in the DSL. Classes implementing
|
| 52 |
+
this protocol can be used in JIT-compiled functions and dynamic value generation.
|
| 53 |
+
|
| 54 |
+
It is required for custom data types to work correctly with following JIT features:
|
| 55 |
+
* as function argument to call another JIT function from JIT function
|
| 56 |
+
* as return value from JIT function
|
| 57 |
+
* for constructions like if-else, while-loop, etc.
|
| 58 |
+
|
| 59 |
+
:param value: The MLIR operation result value to initialize the object with
|
| 60 |
+
:type value: ir.Value
|
| 61 |
+
|
| 62 |
+
**Required Methods**
|
| 63 |
+
|
| 64 |
+
* ``__extract_mlir_values__``: Extract MLIR values from the object
|
| 65 |
+
* ``__new_from_mlir_values__``: Create new instance from MLIR values
|
| 66 |
+
|
| 67 |
+
**Implementation Example**
|
| 68 |
+
|
| 69 |
+
To implement a custom data type that works with the DSL:
|
| 70 |
+
|
| 71 |
+
.. code-block:: python
|
| 72 |
+
|
| 73 |
+
class CustomData(metaclass=DslType):
|
| 74 |
+
def __init__(self, int_value):
|
| 75 |
+
self.int_value = int_value
|
| 76 |
+
|
| 77 |
+
def __extract_mlir_values__(self):
|
| 78 |
+
return [self.int_value]
|
| 79 |
+
|
| 80 |
+
def __new_from_mlir_values__(self, values):
|
| 81 |
+
return CustomData(values[0])
|
| 82 |
+
|
| 83 |
+
**Usage in JIT Functions**
|
| 84 |
+
|
| 85 |
+
When used in JIT-compiled functions, the DSL automatically extracts MLIR values:
|
| 86 |
+
|
| 87 |
+
.. code-block:: python
|
| 88 |
+
|
| 89 |
+
@jit
|
| 90 |
+
def caller():
|
| 91 |
+
x = CustomData(1)
|
| 92 |
+
return foo(x)
|
| 93 |
+
|
| 94 |
+
This generates MLIR like:
|
| 95 |
+
|
| 96 |
+
.. code-block:: mlir
|
| 97 |
+
|
| 98 |
+
func @caller() -> i32 {
|
| 99 |
+
%0 = func.call @foo(%arg0) : (i32) -> i32
|
| 100 |
+
return %0 : i32
|
| 101 |
+
}
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __extract_mlir_values__(self):
|
| 105 |
+
"""Extract MLIR values from this object.
|
| 106 |
+
|
| 107 |
+
:return: List of MLIR values representing this object's data
|
| 108 |
+
:rtype: List[ir.Value]
|
| 109 |
+
"""
|
| 110 |
+
raise NotImplementedError
|
| 111 |
+
|
| 112 |
+
def __new_from_mlir_values__(self, values):
|
| 113 |
+
"""Create a new instance from MLIR values.
|
| 114 |
+
|
| 115 |
+
:param values: List of MLIR values to construct the object from
|
| 116 |
+
:type values: List[ir.Value]
|
| 117 |
+
:return: New instance of the implementing class
|
| 118 |
+
:rtype: Any
|
| 119 |
+
"""
|
| 120 |
+
raise NotImplementedError
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@runtime_checkable
|
| 124 |
+
class JitArgument(Protocol):
|
| 125 |
+
"""
|
| 126 |
+
Protocol class defining the interface for JIT function argument generation.
|
| 127 |
+
|
| 128 |
+
This protocol enables classes to provide the necessary information for generating
|
| 129 |
+
JIT function arguments and allow the DSL JIT executor to call JIT compiled functions.
|
| 130 |
+
|
| 131 |
+
**Required Methods**
|
| 132 |
+
|
| 133 |
+
* ``__c_pointers__``: Returns ctypes pointers for runtime execution
|
| 134 |
+
* ``__get_mlir_types__``: Returns MLIR types for function definition
|
| 135 |
+
* ``__new_from_mlir_values__``: Creates new instances from MLIR values
|
| 136 |
+
|
| 137 |
+
**Example**
|
| 138 |
+
|
| 139 |
+
.. code-block:: python
|
| 140 |
+
|
| 141 |
+
class CustomData:
|
| 142 |
+
def __init__(self, int_value, ...):
|
| 143 |
+
self.int_value = int_value
|
| 144 |
+
...
|
| 145 |
+
|
| 146 |
+
def __c_pointers__(self):
|
| 147 |
+
return [ctypes.pointer(ctypes.c_int32(self.int_value)), ...]
|
| 148 |
+
|
| 149 |
+
def __get_mlir_types__(self):
|
| 150 |
+
return [ir.IntegerType.get(32), ...]
|
| 151 |
+
|
| 152 |
+
def __new_from_mlir_values__(self, values):
|
| 153 |
+
return CustomData(values[0], ...)
|
| 154 |
+
|
| 155 |
+
@jit
|
| 156 |
+
def foo(x: CustomData):
|
| 157 |
+
a = x.int_value + 1
|
| 158 |
+
...
|
| 159 |
+
|
| 160 |
+
# `CustomData` is an argument of `foo`
|
| 161 |
+
foo(CustomData(1, ...))
|
| 162 |
+
|
| 163 |
+
When called like ``y = foo(x)``, the following steps occur:
|
| 164 |
+
|
| 165 |
+
1. JIT compiler generates MLIR function definition using ``__get_mlir_types__``
|
| 166 |
+
|
| 167 |
+
.. code-block:: mlir
|
| 168 |
+
|
| 169 |
+
func.func @foo(%arg0: i32, ...) {
|
| 170 |
+
...
|
| 171 |
+
|
| 172 |
+
return
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
2. JIT function can't use values from Python, so it needs to reconstruct the object from
|
| 176 |
+
MLIR values, a.k.a `%arg0`, with ``__new_from_mlir_values__`` and pass it to `foo`.
|
| 177 |
+
|
| 178 |
+
Following code demonstrates how JIT compiler reconstructs the object and pass to Python.
|
| 179 |
+
|
| 180 |
+
.. code-block:: python
|
| 181 |
+
|
| 182 |
+
# Implementation of IR tracing
|
| 183 |
+
new_x = CustomData(ir.Value(%arg0), ...)
|
| 184 |
+
y = foo(new_x)
|
| 185 |
+
# `x.int_value` is %arg0 rather than `c1` defined by Python.
|
| 186 |
+
|
| 187 |
+
3. For Python runtime execution, JIT engine invokes compiled function using ``__c_pointers__``
|
| 188 |
+
pointing to the underlying data object passing to JIT compiled function.
|
| 189 |
+
|
| 190 |
+
.. code-block:: python
|
| 191 |
+
|
| 192 |
+
jit_engine.invoke(compiled_foo, concat([x.__c_pointers__(), ...]))
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
def __c_pointers__(self):
|
| 196 |
+
"""
|
| 197 |
+
Generate a list of ctypes pointers for the current object.
|
| 198 |
+
|
| 199 |
+
:return: List of ctypes pointers
|
| 200 |
+
:rtype: List[ctypes.c_void_p]
|
| 201 |
+
"""
|
| 202 |
+
raise NotImplementedError
|
| 203 |
+
|
| 204 |
+
def __get_mlir_types__(self):
|
| 205 |
+
"""
|
| 206 |
+
Generate a list of MLIR types for the current object.
|
| 207 |
+
|
| 208 |
+
:return: List of MLIR types
|
| 209 |
+
:rtype: List[ir.Type]
|
| 210 |
+
"""
|
| 211 |
+
raise NotImplementedError
|
| 212 |
+
|
| 213 |
+
def __new_from_mlir_values__(self, values):
|
| 214 |
+
"""
|
| 215 |
+
Create a new object from MLIR values.
|
| 216 |
+
|
| 217 |
+
:param values: List of MLIR values
|
| 218 |
+
:type values: List[ir.Value]
|
| 219 |
+
:return: A new object that represents the given MLIR values
|
| 220 |
+
:rtype: Any
|
| 221 |
+
"""
|
| 222 |
+
raise NotImplementedError
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def get_c_pointers(obj):
|
| 226 |
+
"""
|
| 227 |
+
Given the `obj`, recursively go through it to extract all contained C pointers
|
| 228 |
+
"""
|
| 229 |
+
if hasattr(obj, "__c_pointers__"):
|
| 230 |
+
return obj.__c_pointers__()
|
| 231 |
+
elif isinstance(obj, (tuple, list)):
|
| 232 |
+
return sum((get_c_pointers(x) for x in obj), [])
|
| 233 |
+
elif isinstance(obj, set):
|
| 234 |
+
raise DSLRuntimeError(
|
| 235 |
+
"Sets are not supported in get_c_pointers to ensure order preservation",
|
| 236 |
+
context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.",
|
| 237 |
+
suggestion="Consider using a list or tuple instead",
|
| 238 |
+
)
|
| 239 |
+
return []
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def get_mlir_types(obj):
|
| 243 |
+
"""
|
| 244 |
+
Given the `obj`, recursively go through it to extract all contained MLIR types
|
| 245 |
+
"""
|
| 246 |
+
if hasattr(obj, "__get_mlir_types__"):
|
| 247 |
+
return obj.__get_mlir_types__()
|
| 248 |
+
elif hasattr(obj, "__extract_mlir_values__"):
|
| 249 |
+
return [v.type for v in obj.__extract_mlir_values__()]
|
| 250 |
+
elif isinstance(obj, ir.Value):
|
| 251 |
+
return [obj.type]
|
| 252 |
+
elif isinstance(obj, (tuple, list)):
|
| 253 |
+
return sum((get_mlir_types(x) for x in obj), [])
|
| 254 |
+
elif isinstance(obj, set):
|
| 255 |
+
raise DSLRuntimeError(
|
| 256 |
+
"Sets are not supported in get_mlir_types to ensure order preservation",
|
| 257 |
+
context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.",
|
| 258 |
+
suggestion="Consider using a list or tuple instead",
|
| 259 |
+
)
|
| 260 |
+
return []
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class DslType(type):
|
| 264 |
+
"""Metaclass for all DSL types in the system.
|
| 265 |
+
|
| 266 |
+
This metaclass provides type system infrastructure for DSL types, handling MLIR
|
| 267 |
+
type mappings and NumPy type conversions.
|
| 268 |
+
|
| 269 |
+
All data types in DSL must provide the following methods:
|
| 270 |
+
|
| 271 |
+
:param mlir_type: Corresponding MLIR type for this DSL type
|
| 272 |
+
:type mlir_type: Any, optional
|
| 273 |
+
:param is_abstract: Whether this type is abstract, defaults to False
|
| 274 |
+
:type is_abstract: bool, optional
|
| 275 |
+
|
| 276 |
+
**Required Methods**
|
| 277 |
+
|
| 278 |
+
* ``__str__`` (classmethod): Return string representation of the type
|
| 279 |
+
* ``__c_pointers__`` (optional): Return list of ctypes pointers of data used to invoke JIT function
|
| 280 |
+
* ``__get_mlir_types__``: Return list of MLIR types of the MLIR values contained in the instance
|
| 281 |
+
* ``__extract_mlir_values__``: Return list of MLIR values contained in the instance
|
| 282 |
+
* ``__new_from_mlir_values__``: Return a new instance from list of MLIR values
|
| 283 |
+
|
| 284 |
+
**Attributes**
|
| 285 |
+
|
| 286 |
+
:ivar _ir: MLIR provider
|
| 287 |
+
:vartype _ir: Any
|
| 288 |
+
:ivar _T: MLIR Type system provider
|
| 289 |
+
:vartype _T: Any
|
| 290 |
+
|
| 291 |
+
**Properties**
|
| 292 |
+
|
| 293 |
+
:property mlir_type: Returns the corresponding MLIR type for this DSL type
|
| 294 |
+
:type mlir_type: Any
|
| 295 |
+
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
_is_abstract: bool
|
| 299 |
+
|
| 300 |
+
def __new__(cls, name, bases, attrs, is_abstract=False, **kwargs):
|
| 301 |
+
new_cls = super().__new__(cls, name, bases, attrs)
|
| 302 |
+
|
| 303 |
+
new_cls._is_abstract = is_abstract
|
| 304 |
+
|
| 305 |
+
return new_cls
|
| 306 |
+
|
| 307 |
+
@property
|
| 308 |
+
def is_abstract(cls):
|
| 309 |
+
return cls._is_abstract
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class NumericMeta(DslType):
|
| 313 |
+
"""Metaclass for numeric types providing width and numpy dtype information.
|
| 314 |
+
|
| 315 |
+
:param width: Bit width of the numeric type, defaults to 8
|
| 316 |
+
:type width: int
|
| 317 |
+
:param np_dtype: Corresponding NumPy dtype
|
| 318 |
+
:type np_dtype: numpy.dtype, optional
|
| 319 |
+
:param mlir_type: Corresponding MLIR type
|
| 320 |
+
:type mlir_type: Any, optional
|
| 321 |
+
:param is_abstract: Whether the type is abstract, defaults to False
|
| 322 |
+
:type is_abstract: bool, optional
|
| 323 |
+
|
| 324 |
+
:ivar width: Bit width of the numeric type
|
| 325 |
+
:type width: int
|
| 326 |
+
:ivar _np_dtype: Corresponding NumPy dtype
|
| 327 |
+
:type _np_dtype: Union[numpy.dtype, None]
|
| 328 |
+
|
| 329 |
+
:property numpy_dtype: Returns the corresponding NumPy dtype
|
| 330 |
+
:rtype numpy_dtype: numpy.dtype
|
| 331 |
+
"""
|
| 332 |
+
|
| 333 |
+
width: int
|
| 334 |
+
|
| 335 |
+
# Placeholder type
|
| 336 |
+
_mlir_type = Any
|
| 337 |
+
_np_dtype: Union[np.dtype, None]
|
| 338 |
+
|
| 339 |
+
def __new__(
|
| 340 |
+
cls,
|
| 341 |
+
name,
|
| 342 |
+
bases,
|
| 343 |
+
attrs,
|
| 344 |
+
width=8,
|
| 345 |
+
np_dtype=None,
|
| 346 |
+
mlir_type=None,
|
| 347 |
+
is_abstract=False,
|
| 348 |
+
**kwargs,
|
| 349 |
+
):
|
| 350 |
+
def _extract_mlir_values(self):
|
| 351 |
+
return [self.ir_value()]
|
| 352 |
+
|
| 353 |
+
def _new_from_mlir_values(self, values: list) -> "Numeric":
|
| 354 |
+
res_ty = type(self)
|
| 355 |
+
return res_ty(values[0])
|
| 356 |
+
|
| 357 |
+
new_attrs = {
|
| 358 |
+
"__extract_mlir_values__": _extract_mlir_values,
|
| 359 |
+
"__new_from_mlir_values__": _new_from_mlir_values,
|
| 360 |
+
}
|
| 361 |
+
new_cls = super().__new__(
|
| 362 |
+
cls,
|
| 363 |
+
name,
|
| 364 |
+
bases,
|
| 365 |
+
new_attrs | attrs,
|
| 366 |
+
is_abstract=is_abstract,
|
| 367 |
+
**kwargs,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
if mlir_type is not None:
|
| 371 |
+
new_cls._mlir_type = staticmethod(mlir_type)
|
| 372 |
+
|
| 373 |
+
new_cls.width = width
|
| 374 |
+
new_cls._np_dtype = np_dtype
|
| 375 |
+
return new_cls
|
| 376 |
+
|
| 377 |
+
@property
|
| 378 |
+
def numpy_dtype(cls):
|
| 379 |
+
return cls._np_dtype
|
| 380 |
+
|
| 381 |
+
@property
|
| 382 |
+
def is_integer(cls) -> bool: ...
|
| 383 |
+
|
| 384 |
+
@property
|
| 385 |
+
def is_float(cls) -> bool: ...
|
| 386 |
+
|
| 387 |
+
def is_same_kind(cls, other: Type) -> bool:
|
| 388 |
+
return cls.is_integer == other.is_integer or cls.is_float == other.is_float
|
| 389 |
+
|
| 390 |
+
@staticmethod
|
| 391 |
+
def from_python(value: Any) -> Type["Numeric"]:
|
| 392 |
+
"""
|
| 393 |
+
Deduce the DSL type from a Python value.
|
| 394 |
+
"""
|
| 395 |
+
if isinstance(value, int):
|
| 396 |
+
return Int32
|
| 397 |
+
elif isinstance(value, float):
|
| 398 |
+
return Float32
|
| 399 |
+
elif isinstance(value, bool):
|
| 400 |
+
return Boolean
|
| 401 |
+
raise DSLRuntimeError(
|
| 402 |
+
f"Could not deduce Type[Numeric] from python value: {value} :{type(value)}"
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
@property
|
| 406 |
+
def mlir_type(cls):
|
| 407 |
+
return cls._mlir_type() # type: ignore
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
Value = TypeVar("Value")
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def cast(obj: Union[bool, int, float, Value], type_: Type["Numeric"]) -> "Numeric":
|
| 414 |
+
"""Cast an object to the specified numeric type.
|
| 415 |
+
|
| 416 |
+
:param obj: Object to be cast
|
| 417 |
+
:type obj: Union[bool, int, float, Value]
|
| 418 |
+
:param type_: Target numeric type
|
| 419 |
+
:type type_: Type[Numeric]
|
| 420 |
+
:raises TypeError: If casting to an abstract type or unsupported type conversion
|
| 421 |
+
:return: Object cast to the target numeric type
|
| 422 |
+
:rtype: Numeric
|
| 423 |
+
|
| 424 |
+
Example::
|
| 425 |
+
>>> x = cast(5, Int32) # Cast integer to Int32
|
| 426 |
+
>>> y = cast(3.14, Float32) # Cast float to Float32
|
| 427 |
+
"""
|
| 428 |
+
if type_.is_abstract:
|
| 429 |
+
if not isinstance(obj, type_):
|
| 430 |
+
raise TypeError(
|
| 431 |
+
f"can't cast {obj} to {type_}. Pass in concrete type instead, "
|
| 432 |
+
"e.g. Int32, Float32, etc."
|
| 433 |
+
)
|
| 434 |
+
# If target_type is abstract, and value is instance of target_type,
|
| 435 |
+
# then we can return value as is
|
| 436 |
+
else:
|
| 437 |
+
# Implicit cast based on using annotation type
|
| 438 |
+
obj = type_(obj)
|
| 439 |
+
return obj
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
# Option 1: use ir.Value as base
|
| 443 |
+
# class IntegerMeta(DslType, type(ir.Value)):
|
| 444 |
+
class IntegerMeta(NumericMeta):
|
| 445 |
+
"""Metaclass for integer types providing signedness information.
|
| 446 |
+
|
| 447 |
+
:param width: Bit width of the integer type, defaults to 32
|
| 448 |
+
:type width: int
|
| 449 |
+
:param signed: Whether the integer type is signed, defaults to True
|
| 450 |
+
:type signed: bool
|
| 451 |
+
:param mlir_type: Corresponding MLIR type, defaults to None
|
| 452 |
+
:type mlir_type: Any, optional
|
| 453 |
+
|
| 454 |
+
:ivar signed: Whether the integer type is signed
|
| 455 |
+
:vartype signed: bool
|
| 456 |
+
:ivar arith: Arithmetic operations interface
|
| 457 |
+
:vartype arith: Any
|
| 458 |
+
"""
|
| 459 |
+
|
| 460 |
+
signed: bool
|
| 461 |
+
|
| 462 |
+
def __new__(
|
| 463 |
+
cls,
|
| 464 |
+
name,
|
| 465 |
+
bases,
|
| 466 |
+
attrs,
|
| 467 |
+
width=32,
|
| 468 |
+
signed=True,
|
| 469 |
+
mlir_type=None,
|
| 470 |
+
is_abstract=False,
|
| 471 |
+
):
|
| 472 |
+
if width == 1:
|
| 473 |
+
np_dtype = np.bool_
|
| 474 |
+
elif width == 128:
|
| 475 |
+
np_dtype = None
|
| 476 |
+
elif signed:
|
| 477 |
+
np_dtype = getattr(np, f"int{width}")
|
| 478 |
+
else:
|
| 479 |
+
np_dtype = getattr(np, f"uint{width}")
|
| 480 |
+
|
| 481 |
+
def _c_pointers(self):
|
| 482 |
+
if width == 1:
|
| 483 |
+
c_value = ctypes.c_bool(self.value)
|
| 484 |
+
elif signed:
|
| 485 |
+
c_value = getattr(ctypes, f"c_int{width}")(self.value)
|
| 486 |
+
else:
|
| 487 |
+
c_value = getattr(ctypes, f"c_uint{width}")(self.value)
|
| 488 |
+
|
| 489 |
+
return [ctypes.cast(ctypes.pointer(c_value), ctypes.c_void_p)]
|
| 490 |
+
|
| 491 |
+
new_attrs = {
|
| 492 |
+
"__c_pointers__": _c_pointers,
|
| 493 |
+
}
|
| 494 |
+
new_cls = super().__new__(
|
| 495 |
+
cls, name, bases, attrs | new_attrs, width, np_dtype, mlir_type, is_abstract
|
| 496 |
+
)
|
| 497 |
+
new_cls.signed = signed
|
| 498 |
+
return new_cls
|
| 499 |
+
|
| 500 |
+
def __str__(cls):
|
| 501 |
+
return f"{cls.__name__}"
|
| 502 |
+
|
| 503 |
+
@property
|
| 504 |
+
def is_integer(cls) -> bool:
|
| 505 |
+
return True
|
| 506 |
+
|
| 507 |
+
@property
|
| 508 |
+
def is_float(cls) -> bool:
|
| 509 |
+
return False
|
| 510 |
+
|
| 511 |
+
@property
|
| 512 |
+
def zero(cls) -> int:
|
| 513 |
+
return 0
|
| 514 |
+
|
| 515 |
+
@property
|
| 516 |
+
def min(cls) -> int:
|
| 517 |
+
if cls.signed:
|
| 518 |
+
return -(2 ** (cls.width - 1))
|
| 519 |
+
else:
|
| 520 |
+
return 0
|
| 521 |
+
|
| 522 |
+
@property
|
| 523 |
+
def max(cls) -> int:
|
| 524 |
+
if cls.signed:
|
| 525 |
+
return 2 ** (cls.width - 1) - 1
|
| 526 |
+
else:
|
| 527 |
+
return 2**cls.width - 1
|
| 528 |
+
|
| 529 |
+
def recast_width(cls, width):
|
| 530 |
+
type_map = {
|
| 531 |
+
8: Int8,
|
| 532 |
+
16: Int16,
|
| 533 |
+
32: Int32,
|
| 534 |
+
64: Int64,
|
| 535 |
+
128: Int128,
|
| 536 |
+
}
|
| 537 |
+
if width not in type_map:
|
| 538 |
+
raise TypeError(f"Unsupported width: {width}")
|
| 539 |
+
return type_map[width]
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class FloatMeta(NumericMeta):
|
| 543 |
+
"""Metaclass for floating-point types.
|
| 544 |
+
|
| 545 |
+
This metaclass provides type system infrastructure for floating-point types in the DSL,
|
| 546 |
+
handling MLIR type mappings and NumPy type conversions.
|
| 547 |
+
|
| 548 |
+
:param width: Bit width of the float type, defaults to 32
|
| 549 |
+
:type width: int
|
| 550 |
+
:param mlir_type: Corresponding MLIR type, defaults to None
|
| 551 |
+
:type mlir_type: Any, optional
|
| 552 |
+
:param is_abstract: Whether this is an abstract base class, defaults to False
|
| 553 |
+
:type is_abstract: bool, optional
|
| 554 |
+
|
| 555 |
+
:ivar _arith: Arithmetic operations interface
|
| 556 |
+
:vartype _arith: Any
|
| 557 |
+
"""
|
| 558 |
+
|
| 559 |
+
_exponent_width: int
|
| 560 |
+
_mantissa_width: int
|
| 561 |
+
|
| 562 |
+
def __new__(cls, name, bases, attrs, width=32, mlir_type=None, is_abstract=False):
|
| 563 |
+
np_dtype = getattr(np, name.lower(), None)
|
| 564 |
+
new_cls = super().__new__(
|
| 565 |
+
cls, name, bases, attrs, width, np_dtype, mlir_type, is_abstract
|
| 566 |
+
)
|
| 567 |
+
# Extract exponent and mantissa bits from class name if it follows Float<E><M> pattern
|
| 568 |
+
# For example: Float8E4M3 -> exponent_width=4, mantissa_width=3
|
| 569 |
+
import re
|
| 570 |
+
|
| 571 |
+
if not is_abstract:
|
| 572 |
+
match = re.match(r"Float(\d+)E(\d+)M(\d+)(?:.*)", name)
|
| 573 |
+
if match:
|
| 574 |
+
exp_bits = int(match.group(2))
|
| 575 |
+
mant_bits = int(match.group(3))
|
| 576 |
+
|
| 577 |
+
# Store extracted values as class attributes
|
| 578 |
+
new_cls._exponent_width = exp_bits
|
| 579 |
+
new_cls._mantissa_width = mant_bits
|
| 580 |
+
# Don't have 1-to-1 mapping of narrow precision types like bfloat16, tfloat32, etc.
|
| 581 |
+
return new_cls
|
| 582 |
+
|
| 583 |
+
def __str__(cls):
|
| 584 |
+
return f"{cls.__name__}"
|
| 585 |
+
|
| 586 |
+
@property
|
| 587 |
+
def is_integer(cls) -> bool:
|
| 588 |
+
return False
|
| 589 |
+
|
| 590 |
+
@property
|
| 591 |
+
def is_float(cls) -> bool:
|
| 592 |
+
return True
|
| 593 |
+
|
| 594 |
+
@property
|
| 595 |
+
def zero(cls) -> float:
|
| 596 |
+
return 0.0
|
| 597 |
+
|
| 598 |
+
@property
|
| 599 |
+
def inf(cls) -> float:
|
| 600 |
+
return float("inf")
|
| 601 |
+
|
| 602 |
+
@property
|
| 603 |
+
def nan(cls) -> float:
|
| 604 |
+
return float("nan")
|
| 605 |
+
|
| 606 |
+
@property
|
| 607 |
+
def exponent_width(cls) -> int:
|
| 608 |
+
return cls._exponent_width
|
| 609 |
+
|
| 610 |
+
@property
|
| 611 |
+
def mantissa_width(cls) -> int:
|
| 612 |
+
return cls._mantissa_width
|
| 613 |
+
|
| 614 |
+
def recast_width(cls, width):
|
| 615 |
+
type_map = {
|
| 616 |
+
16: Float16,
|
| 617 |
+
32: Float32,
|
| 618 |
+
64: Float64,
|
| 619 |
+
}
|
| 620 |
+
if width not in type_map:
|
| 621 |
+
raise TypeError(f"Unsupported width: {width}")
|
| 622 |
+
return type_map[width]
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def _arith_signless_to_int(a, target_type):
|
| 626 |
+
# is_signed: sign of result type
|
| 627 |
+
if target_type.width > a.type.width:
|
| 628 |
+
# arith dialect consider `1` in `i1` as `-1`, treat it as unsigned for DSL
|
| 629 |
+
if target_type.signed and a.type.width > 1:
|
| 630 |
+
return arith.extsi(target_type.mlir_type, a)
|
| 631 |
+
else:
|
| 632 |
+
return arith.extui(target_type.mlir_type, a)
|
| 633 |
+
elif target_type.width < a.type.width:
|
| 634 |
+
return arith.trunci(target_type.mlir_type, a)
|
| 635 |
+
else:
|
| 636 |
+
return a
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def _binary_op_type_promote(a, b, promote_bool: bool = False):
|
| 640 |
+
"""Promote two numeric operands following type promotion rules.
|
| 641 |
+
|
| 642 |
+
:param a: First numeric operand
|
| 643 |
+
:type a: Numeric
|
| 644 |
+
:param b: Second numeric operand
|
| 645 |
+
:type b: Numeric
|
| 646 |
+
:param promote_bool: Whether to promote boolean types to Int32 for arithmetic operations, defaults to False
|
| 647 |
+
:type promote_bool: bool, optional
|
| 648 |
+
:raises ValueError: If implicit float promotion is not supported between the given types
|
| 649 |
+
:return: Tuple containing promoted operands and their resulting type
|
| 650 |
+
:rtype: tuple[Numeric, Numeric, Type[Numeric]]
|
| 651 |
+
|
| 652 |
+
Type promotion rules:
|
| 653 |
+
1. If operands are same type and not bools needing promotion:
|
| 654 |
+
- No promotion needed, return original types
|
| 655 |
+
2. If either operand is float:
|
| 656 |
+
a. If one is float and one is int:
|
| 657 |
+
- Convert int to the float type
|
| 658 |
+
b. If both are float:
|
| 659 |
+
- Promote to higher precision float if width >= 16
|
| 660 |
+
- For same width, promote to more general type (Float32 over TFloat32)
|
| 661 |
+
- Otherwise raise ValueError for unsupported promotion
|
| 662 |
+
3. Otherwise, both operands are integers. Integer promotion rules:
|
| 663 |
+
a. If promote_bool is True and either operand is bool:
|
| 664 |
+
- Promote bool to Int32 for arithmetic operations
|
| 665 |
+
|
| 666 |
+
Exceptions for numpy dtype casting:
|
| 667 |
+
- array(dtype=np.bool_) + array(dtype=np.bool_) -> array(dtype=np.bool_)
|
| 668 |
+
|
| 669 |
+
What is not supported:
|
| 670 |
+
- promotion with narrow precision float types which requires explicit cast by user
|
| 671 |
+
"""
|
| 672 |
+
a_type = a.dtype
|
| 673 |
+
b_type = b.dtype
|
| 674 |
+
|
| 675 |
+
# Early return for same types (except when they're bools that need promotion)
|
| 676 |
+
if a_type == b_type and not (promote_bool and a_type is Boolean):
|
| 677 |
+
return a, b, a_type
|
| 678 |
+
|
| 679 |
+
# Handle floating point promotions
|
| 680 |
+
if a_type.is_float or b_type.is_float:
|
| 681 |
+
# Get highest precision float type based on bitwidth
|
| 682 |
+
a_width = getattr(a_type, "width", 0)
|
| 683 |
+
b_width = getattr(b_type, "width", 0)
|
| 684 |
+
|
| 685 |
+
# If one type is integer, convert it to the float type
|
| 686 |
+
if a_type.is_float and not b_type.is_float:
|
| 687 |
+
b_type = a_type.recast_width(max(a_width, b_width))
|
| 688 |
+
elif b_type.is_float and not a_type.is_float:
|
| 689 |
+
a_type = b_type.recast_width(max(a_width, b_width))
|
| 690 |
+
|
| 691 |
+
# Both are float types - handle precision promotion
|
| 692 |
+
if a_width > b_width and a_width >= 16:
|
| 693 |
+
res_type = a_type
|
| 694 |
+
elif b_width > a_width and b_width >= 16:
|
| 695 |
+
res_type = b_type
|
| 696 |
+
elif a_width == b_width:
|
| 697 |
+
# Same bitwidth - handle special cases like TFloat32 -> Float32 and BFloat16 -> Float16
|
| 698 |
+
if a_type is Float64 or b_type is Float64:
|
| 699 |
+
res_type = Float64
|
| 700 |
+
elif a_type is Float32 or b_type is Float32:
|
| 701 |
+
res_type = Float32
|
| 702 |
+
elif a_type is Float16 or b_type is Float16:
|
| 703 |
+
res_type = Float16
|
| 704 |
+
else:
|
| 705 |
+
raise ValueError(
|
| 706 |
+
f"implicit float promotion of {a_type} or {b_type} is not supported, cast explicitly"
|
| 707 |
+
)
|
| 708 |
+
else:
|
| 709 |
+
raise ValueError(
|
| 710 |
+
f"implicit float promotion of {a_type} or {b_type} is not supported, cast explicitly"
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
# Only convert if type is different
|
| 714 |
+
new_a = a.to(res_type) if a.dtype != res_type else a
|
| 715 |
+
new_b = b.to(res_type) if b.dtype != res_type else b
|
| 716 |
+
return new_a, new_b, res_type
|
| 717 |
+
|
| 718 |
+
# Handle bool promotion for arithmetic operations
|
| 719 |
+
if promote_bool:
|
| 720 |
+
if a_type is Boolean and b_type is Boolean:
|
| 721 |
+
# Only promote to Int32 when both are bool
|
| 722 |
+
a = a.to(Int32)
|
| 723 |
+
b = b.to(Int32)
|
| 724 |
+
a_type = b_type = a.dtype
|
| 725 |
+
|
| 726 |
+
# If both were bools, they're now same type (Int32)
|
| 727 |
+
if a_type == b_type:
|
| 728 |
+
return a, b, a_type
|
| 729 |
+
|
| 730 |
+
# Same type, no promotion needed
|
| 731 |
+
if a_type == b_type:
|
| 732 |
+
return a, b, a_type
|
| 733 |
+
|
| 734 |
+
a_signed = a_type.signed
|
| 735 |
+
b_signed = b_type.signed
|
| 736 |
+
a_width = a_type.width
|
| 737 |
+
b_width = b_type.width
|
| 738 |
+
|
| 739 |
+
# Mixed signedness case
|
| 740 |
+
if a_signed != b_signed:
|
| 741 |
+
unsigned_type = a_type if not a_signed else b_type
|
| 742 |
+
signed_type = a_type if a_signed else b_type
|
| 743 |
+
unsigned_width = a_width if not a_signed else b_width
|
| 744 |
+
|
| 745 |
+
if unsigned_width >= signed_type.width:
|
| 746 |
+
# Promote both to unsigned of larger width
|
| 747 |
+
res_type = unsigned_type
|
| 748 |
+
else:
|
| 749 |
+
# Promote both to signed of larger width
|
| 750 |
+
res_type = signed_type
|
| 751 |
+
|
| 752 |
+
new_a = a.to(res_type) if a.dtype != res_type else a
|
| 753 |
+
new_b = b.to(res_type) if b.dtype != res_type else b
|
| 754 |
+
return new_a, new_b, res_type
|
| 755 |
+
|
| 756 |
+
# Same signedness, different width - promote to larger width
|
| 757 |
+
if a_width >= b_width:
|
| 758 |
+
return a, b.to(a.dtype), a.dtype
|
| 759 |
+
else:
|
| 760 |
+
return a.to(b.dtype), b, b.dtype
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
def _binary_op(op, promote_operand=True, promote_bool=False, flip=False):
|
| 764 |
+
"""Wrapper for binary operations on Numeric types.
|
| 765 |
+
|
| 766 |
+
This wrapper handles type promotion, operation execution, and result type determination
|
| 767 |
+
for binary operations between Numeric types.
|
| 768 |
+
|
| 769 |
+
:param op: The binary operation to perform (e.g., operator.add, operator.sub)
|
| 770 |
+
:type op: callable
|
| 771 |
+
:param emitter: Function that emits the MLIR operation for dynamic values
|
| 772 |
+
:type emitter: callable
|
| 773 |
+
:param promote_operand: Whether to promote operands to the same type, defaults to True
|
| 774 |
+
:type promote_operand: bool, optional
|
| 775 |
+
:param promote_bool: Whether to promote boolean results to Boolean type, defaults to False
|
| 776 |
+
:type promote_bool: bool, optional
|
| 777 |
+
:param flip: Whether to flip the operands when calling the operation, defaults to False
|
| 778 |
+
:type flip: bool, optional
|
| 779 |
+
|
| 780 |
+
:raises TypeError: When an unsupported operation is attempted on specific numeric types
|
| 781 |
+
|
| 782 |
+
.. note::
|
| 783 |
+
Not all operations are supported for all numeric types. In particular:
|
| 784 |
+
|
| 785 |
+
- Subtraction is not fully supported for Integer types
|
| 786 |
+
- Multiplication, floor division, and modulo operations may have limited support
|
| 787 |
+
- Division (truediv) with integer types is not fully supported and converts to Float32
|
| 788 |
+
"""
|
| 789 |
+
|
| 790 |
+
def wrapper(lhs, rhs, *, loc=None, ip=None):
|
| 791 |
+
orig_lhs_type = type(lhs)
|
| 792 |
+
orig_rhs_type = type(rhs)
|
| 793 |
+
|
| 794 |
+
# When called directly with self and other
|
| 795 |
+
ty = type(lhs)
|
| 796 |
+
# Canonicalize to Numeric type for promotion
|
| 797 |
+
if not isinstance(rhs, Numeric):
|
| 798 |
+
if not isinstance(rhs, (ArithValue, int, float, bool)):
|
| 799 |
+
# This allows rhs class to implement __rmul__
|
| 800 |
+
return NotImplemented
|
| 801 |
+
|
| 802 |
+
if isinstance(rhs, ArithValue):
|
| 803 |
+
if isinstance(rhs.type, ir.VectorType):
|
| 804 |
+
return NotImplemented
|
| 805 |
+
|
| 806 |
+
rhs = as_numeric(rhs)
|
| 807 |
+
|
| 808 |
+
# default result type to left-hand-side
|
| 809 |
+
res_type = ty
|
| 810 |
+
|
| 811 |
+
if promote_operand:
|
| 812 |
+
lhs, rhs, res_type = _binary_op_type_promote(lhs, rhs, promote_bool)
|
| 813 |
+
else:
|
| 814 |
+
rhs = ty(rhs)
|
| 815 |
+
|
| 816 |
+
if op in (
|
| 817 |
+
operator.lt,
|
| 818 |
+
operator.le,
|
| 819 |
+
operator.gt,
|
| 820 |
+
operator.ge,
|
| 821 |
+
operator.eq,
|
| 822 |
+
operator.ne,
|
| 823 |
+
):
|
| 824 |
+
res_type = Boolean
|
| 825 |
+
elif op == operator.truediv and isinstance(lhs, Integer):
|
| 826 |
+
res_type = Float32
|
| 827 |
+
elif promote_bool and orig_lhs_type == Boolean and orig_rhs_type == Boolean:
|
| 828 |
+
res_type = Boolean
|
| 829 |
+
|
| 830 |
+
if isinstance(lhs.value, ArithValue) and isinstance(lhs, Integer):
|
| 831 |
+
lhs_val = lhs.value.with_signedness(lhs.signed)
|
| 832 |
+
else:
|
| 833 |
+
lhs_val = lhs.value
|
| 834 |
+
|
| 835 |
+
if isinstance(rhs.value, ArithValue) and isinstance(rhs, Integer):
|
| 836 |
+
rhs_val = rhs.value.with_signedness(rhs.signed)
|
| 837 |
+
else:
|
| 838 |
+
rhs_val = rhs.value
|
| 839 |
+
|
| 840 |
+
if flip:
|
| 841 |
+
lhs_val, rhs_val = rhs_val, lhs_val
|
| 842 |
+
|
| 843 |
+
# Check if the operation is supported by the operands
|
| 844 |
+
res_val = op(lhs_val, rhs_val)
|
| 845 |
+
return res_type(res_val, loc=loc, ip=ip)
|
| 846 |
+
|
| 847 |
+
return wrapper
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
class Numeric(metaclass=NumericMeta, is_abstract=True):
|
| 851 |
+
"""Base class for all numeric types in the DSL.
|
| 852 |
+
|
| 853 |
+
This class provides the foundation for both Integer and Float types,
|
| 854 |
+
implementing basic arithmetic operations.
|
| 855 |
+
|
| 856 |
+
:param value: The value to store in the numeric type
|
| 857 |
+
:type value: Union[bool, int, float, Value]
|
| 858 |
+
|
| 859 |
+
:ivar value: The stored numeric value
|
| 860 |
+
:vartype value: Union[bool, int, float, Value]
|
| 861 |
+
"""
|
| 862 |
+
|
| 863 |
+
def __init__(self, value: Union[bool, int, float, Value], *, loc=None, ip=None):
|
| 864 |
+
self.value = value
|
| 865 |
+
|
| 866 |
+
def __str__(self) -> str:
|
| 867 |
+
# Use member's pretty-str method if member object has method.
|
| 868 |
+
# This can be extended in future to have better support for IDE, jupyter notebook, etc.
|
| 869 |
+
pretty_str = getattr(self.value, "pretty_str", None)
|
| 870 |
+
if pretty_str is not None:
|
| 871 |
+
return pretty_str()
|
| 872 |
+
else:
|
| 873 |
+
return "?"
|
| 874 |
+
|
| 875 |
+
def __repr__(self) -> str:
|
| 876 |
+
return f"{self.__class__.__name__}({repr(self.value)})"
|
| 877 |
+
|
| 878 |
+
def __hash__(self):
|
| 879 |
+
return hash(type(self).__class__) ^ hash(self.value)
|
| 880 |
+
|
| 881 |
+
@property
|
| 882 |
+
def dtype(self) -> Type["Numeric"]:
|
| 883 |
+
return type(self)
|
| 884 |
+
|
| 885 |
+
@overload
|
| 886 |
+
def to(self, dtype: Type["Numeric"], *, loc=None, ip=None) -> "Numeric": ...
|
| 887 |
+
|
| 888 |
+
@overload
|
| 889 |
+
def to(self, dtype: Type[int], *, loc=None, ip=None) -> int: ...
|
| 890 |
+
|
| 891 |
+
@overload
|
| 892 |
+
def to(self, dtype: Type[float], *, loc=None, ip=None) -> float: ...
|
| 893 |
+
|
| 894 |
+
@overload
|
| 895 |
+
def to(self, dtype: Type[bool], *, loc=None, ip=None) -> bool: ...
|
| 896 |
+
|
| 897 |
+
@overload
|
| 898 |
+
def to(self, dtype: Type[ir.Value], *, loc=None, ip=None) -> ir.Value: ...
|
| 899 |
+
|
| 900 |
+
def to(self, dtype: Type, *, loc=None, ip=None):
|
| 901 |
+
"""Convert this numeric value to another numeric type.
|
| 902 |
+
|
| 903 |
+
If the target type is the same as the current type, returns self.
|
| 904 |
+
Otherwise, creates a new instance of the target type with the same value.
|
| 905 |
+
|
| 906 |
+
:param dtype: The target numeric type to convert to
|
| 907 |
+
:type dtype: Union[Type["Numeric"], Type[int], Type[float], Type[bool]]
|
| 908 |
+
:return: A new instance of the target type, or self if types match
|
| 909 |
+
:rtype: Numeric
|
| 910 |
+
:raises TypeError: If trying to convert an MLIR value to a static Python type
|
| 911 |
+
:raises TypeError: If trying to convert to unsupported float types like Float8E4M3,
|
| 912 |
+
Float8E4M3B11FNUZ, Float4E2M1FN, Float6E3M2FN, or Float6E2M3FN
|
| 913 |
+
|
| 914 |
+
.. note::
|
| 915 |
+
|
| 916 |
+
Unsupported destination float types:
|
| 917 |
+
- Float8E4M3
|
| 918 |
+
- Float8E4M3B11FNUZ
|
| 919 |
+
- Float4E2M1FN
|
| 920 |
+
- Float6E3M2FN
|
| 921 |
+
- Float6E2M3FN
|
| 922 |
+
|
| 923 |
+
Example::
|
| 924 |
+
|
| 925 |
+
.. code-block:: python
|
| 926 |
+
|
| 927 |
+
# Convert between DSL numeric types
|
| 928 |
+
x = Int32(5)
|
| 929 |
+
y = x.to(Float32) # Converts to Float32(5.0)
|
| 930 |
+
|
| 931 |
+
# Convert to Python primitive types
|
| 932 |
+
# They are considered as static values at JIT time
|
| 933 |
+
z = x.to(int) # Returns Python int 5
|
| 934 |
+
w = y.to(float) # Returns Python float 5.0
|
| 935 |
+
|
| 936 |
+
# This will raise a ValueError
|
| 937 |
+
mlir_val = arith.constant(T.i32(), 42)
|
| 938 |
+
num = Int32(mlir_val)
|
| 939 |
+
num.to(int) # ValueError: unable to convert MLIR value to static type: <class 'int'>
|
| 940 |
+
"""
|
| 941 |
+
if dtype in _unsupported_dst_float_types:
|
| 942 |
+
raise TypeError(f"Unsupported destination float type: {dtype}")
|
| 943 |
+
|
| 944 |
+
if isinstance(dtype, type(self)):
|
| 945 |
+
return self
|
| 946 |
+
elif isinstance(dtype, NumericMeta):
|
| 947 |
+
return dtype(self)
|
| 948 |
+
elif dtype is ir.Value:
|
| 949 |
+
if isinstance(self.value, (int, float, bool)):
|
| 950 |
+
res = arith_helper.const(
|
| 951 |
+
self.value, self.dtype.mlir_type, loc=loc, ip=ip
|
| 952 |
+
)
|
| 953 |
+
elif isinstance(self.value, ir.Value):
|
| 954 |
+
res = self.value
|
| 955 |
+
else:
|
| 956 |
+
raise ValueError(
|
| 957 |
+
f"cannot convert {type(self)} to {dtype}, "
|
| 958 |
+
f"self.value is {self.value.type}"
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
if not isinstance(res, ArithValue):
|
| 962 |
+
raise ValueError(f"Expected ArithValue, got {type(res)} as {res.type}")
|
| 963 |
+
|
| 964 |
+
return res.with_signedness(getattr(type(self), "signed", None))
|
| 965 |
+
elif dtype in (int, float, bool):
|
| 966 |
+
if isinstance(self.value, ir.Value):
|
| 967 |
+
raise ValueError(
|
| 968 |
+
f"unable to convert {self.value} to static type: {dtype}"
|
| 969 |
+
)
|
| 970 |
+
return dtype(self.value)
|
| 971 |
+
else:
|
| 972 |
+
raise ValueError(f"unable to convert {type(self)} to {dtype}")
|
| 973 |
+
|
| 974 |
+
def ir_value(self, *, loc=None, ip=None) -> ir.Value:
|
| 975 |
+
return self.to(ir.Value, loc=loc, ip=ip)
|
| 976 |
+
|
| 977 |
+
@property
|
| 978 |
+
def zero(self) -> "Numeric": ...
|
| 979 |
+
|
| 980 |
+
def __dsl_not__(self, *, loc=None, ip=None):
|
| 981 |
+
"""DSL implementation of Python's `not` operator.
|
| 982 |
+
|
| 983 |
+
Returns True if the value is equal to zero, False otherwise.
|
| 984 |
+
This matches Python's behavior where any non-zero number is considered True.
|
| 985 |
+
|
| 986 |
+
:param loc: The source location information, defaults to None
|
| 987 |
+
:type loc: Optional[Location]
|
| 988 |
+
:param ip: The insertion point for the operation, defaults to None
|
| 989 |
+
:type ip: Optional[InsertionPoint]
|
| 990 |
+
:return: The result of the logical not operation
|
| 991 |
+
:rtype: Boolean
|
| 992 |
+
"""
|
| 993 |
+
if isinstance(self.value, (int, float, bool)):
|
| 994 |
+
return not self.value
|
| 995 |
+
else:
|
| 996 |
+
ty = type(self)
|
| 997 |
+
zero_val = arith.constant(ty.mlir_type, ty.zero)
|
| 998 |
+
return self.__eq__(ty(zero_val), loc=loc, ip=ip)
|
| 999 |
+
|
| 1000 |
+
def __dsl_and__(self, other, *, loc=None, ip=None):
|
| 1001 |
+
"""DSL implementation of Python's `and` operator.
|
| 1002 |
+
|
| 1003 |
+
Returns the second operand if the first is truthy, otherwise returns the first operand.
|
| 1004 |
+
A numeric value is considered truthy if it is non-zero.
|
| 1005 |
+
|
| 1006 |
+
:param other: The right-hand operand
|
| 1007 |
+
:type other: Numeric
|
| 1008 |
+
:param loc: The source location information, defaults to None
|
| 1009 |
+
:type loc: Optional[Location]
|
| 1010 |
+
:param ip: The insertion point for the operation, defaults to None
|
| 1011 |
+
:type ip: Optional[InsertionPoint]
|
| 1012 |
+
:return: The result of the logical and operation
|
| 1013 |
+
:rtype: Boolean
|
| 1014 |
+
|
| 1015 |
+
Example::
|
| 1016 |
+
|
| 1017 |
+
5 and 3 -> 3
|
| 1018 |
+
0 and 3 -> 0
|
| 1019 |
+
3 and 0 and ... -> 0
|
| 1020 |
+
"""
|
| 1021 |
+
is_true = self.__dsl_bool__(loc=loc, ip=ip)
|
| 1022 |
+
|
| 1023 |
+
def and_op(lhs, rhs):
|
| 1024 |
+
if isinstance(lhs, (int, float, bool)):
|
| 1025 |
+
if isinstance(rhs, (int, float, bool)):
|
| 1026 |
+
return lhs and rhs
|
| 1027 |
+
else:
|
| 1028 |
+
lhs = arith.constant(rhs.type, lhs)
|
| 1029 |
+
return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip)
|
| 1030 |
+
else:
|
| 1031 |
+
if isinstance(rhs, (int, float, bool)):
|
| 1032 |
+
rhs = arith.constant(lhs.type, rhs)
|
| 1033 |
+
return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip)
|
| 1034 |
+
else:
|
| 1035 |
+
return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip)
|
| 1036 |
+
|
| 1037 |
+
return _binary_op(and_op, promote_bool=True)(self, other, loc=loc, ip=ip)
|
| 1038 |
+
|
| 1039 |
+
def __dsl_or__(self, other, *, loc=None, ip=None):
|
| 1040 |
+
"""DSL implementation of Python's `or` operator.
|
| 1041 |
+
|
| 1042 |
+
Returns the first operand if it is truthy, otherwise returns the second operand.
|
| 1043 |
+
A numeric value is considered truthy if it is non-zero.
|
| 1044 |
+
|
| 1045 |
+
:param other: The right-hand operand
|
| 1046 |
+
:type other: Numeric
|
| 1047 |
+
:param loc: The source location information, defaults to None
|
| 1048 |
+
:type loc: Optional[Location]
|
| 1049 |
+
:param ip: The insertion point for the operation, defaults to None
|
| 1050 |
+
:type ip: Optional[InsertionPoint]
|
| 1051 |
+
:return: The result of the logical or operation
|
| 1052 |
+
:rtype: Boolean
|
| 1053 |
+
|
| 1054 |
+
Example::
|
| 1055 |
+
|
| 1056 |
+
5 or 3 -> 5
|
| 1057 |
+
0 or 3 -> 3
|
| 1058 |
+
3 or 0 -> 3
|
| 1059 |
+
"""
|
| 1060 |
+
is_true = self.__dsl_bool__(loc=loc, ip=ip)
|
| 1061 |
+
|
| 1062 |
+
def or_op(lhs, rhs):
|
| 1063 |
+
if isinstance(lhs, (int, float, bool)):
|
| 1064 |
+
if isinstance(rhs, (int, float, bool)):
|
| 1065 |
+
return lhs or rhs
|
| 1066 |
+
else:
|
| 1067 |
+
lhs = arith.constant(rhs.type, lhs)
|
| 1068 |
+
return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip)
|
| 1069 |
+
else:
|
| 1070 |
+
if isinstance(rhs, (int, float, bool)):
|
| 1071 |
+
rhs = arith.constant(lhs.type, rhs)
|
| 1072 |
+
return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip)
|
| 1073 |
+
else:
|
| 1074 |
+
return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip)
|
| 1075 |
+
|
| 1076 |
+
return _binary_op(or_op, promote_bool=True)(self, other, loc=loc, ip=ip)
|
| 1077 |
+
|
| 1078 |
+
def __dsl_bool__(self, *, loc=None, ip=None) -> "Boolean":
|
| 1079 |
+
"""DSL implementation of Python's __bool__ method.
|
| 1080 |
+
|
| 1081 |
+
Returns a Boolean indicating whether this value is considered truthy.
|
| 1082 |
+
For numeric types, returns True if the value is non-zero.
|
| 1083 |
+
|
| 1084 |
+
:param loc: The source location information, defaults to None
|
| 1085 |
+
:type loc: Optional[Location]
|
| 1086 |
+
:param ip: The insertion point for the operation, defaults to None
|
| 1087 |
+
:type ip: Optional[InsertionPoint]
|
| 1088 |
+
:return: True if this value is truthy (non-zero), False otherwise
|
| 1089 |
+
:rtype: Boolean
|
| 1090 |
+
"""
|
| 1091 |
+
zero = type(self).zero
|
| 1092 |
+
return self.__ne__(zero, loc=loc, ip=ip)
|
| 1093 |
+
|
| 1094 |
+
def __bool__(self):
|
| 1095 |
+
if isinstance(self.value, (int, float, bool)):
|
| 1096 |
+
return bool(self.value)
|
| 1097 |
+
else:
|
| 1098 |
+
raise DSLRuntimeError(
|
| 1099 |
+
f"Unable to convert dynamic `{type(self).__name__}` value to bool at compile time.",
|
| 1100 |
+
suggestion=[
|
| 1101 |
+
"Decorate the parent function with `jit` decorator and with `preprocess` enabled.",
|
| 1102 |
+
"Ensure not using patterns that DSL does not support.",
|
| 1103 |
+
"Otherwise, please file a bug report.",
|
| 1104 |
+
],
|
| 1105 |
+
)
|
| 1106 |
+
|
| 1107 |
+
def __index__(self):
|
| 1108 |
+
if isinstance(self.value, (int, float, bool)):
|
| 1109 |
+
return self.value
|
| 1110 |
+
else:
|
| 1111 |
+
raise DSLRuntimeError(
|
| 1112 |
+
f"'{type(self.value)}' object cannot be interpreted as an integer",
|
| 1113 |
+
suggestion="Mark the loop as dynamic with `dynamic_expr` or `range_dynamic` and decorate the parent function with `jit` decorator",
|
| 1114 |
+
)
|
| 1115 |
+
|
| 1116 |
+
def __neg__(self, *, loc=None, ip=None):
|
| 1117 |
+
if isinstance(self, (bool, int, float)):
|
| 1118 |
+
return type(self)(-self.value) # type: ignore
|
| 1119 |
+
else:
|
| 1120 |
+
return type(self)(-self.value, loc=loc, ip=ip) # type: ignore
|
| 1121 |
+
|
| 1122 |
+
@staticmethod
|
| 1123 |
+
def _from_python_value(value):
|
| 1124 |
+
if isinstance(value, Numeric):
|
| 1125 |
+
return value
|
| 1126 |
+
|
| 1127 |
+
if isinstance(value, bool):
|
| 1128 |
+
res_type = Boolean
|
| 1129 |
+
elif isinstance(value, int):
|
| 1130 |
+
res_type = Int32
|
| 1131 |
+
elif isinstance(value, float):
|
| 1132 |
+
res_type = Float32
|
| 1133 |
+
elif isinstance(value, ArithValue):
|
| 1134 |
+
res_type = Numeric.from_mlir_type(value.type)
|
| 1135 |
+
else:
|
| 1136 |
+
raise ValueError(
|
| 1137 |
+
f"unable to convert {value} in type {type(value)} to Numeric"
|
| 1138 |
+
)
|
| 1139 |
+
return res_type(value)
|
| 1140 |
+
|
| 1141 |
+
def __add__(self, other, *, loc=None, ip=None) -> "Numeric":
|
| 1142 |
+
return _binary_op(operator.add, promote_bool=True)(self, other, loc=loc, ip=ip)
|
| 1143 |
+
|
| 1144 |
+
def __sub__(self, other, *, loc=None, ip=None) -> "Numeric":
|
| 1145 |
+
return _binary_op(operator.sub, promote_bool=True)(self, other, loc=loc, ip=ip)
|
| 1146 |
+
|
| 1147 |
+
def __mul__(self, other, *, loc=None, ip=None) -> "Numeric":
|
| 1148 |
+
return _binary_op(operator.mul, promote_bool=True)(self, other, loc=loc, ip=ip)
|
| 1149 |
+
|
| 1150 |
+
def __floordiv__(self, other, *, loc=None, ip=None) -> "Numeric":
|
| 1151 |
+
return _binary_op(operator.floordiv, promote_bool=True)(
|
| 1152 |
+
self, other, loc=loc, ip=ip
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
def __truediv__(self, other, *, loc=None, ip=None) -> "Numeric":
|
| 1156 |
+
return _binary_op(operator.truediv, promote_bool=True)(
|
| 1157 |
+
self, other, loc=loc, ip=ip
|
| 1158 |
+
)
|
| 1159 |
+
|
| 1160 |
+
def __mod__(self, other, *, loc=None, ip=None) -> "Numeric":
|
| 1161 |
+
return _binary_op(operator.mod, promote_bool=True)(self, other, loc=loc, ip=ip)
|
| 1162 |
+
|
| 1163 |
+
def __radd__(self, other, *, loc=None, ip=None) -> "Numeric":
|
| 1164 |
+
return self.__add__(other, loc=loc, ip=ip)
|
| 1165 |
+
|
| 1166 |
+
def __rsub__(self, other, *, loc=None, ip=None) -> "Numeric":
|
| 1167 |
+
return _binary_op(operator.sub, promote_bool=True, flip=True)(
|
| 1168 |
+
self, other, loc=loc, ip=ip
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
def __rmul__(self, other, *, loc=None, ip=None) -> "Numeric":
|
| 1172 |
+
return self.__mul__(other, loc=loc, ip=ip)
|
| 1173 |
+
|
| 1174 |
+
def __rfloordiv__(self, other, *, loc=None, ip=None) -> "Numeric":
|
| 1175 |
+
return _binary_op(operator.floordiv, promote_bool=True, flip=True)(
|
| 1176 |
+
self, other, loc=loc, ip=ip
|
| 1177 |
+
)
|
| 1178 |
+
|
| 1179 |
+
def __rtruediv__(self, other, *, loc=None, ip=None) -> "Numeric":
|
| 1180 |
+
return _binary_op(operator.truediv, promote_bool=True, flip=True)(
|
| 1181 |
+
self, other, loc=loc, ip=ip
|
| 1182 |
+
)
|
| 1183 |
+
|
| 1184 |
+
def __rmod__(self, other, *, loc=None, ip=None) -> "Numeric":
|
| 1185 |
+
return _binary_op(operator.mod, promote_bool=True, flip=True)(
|
| 1186 |
+
self, other, loc=loc, ip=ip
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
def __eq__(self, other, *, loc=None, ip=None) -> "Boolean":
|
| 1190 |
+
return _binary_op(operator.eq)(self, other, loc=loc, ip=ip) # type: ignore
|
| 1191 |
+
|
| 1192 |
+
def __ne__(self, other, *, loc=None, ip=None) -> "Boolean":
|
| 1193 |
+
return _binary_op(operator.ne)(self, other, loc=loc, ip=ip) # type: ignore
|
| 1194 |
+
|
| 1195 |
+
def __lt__(self, other, *, loc=None, ip=None) -> "Boolean":
|
| 1196 |
+
return _binary_op(operator.lt)(self, other, loc=loc, ip=ip) # type: ignore
|
| 1197 |
+
|
| 1198 |
+
def __le__(self, other, *, loc=None, ip=None) -> "Boolean":
|
| 1199 |
+
return _binary_op(operator.le)(self, other, loc=loc, ip=ip) # type: ignore
|
| 1200 |
+
|
| 1201 |
+
def __gt__(self, other, *, loc=None, ip=None) -> "Boolean":
|
| 1202 |
+
return _binary_op(operator.gt)(self, other, loc=loc, ip=ip) # type: ignore
|
| 1203 |
+
|
| 1204 |
+
def __ge__(self, other, *, loc=None, ip=None) -> "Boolean":
|
| 1205 |
+
return _binary_op(operator.ge)(self, other, loc=loc, ip=ip) # type: ignore
|
| 1206 |
+
|
| 1207 |
+
def __pow__(self, other, *, loc=None, ip=None) -> "Numeric":
|
| 1208 |
+
return _binary_op(operator.pow)(self, other, loc=loc, ip=ip) # type: ignore
|
| 1209 |
+
|
| 1210 |
+
def __c_pointers__(self):
|
| 1211 |
+
raise ValueError(
|
| 1212 |
+
f"only support built-in types: bool, (u)int{8, 16, 32, 64}, float{32, 64}, but got {type(self)}"
|
| 1213 |
+
)
|
| 1214 |
+
|
| 1215 |
+
def __get_mlir_types__(self):
|
| 1216 |
+
return [type(self).mlir_type]
|
| 1217 |
+
|
| 1218 |
+
@staticmethod
|
| 1219 |
+
def from_mlir_type(mlir_type):
|
| 1220 |
+
type_map = {
|
| 1221 |
+
T.bool(): Boolean,
|
| 1222 |
+
T.f64(): Float64,
|
| 1223 |
+
T.f32(): Float32,
|
| 1224 |
+
T.tf32(): TFloat32,
|
| 1225 |
+
T.f16(): Float16,
|
| 1226 |
+
T.bf16(): BFloat16,
|
| 1227 |
+
T.i(128): Int128,
|
| 1228 |
+
T.i64(): Int64,
|
| 1229 |
+
T.i32(): Int32,
|
| 1230 |
+
T.i16(): Int16,
|
| 1231 |
+
T.i8(): Int8,
|
| 1232 |
+
T.si(128): Int128,
|
| 1233 |
+
T.si64(): Int64,
|
| 1234 |
+
T.si32(): Int32,
|
| 1235 |
+
T.si16(): Int16,
|
| 1236 |
+
T.si8(): Int8,
|
| 1237 |
+
T.ui(128): Uint128,
|
| 1238 |
+
T.ui64(): Uint64,
|
| 1239 |
+
T.ui32(): Uint32,
|
| 1240 |
+
T.ui16(): Uint16,
|
| 1241 |
+
T.ui8(): Uint8,
|
| 1242 |
+
T.f8E5M2(): Float8E5M2,
|
| 1243 |
+
T.f8E4M3(): Float8E4M3,
|
| 1244 |
+
T.f8E4M3FN(): Float8E4M3FN,
|
| 1245 |
+
T.f8E4M3B11FNUZ(): Float8E4M3B11FNUZ,
|
| 1246 |
+
T.f4E2M1FN(): Float4E2M1FN,
|
| 1247 |
+
T.f6E2M3FN(): Float6E2M3FN,
|
| 1248 |
+
T.f6E3M2FN(): Float6E3M2FN,
|
| 1249 |
+
T.f8E8M0FNU(): Float8E8M0FNU,
|
| 1250 |
+
}
|
| 1251 |
+
if mlir_type not in type_map:
|
| 1252 |
+
raise DSLRuntimeError(f"Unsupported DSL type: {mlir_type}")
|
| 1253 |
+
return type_map[mlir_type]
|
| 1254 |
+
|
| 1255 |
+
|
| 1256 |
+
def as_numeric(obj: Union[bool, int, float, ir.Value, Numeric]) -> Numeric:
|
| 1257 |
+
"""Convert a Python primitive value to a Numeric type.
|
| 1258 |
+
|
| 1259 |
+
:param obj: Python primitive value to convert
|
| 1260 |
+
:type obj: Union[bool, int, float]
|
| 1261 |
+
:return: The converted Numeric object
|
| 1262 |
+
:rtype: Numeric
|
| 1263 |
+
|
| 1264 |
+
Example::
|
| 1265 |
+
|
| 1266 |
+
.. code-block:: python
|
| 1267 |
+
|
| 1268 |
+
x = as_numeric(5) # Converts to Int32
|
| 1269 |
+
y = as_numeric(3.14) # Converts to Float32
|
| 1270 |
+
z = as_numeric(True) # Converts to Boolean
|
| 1271 |
+
"""
|
| 1272 |
+
if isinstance(obj, Numeric):
|
| 1273 |
+
return obj
|
| 1274 |
+
return Numeric._from_python_value(obj)
|
| 1275 |
+
|
| 1276 |
+
|
| 1277 |
+
class Integer(Numeric, metaclass=IntegerMeta, mlir_type=T.i32, is_abstract=True):
|
| 1278 |
+
"""A class representing integer values with specific width and signedness.
|
| 1279 |
+
|
| 1280 |
+
This class provides functionality to create and manipulate integer values with
|
| 1281 |
+
configurable width and signedness. It supports conversion from various input types
|
| 1282 |
+
including Python scalars, MLIR Values, and other numeric types.
|
| 1283 |
+
|
| 1284 |
+
:param x: The input value to convert to this integer type
|
| 1285 |
+
:type x: Union[bool, int, float, ir.Value, Integer, Float]
|
| 1286 |
+
|
| 1287 |
+
:return: A new Integer instance with the converted value
|
| 1288 |
+
:rtype: Integer
|
| 1289 |
+
|
| 1290 |
+
:raises AssertionError: If the type's numpy_dtype is None
|
| 1291 |
+
:raises NotImplementedError: If converting between different Integer types
|
| 1292 |
+
:raises ValueError: If the input type is not supported for conversion
|
| 1293 |
+
:raises OverflowError: If converting float infinity to integer
|
| 1294 |
+
|
| 1295 |
+
Type conversion behavior:
|
| 1296 |
+
|
| 1297 |
+
* Python scalars (bool, int, float):
|
| 1298 |
+
* Converted through numpy dtype casting
|
| 1299 |
+
* NaN and infinity values are rejected
|
| 1300 |
+
* Example: Int8(256) -> -256 (overflow behavior)
|
| 1301 |
+
|
| 1302 |
+
* MLIR Value with IntegerType:
|
| 1303 |
+
* Width differences handled by signless to signed/unsigned conversion
|
| 1304 |
+
* Example: i8 -> i8/ui8 depending on target type
|
| 1305 |
+
|
| 1306 |
+
* MLIR Value with FloatType:
|
| 1307 |
+
* Uses MLIR float-to-int conversion
|
| 1308 |
+
* NaN and infinity values is undefined behavior
|
| 1309 |
+
* Example: f32 -> i32/ui32 depending on target type
|
| 1310 |
+
|
| 1311 |
+
* Integer:
|
| 1312 |
+
* Uses MLIR float-to-int conversion or numpy dtype casting
|
| 1313 |
+
* Example: Int32(Int32(5)) => 5
|
| 1314 |
+
|
| 1315 |
+
* Float:
|
| 1316 |
+
* Uses MLIR float-to-int conversion
|
| 1317 |
+
* Example: Int32(Float(5.7)) -> 5
|
| 1318 |
+
|
| 1319 |
+
Example usage:
|
| 1320 |
+
|
| 1321 |
+
.. code-block:: python
|
| 1322 |
+
|
| 1323 |
+
x = Int32(5) # From integer
|
| 1324 |
+
y = Int32(True) # From boolean
|
| 1325 |
+
z = Int32(3.7) # From float (truncates)
|
| 1326 |
+
w = Int32(x) # From same Integer type
|
| 1327 |
+
c5 = arith.constant(5, T.i32())
|
| 1328 |
+
a = Int32(c5) # Treat c5 as int32 bitwise
|
| 1329 |
+
"""
|
| 1330 |
+
|
| 1331 |
+
def __init__(self, x, *, loc=None, ip=None):
|
| 1332 |
+
ty = type(self)
|
| 1333 |
+
|
| 1334 |
+
if isinstance(x, (bool, int, float)):
|
| 1335 |
+
# Add check for NaN before numpy conversion
|
| 1336 |
+
if isinstance(x, float):
|
| 1337 |
+
if np.isnan(x):
|
| 1338 |
+
raise ValueError("Cannot convert float NaN to integer")
|
| 1339 |
+
elif np.isinf(x):
|
| 1340 |
+
raise OverflowError("Cannot convert float infinity to integer")
|
| 1341 |
+
|
| 1342 |
+
np_dtype = ty.numpy_dtype
|
| 1343 |
+
assert np_dtype is not None, f"expects numpy.dtype, but got {np_dtype}"
|
| 1344 |
+
x_val = int(np.array(x).astype(np_dtype))
|
| 1345 |
+
elif type(x) == ty:
|
| 1346 |
+
x_val = x.value
|
| 1347 |
+
elif isinstance(x, ir.Value): # type: ignore
|
| 1348 |
+
x_val = x
|
| 1349 |
+
if isinstance(x.type, ir.IntegerType): # type: ignore
|
| 1350 |
+
if x.type.width != ty.width:
|
| 1351 |
+
# signless -> (u)int
|
| 1352 |
+
x_val = _arith_signless_to_int(x, ty)
|
| 1353 |
+
elif isinstance(x.type, ir.FloatType): # type: ignore
|
| 1354 |
+
# float -> (u)int
|
| 1355 |
+
x_val = arith_helper.fptoi(x, ty.signed, ty.mlir_type, loc=loc, ip=ip)
|
| 1356 |
+
elif isinstance(x, Integer):
|
| 1357 |
+
if isinstance(x.value, ir.Value):
|
| 1358 |
+
x_val = arith_helper.int_to_int(x.ir_value(), ty)
|
| 1359 |
+
else:
|
| 1360 |
+
# For non-MLIR values, use numpy casting
|
| 1361 |
+
src_val = np.array(x.value, dtype=type(x).numpy_dtype)
|
| 1362 |
+
x_val = int(src_val.astype(ty.numpy_dtype))
|
| 1363 |
+
elif isinstance(x, Float):
|
| 1364 |
+
# float -> int is handled by Integer.__init__ recursively
|
| 1365 |
+
Integer.__init__(self, x.value)
|
| 1366 |
+
return
|
| 1367 |
+
else:
|
| 1368 |
+
raise DSLRuntimeError(f"{x} to integer conversion is not supported")
|
| 1369 |
+
|
| 1370 |
+
super().__init__(x_val)
|
| 1371 |
+
|
| 1372 |
+
def __invert__(self, *, loc=None, ip=None):
|
| 1373 |
+
res_type = type(self)
|
| 1374 |
+
return res_type(self.ir_value(loc=loc, ip=ip).__invert__(loc=loc, ip=ip))
|
| 1375 |
+
|
| 1376 |
+
def __lshift__(self, other, *, loc=None, ip=None):
|
| 1377 |
+
return _binary_op(operator.lshift)(self, other, loc=loc, ip=ip)
|
| 1378 |
+
|
| 1379 |
+
def __rlshift__(self, other, *, loc=None, ip=None):
|
| 1380 |
+
other_ = as_numeric(other)
|
| 1381 |
+
if not isinstance(other_, Integer):
|
| 1382 |
+
raise ValueError(f"Cannot left shift {other_} with {self}")
|
| 1383 |
+
return other_.__lshift__(self, loc=loc, ip=ip)
|
| 1384 |
+
|
| 1385 |
+
def __rshift__(self, other, *, loc=None, ip=None):
|
| 1386 |
+
return _binary_op(operator.rshift)(self, other, loc=loc, ip=ip)
|
| 1387 |
+
|
| 1388 |
+
def __rrshift__(self, other, *, loc=None, ip=None):
|
| 1389 |
+
other_ = as_numeric(other)
|
| 1390 |
+
if not isinstance(other_, Integer):
|
| 1391 |
+
raise ValueError(f"Cannot right shift {other_} with {self}")
|
| 1392 |
+
return other_.__rshift__(self, loc=loc, ip=ip)
|
| 1393 |
+
|
| 1394 |
+
def __and__(self, other, *, loc=None, ip=None):
|
| 1395 |
+
return _binary_op(operator.and_)(self, other, loc=loc, ip=ip)
|
| 1396 |
+
|
| 1397 |
+
def __rand__(self, other, *, loc=None, ip=None):
|
| 1398 |
+
return self.__and__(other, loc=loc, ip=ip)
|
| 1399 |
+
|
| 1400 |
+
def __or__(self, other, *, loc=None, ip=None):
|
| 1401 |
+
return _binary_op(operator.or_)(self, other, loc=loc, ip=ip)
|
| 1402 |
+
|
| 1403 |
+
def __ror__(self, other, *, loc=None, ip=None):
|
| 1404 |
+
return self.__or__(other, loc=loc, ip=ip)
|
| 1405 |
+
|
| 1406 |
+
def __xor__(self, other, *, loc=None, ip=None):
|
| 1407 |
+
return _binary_op(operator.xor)(self, other, loc=loc, ip=ip)
|
| 1408 |
+
|
| 1409 |
+
def __rxor__(self, other, *, loc=None, ip=None):
|
| 1410 |
+
return self.__xor__(other, loc=loc, ip=ip)
|
| 1411 |
+
|
| 1412 |
+
|
| 1413 |
+
class Float(Numeric, metaclass=FloatMeta, mlir_type=T.f32, is_abstract=True):
|
| 1414 |
+
"""A class representing floating-point values.
|
| 1415 |
+
|
| 1416 |
+
:param x: The input value to convert to this float type.
|
| 1417 |
+
:type x: Union[bool, int, float, ir.Value, Integer, Float]
|
| 1418 |
+
|
| 1419 |
+
Type conversion behavior:
|
| 1420 |
+
|
| 1421 |
+
1. Python scalars (bool, int, float):
|
| 1422 |
+
- Converted through numpy dtype casting
|
| 1423 |
+
- Example: Float32(1.7) -> 1.7
|
| 1424 |
+
|
| 1425 |
+
2. MLIR Value with FloatType:
|
| 1426 |
+
- If width differs: converts between float types
|
| 1427 |
+
- Example: f16 -> f32
|
| 1428 |
+
|
| 1429 |
+
3. MLIR Value with IntegerType:
|
| 1430 |
+
- Not supported, raises ValueError
|
| 1431 |
+
|
| 1432 |
+
4. Integer:
|
| 1433 |
+
- Converts using MLIR int-to-float operation
|
| 1434 |
+
- Example: Float32(Int32(5)) -> 5.0
|
| 1435 |
+
|
| 1436 |
+
5. Float:
|
| 1437 |
+
- Direct conversion between float types
|
| 1438 |
+
- Example: Float32(Float32(1.5)) -> 1.5
|
| 1439 |
+
|
| 1440 |
+
.. note::
|
| 1441 |
+
The following narrow precision types are only supported in device code:
|
| 1442 |
+
|
| 1443 |
+
8-bit float types:
|
| 1444 |
+
- Float8E5M2
|
| 1445 |
+
- Float8E4M3
|
| 1446 |
+
- Float8E4M3FN
|
| 1447 |
+
- Float8E8M0FNU
|
| 1448 |
+
- Float8E4M3B11FNUZ
|
| 1449 |
+
|
| 1450 |
+
6-bit float types:
|
| 1451 |
+
- Float6E3M2FN
|
| 1452 |
+
- Float6E2M3FN
|
| 1453 |
+
|
| 1454 |
+
4-bit float types:
|
| 1455 |
+
- Float4E2M1FN
|
| 1456 |
+
|
| 1457 |
+
Narrow precision types and special floating-point formats support matrix on device:
|
| 1458 |
+
|
| 1459 |
+
:raises AssertionError: If the type's numpy_dtype is None
|
| 1460 |
+
:raises ValueError: If conversion from the input type is not supported
|
| 1461 |
+
"""
|
| 1462 |
+
|
| 1463 |
+
def __init__(self, x, *, loc=None, ip=None):
|
| 1464 |
+
ty = type(self)
|
| 1465 |
+
|
| 1466 |
+
if isinstance(x, (bool, int, float)): # type: ignore
|
| 1467 |
+
# Why we need to convert x to with numpy?
|
| 1468 |
+
# np_dtype = ty.numpy_dtype
|
| 1469 |
+
# assert np_dtype is not None, f"expects numpy.dtype, but got {np_dtype}"
|
| 1470 |
+
# x = float(np.array(x).astype(np_dtype))
|
| 1471 |
+
super().__init__(float(x))
|
| 1472 |
+
elif isinstance(x, ir.Value): # type: ignore
|
| 1473 |
+
if isinstance(x.type, ir.IntegerType): # type: ignore
|
| 1474 |
+
raise DSLRuntimeError("signless to float conversion is not implemented")
|
| 1475 |
+
elif isinstance(x.type, ir.FloatType): # type: ignore
|
| 1476 |
+
if x.type != ty.mlir_type:
|
| 1477 |
+
x = arith_helper.cvtf(x, ty.mlir_type, loc=loc, ip=ip)
|
| 1478 |
+
super().__init__(x)
|
| 1479 |
+
elif isinstance(x, Integer):
|
| 1480 |
+
if isinstance(x.value, ir.Value): # type: ignore
|
| 1481 |
+
x = arith_helper.itofp(
|
| 1482 |
+
x.value, type(x).signed, ty.mlir_type, loc=loc, ip=ip
|
| 1483 |
+
)
|
| 1484 |
+
else:
|
| 1485 |
+
x = float(x.value)
|
| 1486 |
+
super().__init__(x)
|
| 1487 |
+
elif isinstance(x, Float):
|
| 1488 |
+
Float.__init__(self, x.value)
|
| 1489 |
+
else:
|
| 1490 |
+
raise DSLRuntimeError(f"{x} to Float conversion is not supported")
|
| 1491 |
+
|
| 1492 |
+
|
| 1493 |
+
class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T.bool):
|
| 1494 |
+
"""Boolean type representation in the DSL.
|
| 1495 |
+
|
| 1496 |
+
This class represents boolean values in the DSL, with a width of 1 bit.
|
| 1497 |
+
It supports conversion from various types to boolean values.
|
| 1498 |
+
|
| 1499 |
+
:param a: Value to convert to Boolean
|
| 1500 |
+
:type a: Union[bool, int, float, "Value", Numeric]
|
| 1501 |
+
:param loc: Source location information, defaults to None
|
| 1502 |
+
:type loc: Optional[Location], optional
|
| 1503 |
+
:param ip: Insertion point for MLIR operations, defaults to None
|
| 1504 |
+
:type ip: Optional[InsertionPoint], optional
|
| 1505 |
+
:raises DSLRuntimeError: If the input value cannot be converted to Boolean
|
| 1506 |
+
|
| 1507 |
+
Conversion rules:
|
| 1508 |
+
|
| 1509 |
+
1. Python bool/int/float:
|
| 1510 |
+
- Converted using Python's bool() function
|
| 1511 |
+
- Example: Boolean(1) -> True, Boolean(0) -> False
|
| 1512 |
+
|
| 1513 |
+
2. Numeric:
|
| 1514 |
+
- Uses the Numeric.value to construct Boolean recursively
|
| 1515 |
+
|
| 1516 |
+
3. MLIR Value with IntegerType:
|
| 1517 |
+
- If width is 1: Direct assignment
|
| 1518 |
+
- Otherwise: Compares with 0 using arith.cmpi
|
| 1519 |
+
|
| 1520 |
+
4. MLIR Value with FloatType:
|
| 1521 |
+
- Compares with 0.0 using arith.cmpf
|
| 1522 |
+
- Uses unordered comparison to handle NaN values
|
| 1523 |
+
"""
|
| 1524 |
+
|
| 1525 |
+
def __init__(
|
| 1526 |
+
self, a: Union[bool, int, float, ir.Value, Numeric], *, loc=None, ip=None
|
| 1527 |
+
):
|
| 1528 |
+
value = None
|
| 1529 |
+
if isinstance(a, (bool, int, float)):
|
| 1530 |
+
value = bool(a)
|
| 1531 |
+
elif isinstance(a, Numeric):
|
| 1532 |
+
Boolean.__init__(self, a.value, loc=loc, ip=ip)
|
| 1533 |
+
return
|
| 1534 |
+
elif isinstance(a, ArithValue):
|
| 1535 |
+
if a.type == T.bool():
|
| 1536 |
+
value = a
|
| 1537 |
+
else:
|
| 1538 |
+
value = a != arith_helper.const(0, a.type, loc=loc, ip=ip)
|
| 1539 |
+
if value is None:
|
| 1540 |
+
raise DSLRuntimeError(f"Cannot convert {a} to Boolean")
|
| 1541 |
+
super().__init__(value, loc=loc, ip=ip)
|
| 1542 |
+
self._value_int8 = None
|
| 1543 |
+
|
| 1544 |
+
def ir_value_int8(self, *, loc=None, ip=None):
|
| 1545 |
+
"""
|
| 1546 |
+
Returns int8 ir value of Boolean.
|
| 1547 |
+
When we need to store Boolean tensor element, use ir_value_int8().
|
| 1548 |
+
|
| 1549 |
+
:param loc: Source location information, defaults to None
|
| 1550 |
+
:type loc: Optional[Location], optional
|
| 1551 |
+
:param ip: Insertion point for MLIR operations, defaults to None
|
| 1552 |
+
:type ip: Optional[InsertionPoint], optional
|
| 1553 |
+
:return: The int8 value of this Boolean
|
| 1554 |
+
:rtype: ir.Value
|
| 1555 |
+
"""
|
| 1556 |
+
if self._value_int8 is not None:
|
| 1557 |
+
return self._value_int8
|
| 1558 |
+
self._value_int8 = Int8(self.value, loc=loc, ip=ip).ir_value()
|
| 1559 |
+
return self._value_int8
|
| 1560 |
+
|
| 1561 |
+
def __neg__(self, *, loc=None, ip=None):
|
| 1562 |
+
"""Negation operator is not supported for boolean type.
|
| 1563 |
+
|
| 1564 |
+
:param loc: Source location information, defaults to None
|
| 1565 |
+
:type loc: Optional[Location], optional
|
| 1566 |
+
:param ip: Insertion point for MLIR operations, defaults to None
|
| 1567 |
+
:type ip: Optional[InsertionPoint], optional
|
| 1568 |
+
:raises TypeError: Always raises this error as negation is not supported
|
| 1569 |
+
"""
|
| 1570 |
+
raise TypeError("Negation, the operator `-` is not supported for boolean type")
|
| 1571 |
+
|
| 1572 |
+
|
| 1573 |
+
class Int8(Integer, metaclass=IntegerMeta, width=8, signed=True, mlir_type=T.i8): ...
|
| 1574 |
+
|
| 1575 |
+
|
| 1576 |
+
class Int16(Integer, metaclass=IntegerMeta, width=16, signed=True, mlir_type=T.i16): ...
|
| 1577 |
+
|
| 1578 |
+
|
| 1579 |
+
class Int32(Integer, metaclass=IntegerMeta, width=32, signed=True, mlir_type=T.i32): ...
|
| 1580 |
+
|
| 1581 |
+
|
| 1582 |
+
class Int64(Integer, metaclass=IntegerMeta, width=64, signed=True, mlir_type=T.i64): ...
|
| 1583 |
+
|
| 1584 |
+
|
| 1585 |
+
class Int128(
|
| 1586 |
+
Integer, metaclass=IntegerMeta, width=128, signed=True, mlir_type=lambda: T.i(128)
|
| 1587 |
+
): ...
|
| 1588 |
+
|
| 1589 |
+
|
| 1590 |
+
class Uint8(Integer, metaclass=IntegerMeta, width=8, signed=False, mlir_type=T.i8): ...
|
| 1591 |
+
|
| 1592 |
+
|
| 1593 |
+
class Uint16(
|
| 1594 |
+
Integer, metaclass=IntegerMeta, width=16, signed=False, mlir_type=T.i16
|
| 1595 |
+
): ...
|
| 1596 |
+
|
| 1597 |
+
|
| 1598 |
+
class Uint32(
|
| 1599 |
+
Integer, metaclass=IntegerMeta, width=32, signed=False, mlir_type=T.i32
|
| 1600 |
+
): ...
|
| 1601 |
+
|
| 1602 |
+
|
| 1603 |
+
class Uint64(
|
| 1604 |
+
Integer, metaclass=IntegerMeta, width=64, signed=False, mlir_type=T.i64
|
| 1605 |
+
): ...
|
| 1606 |
+
|
| 1607 |
+
|
| 1608 |
+
class Uint128(
|
| 1609 |
+
Integer, metaclass=IntegerMeta, width=128, signed=False, mlir_type=lambda: T.i(128)
|
| 1610 |
+
): ...
|
| 1611 |
+
|
| 1612 |
+
|
| 1613 |
+
class Float64(Float, metaclass=FloatMeta, width=64, mlir_type=T.f64):
|
| 1614 |
+
def __c_pointers__(self):
|
| 1615 |
+
if not isinstance(self.value, float):
|
| 1616 |
+
raise ValueError("only float is supported")
|
| 1617 |
+
|
| 1618 |
+
return [
|
| 1619 |
+
ctypes.cast(ctypes.pointer(ctypes.c_double(self.value)), ctypes.c_void_p)
|
| 1620 |
+
]
|
| 1621 |
+
|
| 1622 |
+
|
| 1623 |
+
class Float32(Float, metaclass=FloatMeta, width=32, mlir_type=T.f32):
|
| 1624 |
+
@staticmethod
|
| 1625 |
+
def _get_c_pointer(value: float):
|
| 1626 |
+
return ctypes.cast(ctypes.pointer(ctypes.c_float(value)), ctypes.c_void_p)
|
| 1627 |
+
|
| 1628 |
+
def __c_pointers__(self):
|
| 1629 |
+
if not isinstance(self.value, float):
|
| 1630 |
+
raise ValueError("only float is supported")
|
| 1631 |
+
|
| 1632 |
+
return [Float32._get_c_pointer(self.value)]
|
| 1633 |
+
|
| 1634 |
+
|
| 1635 |
+
class TFloat32(Float, metaclass=FloatMeta, width=32, mlir_type=T.tf32):
|
| 1636 |
+
def __c_pointers__(self):
|
| 1637 |
+
if not isinstance(self.value, float):
|
| 1638 |
+
raise ValueError("only float is supported")
|
| 1639 |
+
return [Float32._get_c_pointer(self.value)]
|
| 1640 |
+
|
| 1641 |
+
|
| 1642 |
+
class Float16(Float, metaclass=FloatMeta, width=16, mlir_type=T.f16):
|
| 1643 |
+
@staticmethod
|
| 1644 |
+
def _get_c_pointer(value: float):
|
| 1645 |
+
# Convert float to float16 binary representation
|
| 1646 |
+
# First convert to numpy float16 to handle the conversion
|
| 1647 |
+
f16_val = np.float16(value)
|
| 1648 |
+
# Get the raw bits as a 16-bit integer
|
| 1649 |
+
bits = f16_val.view(np.uint16)
|
| 1650 |
+
# Create a short (16-bit int) with those bits
|
| 1651 |
+
c_val = ctypes.c_short(bits)
|
| 1652 |
+
return ctypes.cast(ctypes.pointer(c_val), ctypes.c_void_p)
|
| 1653 |
+
|
| 1654 |
+
def __c_pointers__(self):
|
| 1655 |
+
if not isinstance(self.value, float):
|
| 1656 |
+
raise ValueError("only float is supported")
|
| 1657 |
+
return [Float16._get_c_pointer(self.value)]
|
| 1658 |
+
|
| 1659 |
+
|
| 1660 |
+
class BFloat16(Float, metaclass=FloatMeta, width=16, mlir_type=T.bf16):
|
| 1661 |
+
def __c_pointers__(self):
|
| 1662 |
+
if not isinstance(self.value, float):
|
| 1663 |
+
raise ValueError("only float is supported")
|
| 1664 |
+
|
| 1665 |
+
return Float.__c_pointers__(self)
|
| 1666 |
+
|
| 1667 |
+
|
| 1668 |
+
class Float8E5M2(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E5M2): ...
|
| 1669 |
+
|
| 1670 |
+
|
| 1671 |
+
class Float8E4M3FN(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3FN): ...
|
| 1672 |
+
|
| 1673 |
+
|
| 1674 |
+
class Float8E4M3B11FNUZ(
|
| 1675 |
+
Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3B11FNUZ
|
| 1676 |
+
): ...
|
| 1677 |
+
|
| 1678 |
+
|
| 1679 |
+
|
| 1680 |
+
# Added missing float types
|
| 1681 |
+
class Float8E4M3(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3): ...
|
| 1682 |
+
|
| 1683 |
+
|
| 1684 |
+
class Float8E8M0FNU(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E8M0FNU): ...
|
| 1685 |
+
|
| 1686 |
+
|
| 1687 |
+
class Float4E2M1FN(Float, metaclass=FloatMeta, width=4, mlir_type=T.f4E2M1FN): ...
|
| 1688 |
+
|
| 1689 |
+
|
| 1690 |
+
class Float6E3M2FN(Float, metaclass=FloatMeta, width=6, mlir_type=T.f6E3M2FN): ...
|
| 1691 |
+
|
| 1692 |
+
|
| 1693 |
+
class Float6E2M3FN(Float, metaclass=FloatMeta, width=6, mlir_type=T.f6E2M3FN): ...
|
| 1694 |
+
|
| 1695 |
+
|
| 1696 |
+
_unsupported_dst_float_types = [
|
| 1697 |
+
Float8E4M3,
|
| 1698 |
+
Float8E4M3B11FNUZ,
|
| 1699 |
+
Float4E2M1FN,
|
| 1700 |
+
Float6E3M2FN,
|
| 1701 |
+
Float6E2M3FN,
|
| 1702 |
+
]
|
| 1703 |
+
|
| 1704 |
+
|
| 1705 |
+
ALL_DTYPES = {
|
| 1706 |
+
Int8,
|
| 1707 |
+
Int16,
|
| 1708 |
+
Int32,
|
| 1709 |
+
Int64,
|
| 1710 |
+
Int128,
|
| 1711 |
+
Uint8,
|
| 1712 |
+
Uint16,
|
| 1713 |
+
Uint32,
|
| 1714 |
+
Uint64,
|
| 1715 |
+
Uint128,
|
| 1716 |
+
BFloat16,
|
| 1717 |
+
Float16,
|
| 1718 |
+
Float32,
|
| 1719 |
+
TFloat32,
|
| 1720 |
+
Float64,
|
| 1721 |
+
Float8E5M2,
|
| 1722 |
+
Float8E4M3,
|
| 1723 |
+
Float8E4M3FN,
|
| 1724 |
+
Float8E8M0FNU,
|
| 1725 |
+
Float8E4M3B11FNUZ,
|
| 1726 |
+
Float4E2M1FN,
|
| 1727 |
+
Float6E2M3FN,
|
| 1728 |
+
Float6E3M2FN,
|
| 1729 |
+
}
|
| 1730 |
+
__STR_TO_DTYPE__ = {dt.__name__: dt for dt in ALL_DTYPES}
|
| 1731 |
+
|
| 1732 |
+
|
| 1733 |
+
def dtype(dtype_) -> Type[Numeric]:
|
| 1734 |
+
t = None
|
| 1735 |
+
if const_expr(isinstance(dtype_, str) and dtype_ in __STR_TO_DTYPE__):
|
| 1736 |
+
t = __STR_TO_DTYPE__[dtype_]
|
| 1737 |
+
else:
|
| 1738 |
+
raise TypeError(f"can't interpret {dtype_} as data type")
|
| 1739 |
+
|
| 1740 |
+
return t
|
| 1741 |
+
|
| 1742 |
+
|
| 1743 |
+
##############################################################
|
| 1744 |
+
# Tensor
|
| 1745 |
+
##############################################################
|
| 1746 |
+
|
| 1747 |
+
|
| 1748 |
+
class TensorMeta(DslType):
|
| 1749 |
+
_element_type = Any
|
| 1750 |
+
_shape = Any
|
| 1751 |
+
|
| 1752 |
+
"""
|
| 1753 |
+
Examples:
|
| 1754 |
+
>>> Tensor[Int32, (3,)]
|
| 1755 |
+
>>> Tensor[Float32, (3, 4)]
|
| 1756 |
+
>>> T = TypeVar("T")
|
| 1757 |
+
>>> Tensor[T, (3, 4, 5)]
|
| 1758 |
+
"""
|
| 1759 |
+
|
| 1760 |
+
def __new__(cls, name, bases, attrs, element_type=Any, shape=Any):
|
| 1761 |
+
new_cls = super().__new__(cls, name, bases, attrs)
|
| 1762 |
+
new_cls._element_type = element_type
|
| 1763 |
+
new_cls._shape = shape
|
| 1764 |
+
return new_cls
|
| 1765 |
+
|
| 1766 |
+
|
| 1767 |
+
# Generic type
|
| 1768 |
+
TY = TypeVar("TY")
|
| 1769 |
+
|
| 1770 |
+
|
| 1771 |
+
class Constexpr(Generic[TY]):
|
| 1772 |
+
"""Value is passed and computed by python interpreter"""
|
| 1773 |
+
|
| 1774 |
+
pass
|
| 1775 |
+
|
| 1776 |
+
|
| 1777 |
+
class align:
|
| 1778 |
+
def __init__(self, value: int):
|
| 1779 |
+
if value <= 0 or (value & (value - 1)) != 0:
|
| 1780 |
+
raise DSLRuntimeError("expects align be power of 2 as positive value")
|
| 1781 |
+
self._value = value
|
| 1782 |
+
|
| 1783 |
+
def __str__(self):
|
| 1784 |
+
return f"align({self._value})"
|
| 1785 |
+
|
| 1786 |
+
|
| 1787 |
+
class PointerMeta(DslType):
|
| 1788 |
+
def __new__(cls, name, bases, attrs, value_type=Int32, align_=align(1)):
|
| 1789 |
+
new_cls = super().__new__(
|
| 1790 |
+
cls,
|
| 1791 |
+
name,
|
| 1792 |
+
bases,
|
| 1793 |
+
attrs,
|
| 1794 |
+
mlir_type=lambda: getattr(ir, "UnrankedMemRefType").get(
|
| 1795 |
+
value_type.mlir_type, getattr(ir, "Attribute").parse("0")
|
| 1796 |
+
),
|
| 1797 |
+
)
|
| 1798 |
+
new_cls._value_type = value_type
|
| 1799 |
+
new_cls._align = align_
|
| 1800 |
+
return new_cls
|
| 1801 |
+
|
| 1802 |
+
def __eq__(cls, other):
|
| 1803 |
+
if not isinstance(other, PointerMeta):
|
| 1804 |
+
return False
|
| 1805 |
+
return (
|
| 1806 |
+
cls._value_type == other._value_type
|
| 1807 |
+
and cls._align._value == other._align._value
|
| 1808 |
+
) # Compare alignment values
|
| 1809 |
+
|
| 1810 |
+
def __hash__(cls):
|
| 1811 |
+
return hash((cls._value_type, cls._align._value)) # Hash alignment value
|
| 1812 |
+
|
| 1813 |
+
def __getitem__(cls, params) -> Type["Pointer"]:
|
| 1814 |
+
value_type, align_ = params
|
| 1815 |
+
|
| 1816 |
+
if not isinstance(align_, align):
|
| 1817 |
+
raise DSLRuntimeError(f"expects align but got {align_}")
|
| 1818 |
+
|
| 1819 |
+
# Create new class with proper name and parameters
|
| 1820 |
+
new_cls = type(
|
| 1821 |
+
f"Pointer[{value_type.__name__}, {align_}]",
|
| 1822 |
+
(Pointer,),
|
| 1823 |
+
{},
|
| 1824 |
+
value_type=value_type,
|
| 1825 |
+
align_=align_, # Pass alignment to __new__
|
| 1826 |
+
)
|
| 1827 |
+
return new_cls
|
| 1828 |
+
|
| 1829 |
+
def __str__(cls):
|
| 1830 |
+
return f"ptr<{cls._value_type}, {cls._align}>"
|
| 1831 |
+
|
| 1832 |
+
|
| 1833 |
+
class Pointer(metaclass=PointerMeta):
|
| 1834 |
+
"""
|
| 1835 |
+
A pointer to a memory location.
|
| 1836 |
+
|
| 1837 |
+
Examples:
|
| 1838 |
+
|
| 1839 |
+
def foo(a : Pointer[Int32, align=8]):
|
| 1840 |
+
...
|
| 1841 |
+
|
| 1842 |
+
"""
|
| 1843 |
+
|
| 1844 |
+
def __init__(self, value):
|
| 1845 |
+
self.value = value
|
| 1846 |
+
|
| 1847 |
+
def __str__(self):
|
| 1848 |
+
return f"{self.value} : {type(self)}"
|
| 1849 |
+
|
| 1850 |
+
|
| 1851 |
+
class IRConst(Generic[TY]):
|
| 1852 |
+
"""Value is passed as MLIR constant value for (arith.constant)."""
|
| 1853 |
+
|
| 1854 |
+
def __init__(self, ty: TY):
|
| 1855 |
+
self.ty = ty
|
| 1856 |
+
|
| 1857 |
+
|
| 1858 |
+
class IRValue(Generic[TY]):
|
| 1859 |
+
"""Value is passed as MLIR dynamic value."""
|
| 1860 |
+
|
| 1861 |
+
def __init__(self, ty: TY):
|
| 1862 |
+
self.ty = ty
|
| 1863 |
+
|
| 1864 |
+
|
| 1865 |
+
class IRVariadic:
|
| 1866 |
+
"""
|
| 1867 |
+
A helper class to pass a variadic number of arguments to a function.
|
| 1868 |
+
"""
|
| 1869 |
+
|
| 1870 |
+
def __init__(self, operands):
|
| 1871 |
+
"""
|
| 1872 |
+
Create a list of variadic operands. `operands` must be dynamic values.
|
| 1873 |
+
"""
|
| 1874 |
+
self.operands = operands
|
| 1875 |
+
|
| 1876 |
+
def block_arg_types(self):
|
| 1877 |
+
"""
|
| 1878 |
+
Return the list of block args types.
|
| 1879 |
+
"""
|
| 1880 |
+
return [operand.type for operand in self.operands]
|
| 1881 |
+
|
| 1882 |
+
def set_func_args(self, block_args):
|
| 1883 |
+
"""
|
| 1884 |
+
This function is called after entering a function. `block_args` are the
|
| 1885 |
+
block arguments that correspond to the passed operands. Derived classes
|
| 1886 |
+
may implement this function to provide convenience getters for block
|
| 1887 |
+
arguments.
|
| 1888 |
+
"""
|
| 1889 |
+
pass
|
| 1890 |
+
|
| 1891 |
+
def __len__(self):
|
| 1892 |
+
"""
|
| 1893 |
+
Return the length of variadic operands.
|
| 1894 |
+
"""
|
| 1895 |
+
return len(self.operands)
|
| 1896 |
+
|
| 1897 |
+
|
| 1898 |
+
class FuncArgWithAttr(IRValue):
|
| 1899 |
+
"""
|
| 1900 |
+
This derived class is specifically for func op arg with attr
|
| 1901 |
+
"""
|
| 1902 |
+
|
| 1903 |
+
def __init__(self, ty, attr_name, attr_ty, attr_value=None):
|
| 1904 |
+
super().__init__(ty)
|
| 1905 |
+
assert attr_name is not None and (
|
| 1906 |
+
attr_ty is not None or attr_value is not None
|
| 1907 |
+
), "Invalid attr_name and/or attr_ty and/or attr_value for FuncArgWithAttr"
|
| 1908 |
+
self.attr_name = attr_name
|
| 1909 |
+
self.attr_ty = attr_ty
|
| 1910 |
+
self.attr_value = attr_value
|
| 1911 |
+
|
| 1912 |
+
|
| 1913 |
+
|
| 1914 |
+
def implicitDowncastNumericType(value):
|
| 1915 |
+
if isinstance(value, Numeric):
|
| 1916 |
+
return value.ir_value()
|
| 1917 |
+
return value
|
| 1918 |
+
|
| 1919 |
+
|
| 1920 |
+
__all__ = [
|
| 1921 |
+
"DslType",
|
| 1922 |
+
"Numeric",
|
| 1923 |
+
"NumericMeta",
|
| 1924 |
+
"IntegerMeta",
|
| 1925 |
+
"FloatMeta",
|
| 1926 |
+
"Boolean",
|
| 1927 |
+
"Integer",
|
| 1928 |
+
"Int16",
|
| 1929 |
+
"Int32",
|
| 1930 |
+
"Int64",
|
| 1931 |
+
"Int128",
|
| 1932 |
+
"Int8",
|
| 1933 |
+
"Uint8",
|
| 1934 |
+
"Uint16",
|
| 1935 |
+
"Uint32",
|
| 1936 |
+
"Uint64",
|
| 1937 |
+
"Uint128",
|
| 1938 |
+
"Float",
|
| 1939 |
+
"Float16",
|
| 1940 |
+
"BFloat16",
|
| 1941 |
+
"TFloat32",
|
| 1942 |
+
"Float32",
|
| 1943 |
+
"Float64",
|
| 1944 |
+
"Float8E5M2",
|
| 1945 |
+
"Float8E4M3",
|
| 1946 |
+
"Float8E4M3FN",
|
| 1947 |
+
"Float8E4M3B11FNUZ",
|
| 1948 |
+
"Float8E4M3",
|
| 1949 |
+
"Float8E8M0FNU",
|
| 1950 |
+
"Float4E2M1FN",
|
| 1951 |
+
"Float6E2M3FN",
|
| 1952 |
+
"Float6E3M2FN",
|
| 1953 |
+
"as_numeric",
|
| 1954 |
+
"align",
|
| 1955 |
+
"Pointer",
|
| 1956 |
+
"dtype",
|
| 1957 |
+
"Constexpr",
|
| 1958 |
+
"IRConst",
|
| 1959 |
+
"IRValue",
|
| 1960 |
+
"IRVariadic",
|
| 1961 |
+
"implicitDowncastNumericType",
|
| 1962 |
+
]
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from . import stacktrace
|
| 13 |
+
from . import logger
|
| 14 |
+
from . import timer
|
| 15 |
+
__all__ = [
|
| 16 |
+
"logger",
|
| 17 |
+
"timer",
|
| 18 |
+
"stacktrace",
|
| 19 |
+
]
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/logger.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This module provides logging helper functions
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
logger = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def log():
|
| 22 |
+
return logger
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def setup_log(
|
| 26 |
+
name, log_to_console=False, log_to_file=False, log_file_path=None, log_level=1
|
| 27 |
+
):
|
| 28 |
+
"""Set up and configure a logger with console and/or file handlers.
|
| 29 |
+
|
| 30 |
+
:param name: Name of the logger to create
|
| 31 |
+
:type name: str
|
| 32 |
+
:param log_to_console: Whether to enable logging to console, defaults to False
|
| 33 |
+
:type log_to_console: bool, optional
|
| 34 |
+
:param log_to_file: Whether to enable logging to file, defaults to False
|
| 35 |
+
:type log_to_file: bool, optional
|
| 36 |
+
:param log_file_path: Path to the log file, required if log_to_file is True
|
| 37 |
+
:type log_file_path: str, optional
|
| 38 |
+
:param log_level: Logging level to set, defaults to 1
|
| 39 |
+
:type log_level: int, optional
|
| 40 |
+
:raises ValueError: If log_to_file is True but log_file_path is not provided
|
| 41 |
+
:return: Configured logger instance
|
| 42 |
+
:rtype: logging.Logger
|
| 43 |
+
"""
|
| 44 |
+
# Create a custom logger
|
| 45 |
+
global logger
|
| 46 |
+
logger = logging.getLogger(name)
|
| 47 |
+
if log_to_console or log_to_file:
|
| 48 |
+
logger.setLevel(log_level)
|
| 49 |
+
else:
|
| 50 |
+
# Makes sure logging is OFF
|
| 51 |
+
logger.setLevel(logging.CRITICAL + 1)
|
| 52 |
+
|
| 53 |
+
# Clear existing handlers to prevent duplicate logs
|
| 54 |
+
if logger.hasHandlers():
|
| 55 |
+
logger.handlers.clear()
|
| 56 |
+
|
| 57 |
+
# Define formatter
|
| 58 |
+
formatter = logging.Formatter(
|
| 59 |
+
f"%(asctime)s - %(name)s - %(levelname)s - [%(funcName)s] - %(message)s"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Add console handler if enabled
|
| 63 |
+
if log_to_console:
|
| 64 |
+
console_handler = logging.StreamHandler()
|
| 65 |
+
console_handler.setLevel(log_level)
|
| 66 |
+
console_handler.setFormatter(formatter)
|
| 67 |
+
logger.addHandler(console_handler)
|
| 68 |
+
|
| 69 |
+
# Add file handler if enabled
|
| 70 |
+
if log_to_file:
|
| 71 |
+
if not log_file_path:
|
| 72 |
+
raise ValueError("log_file_path must be provided when enable_file is True")
|
| 73 |
+
file_handler = logging.FileHandler(log_file_path)
|
| 74 |
+
file_handler.setLevel(log_level)
|
| 75 |
+
file_handler.setFormatter(formatter)
|
| 76 |
+
logger.addHandler(file_handler)
|
| 77 |
+
|
| 78 |
+
return logger
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
logger = setup_log("generic")
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/stacktrace.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This module provides stacktrace helper functions
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def walk_to_top_module(start_path):
|
| 21 |
+
"""
|
| 22 |
+
Walk up from the start_path to find the top-level Python module.
|
| 23 |
+
|
| 24 |
+
:param start_path: The path to start from.
|
| 25 |
+
:return: The path of the top-level module.
|
| 26 |
+
"""
|
| 27 |
+
current_path = start_path
|
| 28 |
+
|
| 29 |
+
while True:
|
| 30 |
+
# Check if we are at the root directory
|
| 31 |
+
if os.path.dirname(current_path) == current_path:
|
| 32 |
+
break
|
| 33 |
+
|
| 34 |
+
# Check for __init__.py
|
| 35 |
+
init_file_path = os.path.join(current_path, "__init__.py")
|
| 36 |
+
if os.path.isfile(init_file_path):
|
| 37 |
+
# If __init__.py exists, move up one level
|
| 38 |
+
current_path = os.path.dirname(current_path)
|
| 39 |
+
else:
|
| 40 |
+
# If no __init__.py, we are not in a module; stop
|
| 41 |
+
break
|
| 42 |
+
|
| 43 |
+
# If we reached the root without finding a module, return None
|
| 44 |
+
if os.path.dirname(current_path) == current_path and not os.path.isfile(
|
| 45 |
+
os.path.join(current_path, "__init__.py")
|
| 46 |
+
):
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
# Return the path of the top-level module
|
| 50 |
+
return current_path
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _filter_internal_frames(traceback, internal_path):
|
| 54 |
+
"""
|
| 55 |
+
Filter out stack frames from the traceback that belong to the specified module path.
|
| 56 |
+
|
| 57 |
+
This function removes stack frames from the traceback whose file paths start with
|
| 58 |
+
the given prefix_path, effectively hiding internal implementation details from
|
| 59 |
+
the error traceback shown to users.
|
| 60 |
+
"""
|
| 61 |
+
iter_prev = None
|
| 62 |
+
iter_tb = traceback
|
| 63 |
+
while iter_tb is not None:
|
| 64 |
+
if os.path.abspath(iter_tb.tb_frame.f_code.co_filename).startswith(
|
| 65 |
+
internal_path
|
| 66 |
+
):
|
| 67 |
+
if iter_tb.tb_next:
|
| 68 |
+
if iter_prev:
|
| 69 |
+
iter_prev.tb_next = iter_tb.tb_next
|
| 70 |
+
else:
|
| 71 |
+
traceback = iter_tb.tb_next
|
| 72 |
+
else:
|
| 73 |
+
iter_prev = iter_tb
|
| 74 |
+
iter_tb = iter_tb.tb_next
|
| 75 |
+
return traceback
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
_generated_function_names = re.compile(
|
| 79 |
+
r"^(loop_body|while_region|while_before_block|while_after_block|if_region|then_block|else_block|elif_region)_\d+$"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _filter_duplicated_frames(traceback):
|
| 84 |
+
"""
|
| 85 |
+
Filter out duplicated stack frames from the traceback.
|
| 86 |
+
The function filters out consecutive frames that are in the same file and have the same line number.
|
| 87 |
+
In a sequence of consecutive frames, the logic prefers to keep the non-generated frame or the last frame.
|
| 88 |
+
"""
|
| 89 |
+
iter_prev = None
|
| 90 |
+
iter_tb = traceback
|
| 91 |
+
while iter_tb is not None:
|
| 92 |
+
skip_current = False
|
| 93 |
+
skip_next = False
|
| 94 |
+
if iter_tb.tb_next:
|
| 95 |
+
current_filename = os.path.abspath(iter_tb.tb_frame.f_code.co_filename)
|
| 96 |
+
next_filename = os.path.abspath(iter_tb.tb_next.tb_frame.f_code.co_filename)
|
| 97 |
+
# if in the same file, check if the line number is the same
|
| 98 |
+
if current_filename == next_filename:
|
| 99 |
+
current_lineno = iter_tb.tb_lineno
|
| 100 |
+
next_lineno = iter_tb.tb_next.tb_lineno
|
| 101 |
+
if current_lineno == next_lineno:
|
| 102 |
+
# Same file and line number, check name, if current is generated, skip current, otherwise skip next
|
| 103 |
+
name = iter_tb.tb_frame.f_code.co_name
|
| 104 |
+
is_generated = bool(_generated_function_names.match(name))
|
| 105 |
+
if is_generated:
|
| 106 |
+
# Skip current
|
| 107 |
+
skip_current = True
|
| 108 |
+
else:
|
| 109 |
+
# Skip next if it's generated, otherwise keep both
|
| 110 |
+
next_name = iter_tb.tb_next.tb_frame.f_code.co_name
|
| 111 |
+
skip_next = bool(_generated_function_names.match(next_name))
|
| 112 |
+
if skip_current:
|
| 113 |
+
if iter_prev:
|
| 114 |
+
iter_prev.tb_next = iter_tb.tb_next
|
| 115 |
+
else:
|
| 116 |
+
traceback = iter_tb.tb_next
|
| 117 |
+
elif skip_next:
|
| 118 |
+
# if next is last frame, don't skip
|
| 119 |
+
if iter_tb.tb_next.tb_next:
|
| 120 |
+
iter_tb.tb_next = iter_tb.tb_next.tb_next
|
| 121 |
+
iter_prev = iter_tb
|
| 122 |
+
else:
|
| 123 |
+
iter_prev = iter_tb
|
| 124 |
+
iter_tb = iter_tb.tb_next
|
| 125 |
+
|
| 126 |
+
return traceback
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def filter_stackframe(traceback, prefix_path):
|
| 130 |
+
"""
|
| 131 |
+
Filter out stack frames from the traceback that belong to the specified module path.
|
| 132 |
+
|
| 133 |
+
This function removes stack frames from the traceback whose file paths start with
|
| 134 |
+
the given prefix_path, effectively hiding internal implementation details from
|
| 135 |
+
the error traceback shown to users.
|
| 136 |
+
|
| 137 |
+
:param traceback: The traceback object to filter.
|
| 138 |
+
:param prefix_path: The path prefix to filter out from the traceback.
|
| 139 |
+
:return: The filtered traceback with internal frames removed.
|
| 140 |
+
"""
|
| 141 |
+
# Step 1: filter internal frames
|
| 142 |
+
traceback = _filter_internal_frames(traceback, prefix_path)
|
| 143 |
+
|
| 144 |
+
# Step 2: consolidate duplicated frames
|
| 145 |
+
return _filter_duplicated_frames(traceback)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def filter_exception(value, module_dir):
|
| 149 |
+
"""
|
| 150 |
+
Filter out internal implementation details from exception traceback.
|
| 151 |
+
|
| 152 |
+
This function recursively processes an exception and its cause chain,
|
| 153 |
+
removing stack frames that belong to the specified module directory.
|
| 154 |
+
This helps to present cleaner error messages to users by hiding
|
| 155 |
+
implementation details.
|
| 156 |
+
|
| 157 |
+
:param value: The exception object to filter.
|
| 158 |
+
:param module_dir: The module directory path to filter out from tracebacks.
|
| 159 |
+
:return: The filtered exception with internal frames removed.
|
| 160 |
+
"""
|
| 161 |
+
if hasattr(value, "__cause__") and value.__cause__:
|
| 162 |
+
filter_exception(value.__cause__, module_dir)
|
| 163 |
+
|
| 164 |
+
if hasattr(value, "__traceback__"):
|
| 165 |
+
filter_stackframe(value.__traceback__, module_dir)
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/timer.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
This module provides a timing helper functions
|
| 14 |
+
"""
|
| 15 |
+
from functools import wraps
|
| 16 |
+
|
| 17 |
+
from .logger import log
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# TODO: revisit this part when mlir timing manager is ready for pybind.
|
| 21 |
+
def timer(*dargs, **kwargs):
|
| 22 |
+
enable = kwargs.get("enable", True)
|
| 23 |
+
|
| 24 |
+
def decorator(func):
|
| 25 |
+
@wraps(func)
|
| 26 |
+
def func_wrapper(*args, **kwargs):
|
| 27 |
+
if not enable:
|
| 28 |
+
return func(*args, **kwargs)
|
| 29 |
+
from time import time
|
| 30 |
+
|
| 31 |
+
start = time()
|
| 32 |
+
result = func(*args, **kwargs)
|
| 33 |
+
end = time()
|
| 34 |
+
|
| 35 |
+
# Convert time from seconds to us
|
| 36 |
+
spend_us = (end - start) * 1e6
|
| 37 |
+
|
| 38 |
+
# Determine the function type and format the log message
|
| 39 |
+
if hasattr(func, "__name__"):
|
| 40 |
+
func_name = func.__name__
|
| 41 |
+
log_message = f"[JIT-TIMER] Function: {func_name} | Execution Time: {spend_us:.2f} µs"
|
| 42 |
+
elif "CFunctionType" in str(type(func)):
|
| 43 |
+
log_message = f"[JIT-TIMER] C API Function: {str(func)} | Execution Time: {spend_us:.2f} µs"
|
| 44 |
+
else:
|
| 45 |
+
log_message = f"[JIT-TIMER] Anonymous Function | Execution Time: {spend_us:.2f} µs"
|
| 46 |
+
|
| 47 |
+
log().info(log_message)
|
| 48 |
+
|
| 49 |
+
return result
|
| 50 |
+
|
| 51 |
+
return func_wrapper
|
| 52 |
+
|
| 53 |
+
if len(dargs) == 1 and callable(dargs[0]):
|
| 54 |
+
return decorator(dargs[0])
|
| 55 |
+
else:
|
| 56 |
+
return decorator
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/__init__.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from .cutlass_dsl import (
|
| 13 |
+
Constexpr,
|
| 14 |
+
as_numeric,
|
| 15 |
+
min,
|
| 16 |
+
max,
|
| 17 |
+
and_,
|
| 18 |
+
or_,
|
| 19 |
+
all_,
|
| 20 |
+
any_,
|
| 21 |
+
not_,
|
| 22 |
+
all_,
|
| 23 |
+
any_,
|
| 24 |
+
select_,
|
| 25 |
+
# Control-flow without AST pre-processor
|
| 26 |
+
if_generate,
|
| 27 |
+
for_generate,
|
| 28 |
+
LoopUnroll,
|
| 29 |
+
while_generate,
|
| 30 |
+
yield_out,
|
| 31 |
+
# Control-flow with AST pre-processor
|
| 32 |
+
range_constexpr,
|
| 33 |
+
range_dynamic,
|
| 34 |
+
const_expr,
|
| 35 |
+
dynamic_expr,
|
| 36 |
+
# Data types
|
| 37 |
+
dtype, # Provides conversions to types inheriting from NumericType
|
| 38 |
+
DSLRuntimeError,
|
| 39 |
+
JitArgAdapterRegistry,
|
| 40 |
+
# Construction utilities for user-defined classes
|
| 41 |
+
extract_mlir_values,
|
| 42 |
+
new_from_mlir_values,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
from .cute.typing import *
|
| 46 |
+
|
| 47 |
+
# Utilities not belonging to CuTe
|
| 48 |
+
from . import utils as utils
|
| 49 |
+
|
| 50 |
+
# Used as internal symbol
|
| 51 |
+
from . import cutlass_dsl as _dsl
|
| 52 |
+
|
| 53 |
+
# Aliases
|
| 54 |
+
LaunchConfig = _dsl.BaseDSL.LaunchConfig
|
| 55 |
+
register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter
|
| 56 |
+
gpu = _dsl.cutlass_gpu
|
| 57 |
+
cuda = _dsl.cuda_helpers
|
| 58 |
+
|
| 59 |
+
CACHE_FILE = "compiled_cache.db"
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/__init__.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
# Use the auto-generated enum AddressSpace
|
| 13 |
+
from cutlass._mlir.dialects.cute import AddressSpace
|
| 14 |
+
|
| 15 |
+
# Explicitly import types that might be directly used by other modules.
|
| 16 |
+
# This is a fix for using Sphinx to generate documentation
|
| 17 |
+
# Because Sphinx processes each module in isolation, it won't be able to rely
|
| 18 |
+
# on re-exported symbols via wildcard imports (from .typing import *) in the
|
| 19 |
+
# same way that Python does at runtime.
|
| 20 |
+
from .typing import (
|
| 21 |
+
Shape,
|
| 22 |
+
Stride,
|
| 23 |
+
IntTuple,
|
| 24 |
+
Coord,
|
| 25 |
+
Tile,
|
| 26 |
+
XTuple,
|
| 27 |
+
Tiler,
|
| 28 |
+
Layout,
|
| 29 |
+
Pointer,
|
| 30 |
+
Tensor,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Import everything else
|
| 34 |
+
from .typing import *
|
| 35 |
+
|
| 36 |
+
from .core import (
|
| 37 |
+
assume,
|
| 38 |
+
is_integer,
|
| 39 |
+
is_int_tuple,
|
| 40 |
+
is_static,
|
| 41 |
+
size,
|
| 42 |
+
has_underscore,
|
| 43 |
+
slice_,
|
| 44 |
+
make_ptr,
|
| 45 |
+
make_layout,
|
| 46 |
+
recast_layout,
|
| 47 |
+
make_fragment_like,
|
| 48 |
+
depth,
|
| 49 |
+
rank,
|
| 50 |
+
flatten_to_tuple,
|
| 51 |
+
flatten,
|
| 52 |
+
unflatten,
|
| 53 |
+
product,
|
| 54 |
+
product_like,
|
| 55 |
+
shape,
|
| 56 |
+
size_in_bytes,
|
| 57 |
+
make_identity_layout,
|
| 58 |
+
make_ordered_layout,
|
| 59 |
+
make_composed_layout,
|
| 60 |
+
make_layout_tv,
|
| 61 |
+
make_swizzle,
|
| 62 |
+
recast_ptr,
|
| 63 |
+
make_tensor,
|
| 64 |
+
make_identity_tensor,
|
| 65 |
+
make_fragment,
|
| 66 |
+
recast_tensor,
|
| 67 |
+
get,
|
| 68 |
+
select,
|
| 69 |
+
front,
|
| 70 |
+
is_major,
|
| 71 |
+
leading_dim,
|
| 72 |
+
find,
|
| 73 |
+
find_if,
|
| 74 |
+
coalesce,
|
| 75 |
+
group_modes,
|
| 76 |
+
cosize,
|
| 77 |
+
dice,
|
| 78 |
+
product_each,
|
| 79 |
+
prepend,
|
| 80 |
+
append,
|
| 81 |
+
prepend_ones,
|
| 82 |
+
append_ones,
|
| 83 |
+
ceil_div,
|
| 84 |
+
slice_and_offset,
|
| 85 |
+
crd2idx,
|
| 86 |
+
domain_offset,
|
| 87 |
+
elem_less,
|
| 88 |
+
transform_leaf,
|
| 89 |
+
filter_zeros,
|
| 90 |
+
filter,
|
| 91 |
+
tile_to_shape,
|
| 92 |
+
shape_div,
|
| 93 |
+
composition,
|
| 94 |
+
complement,
|
| 95 |
+
right_inverse,
|
| 96 |
+
left_inverse,
|
| 97 |
+
max_common_layout,
|
| 98 |
+
max_common_vector,
|
| 99 |
+
logical_product,
|
| 100 |
+
zipped_product,
|
| 101 |
+
tiled_product,
|
| 102 |
+
flat_product,
|
| 103 |
+
raked_product,
|
| 104 |
+
blocked_product,
|
| 105 |
+
flat_divide,
|
| 106 |
+
logical_divide,
|
| 107 |
+
zipped_divide,
|
| 108 |
+
tiled_divide,
|
| 109 |
+
local_partition,
|
| 110 |
+
local_tile,
|
| 111 |
+
printf,
|
| 112 |
+
print_tensor,
|
| 113 |
+
# tiled mma/tiled copy
|
| 114 |
+
make_mma_atom,
|
| 115 |
+
make_tiled_mma,
|
| 116 |
+
make_copy_atom,
|
| 117 |
+
make_tiled_copy_tv,
|
| 118 |
+
make_tiled_copy,
|
| 119 |
+
make_tiled_copy_S,
|
| 120 |
+
make_tiled_copy_D,
|
| 121 |
+
make_tiled_copy_A,
|
| 122 |
+
make_tiled_copy_B,
|
| 123 |
+
make_tiled_copy_C,
|
| 124 |
+
make_tiled_copy_C_atom,
|
| 125 |
+
basic_copy,
|
| 126 |
+
basic_copy_if,
|
| 127 |
+
autovec_copy,
|
| 128 |
+
copy,
|
| 129 |
+
copy_atom_call,
|
| 130 |
+
gemm,
|
| 131 |
+
# Wrapper classes
|
| 132 |
+
ComposedLayout,
|
| 133 |
+
Swizzle,
|
| 134 |
+
E,
|
| 135 |
+
Atom,
|
| 136 |
+
MmaAtom,
|
| 137 |
+
CopyAtom,
|
| 138 |
+
TiledCopy,
|
| 139 |
+
TiledMma,
|
| 140 |
+
TensorSSA,
|
| 141 |
+
ReductionOp,
|
| 142 |
+
full,
|
| 143 |
+
full_like,
|
| 144 |
+
empty_like,
|
| 145 |
+
ones_like,
|
| 146 |
+
zeros_like,
|
| 147 |
+
where,
|
| 148 |
+
any_,
|
| 149 |
+
all_,
|
| 150 |
+
# User defined struct
|
| 151 |
+
struct,
|
| 152 |
+
pretty_str,
|
| 153 |
+
make_layout_image_mask,
|
| 154 |
+
repeat_like,
|
| 155 |
+
round_up,
|
| 156 |
+
is_congruent,
|
| 157 |
+
is_weakly_congruent,
|
| 158 |
+
ScaledBasis,
|
| 159 |
+
get_divisibility,
|
| 160 |
+
Ratio,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
from . import arch
|
| 164 |
+
from . import nvgpu
|
| 165 |
+
from . import testing
|
| 166 |
+
from . import runtime
|
| 167 |
+
|
| 168 |
+
# Export all math ops without "math."
|
| 169 |
+
from .math import *
|
| 170 |
+
|
| 171 |
+
# Used as internal symbol
|
| 172 |
+
from .. import cutlass_dsl as _dsl
|
| 173 |
+
|
| 174 |
+
# Aliases
|
| 175 |
+
jit = _dsl.CuTeDSL.jit
|
| 176 |
+
kernel = _dsl.CuTeDSL.kernel
|
| 177 |
+
register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter
|
| 178 |
+
compile = _dsl.compile
|
| 179 |
+
|
| 180 |
+
# Explicitly export all symbols for documentation generation
|
| 181 |
+
__all__ = [
|
| 182 |
+
# Core types
|
| 183 |
+
"AddressSpace",
|
| 184 |
+
"Tensor",
|
| 185 |
+
"Layout",
|
| 186 |
+
"ComposedLayout",
|
| 187 |
+
"Swizzle",
|
| 188 |
+
"E",
|
| 189 |
+
"Atom",
|
| 190 |
+
"MmaAtom",
|
| 191 |
+
"CopyAtom",
|
| 192 |
+
"TiledCopy",
|
| 193 |
+
"TiledMma",
|
| 194 |
+
"TensorSSA",
|
| 195 |
+
# Basic utility functions
|
| 196 |
+
"assume",
|
| 197 |
+
"is_integer",
|
| 198 |
+
"is_int_tuple",
|
| 199 |
+
"is_static",
|
| 200 |
+
"size",
|
| 201 |
+
"has_underscore",
|
| 202 |
+
"slice_",
|
| 203 |
+
"depth",
|
| 204 |
+
"rank",
|
| 205 |
+
"shape",
|
| 206 |
+
"printf",
|
| 207 |
+
"print_tensor",
|
| 208 |
+
"pretty_str",
|
| 209 |
+
# Layout functions
|
| 210 |
+
"make_layout",
|
| 211 |
+
"recast_layout",
|
| 212 |
+
"make_identity_layout",
|
| 213 |
+
"make_ordered_layout",
|
| 214 |
+
"make_composed_layout",
|
| 215 |
+
"make_layout_tv",
|
| 216 |
+
"make_layout_image_mask",
|
| 217 |
+
# Tensor functions
|
| 218 |
+
"make_ptr",
|
| 219 |
+
"make_tensor",
|
| 220 |
+
"make_identity_tensor",
|
| 221 |
+
"make_fragment",
|
| 222 |
+
"make_fragment_like",
|
| 223 |
+
"recast_ptr",
|
| 224 |
+
"recast_tensor",
|
| 225 |
+
# Tensor manipulation
|
| 226 |
+
"get",
|
| 227 |
+
"select",
|
| 228 |
+
"front",
|
| 229 |
+
"is_major",
|
| 230 |
+
"leading_dim",
|
| 231 |
+
"find",
|
| 232 |
+
"find_if",
|
| 233 |
+
"coalesce",
|
| 234 |
+
"group_modes",
|
| 235 |
+
"cosize",
|
| 236 |
+
"size_in_bytes",
|
| 237 |
+
# Tuple operations
|
| 238 |
+
"flatten_to_tuple",
|
| 239 |
+
"flatten",
|
| 240 |
+
"product",
|
| 241 |
+
"product_like",
|
| 242 |
+
"product_each",
|
| 243 |
+
"prepend",
|
| 244 |
+
"append",
|
| 245 |
+
"prepend_ones",
|
| 246 |
+
"append_ones",
|
| 247 |
+
# Math operations
|
| 248 |
+
"ceil_div",
|
| 249 |
+
"round_up",
|
| 250 |
+
# Layout operations
|
| 251 |
+
"slice_and_offset",
|
| 252 |
+
"crd2idx",
|
| 253 |
+
"domain_offset",
|
| 254 |
+
"elem_less",
|
| 255 |
+
"filter_zeros",
|
| 256 |
+
"filter",
|
| 257 |
+
"tile_to_shape",
|
| 258 |
+
"shape_div",
|
| 259 |
+
"dice",
|
| 260 |
+
# Layout algebra
|
| 261 |
+
"composition",
|
| 262 |
+
"complement",
|
| 263 |
+
"right_inverse",
|
| 264 |
+
"left_inverse",
|
| 265 |
+
"max_common_layout",
|
| 266 |
+
"max_common_vector",
|
| 267 |
+
"is_congruent",
|
| 268 |
+
"is_weakly_congruent",
|
| 269 |
+
# Product operations
|
| 270 |
+
"logical_product",
|
| 271 |
+
"zipped_product",
|
| 272 |
+
"tiled_product",
|
| 273 |
+
"flat_product",
|
| 274 |
+
"raked_product",
|
| 275 |
+
"blocked_product",
|
| 276 |
+
# Division operations
|
| 277 |
+
"flat_divide",
|
| 278 |
+
"logical_divide",
|
| 279 |
+
"zipped_divide",
|
| 280 |
+
"tiled_divide",
|
| 281 |
+
"local_partition",
|
| 282 |
+
"local_tile",
|
| 283 |
+
# MMA and Copy operations
|
| 284 |
+
"make_mma_atom",
|
| 285 |
+
"make_tiled_mma",
|
| 286 |
+
"make_copy_atom",
|
| 287 |
+
"make_tiled_copy_tv",
|
| 288 |
+
"make_tiled_copy",
|
| 289 |
+
"make_tiled_copy_C_atom",
|
| 290 |
+
"basic_copy",
|
| 291 |
+
"basic_copy_if",
|
| 292 |
+
"autovec_copy",
|
| 293 |
+
"copy",
|
| 294 |
+
"copy_atom_call",
|
| 295 |
+
"gemm",
|
| 296 |
+
# Tensor creation
|
| 297 |
+
"full",
|
| 298 |
+
"full_like",
|
| 299 |
+
"empty_like",
|
| 300 |
+
"ones_like",
|
| 301 |
+
"zeros_like",
|
| 302 |
+
"where",
|
| 303 |
+
"any_",
|
| 304 |
+
"all_",
|
| 305 |
+
"repeat_like",
|
| 306 |
+
"ScaledBasis",
|
| 307 |
+
# User defined struct
|
| 308 |
+
"struct",
|
| 309 |
+
# Modules
|
| 310 |
+
"arch",
|
| 311 |
+
"nvgpu",
|
| 312 |
+
"testing",
|
| 313 |
+
"runtime",
|
| 314 |
+
# Decorators and code generation
|
| 315 |
+
"jit",
|
| 316 |
+
"kernel",
|
| 317 |
+
"register_jit_arg_adapter",
|
| 318 |
+
"compile",
|
| 319 |
+
]
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/__init__.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from .elect import *
|
| 13 |
+
from .mbar import *
|
| 14 |
+
from .nvvm_wrappers import *
|
| 15 |
+
from .smem import *
|
| 16 |
+
from .tmem import *
|
| 17 |
+
|
| 18 |
+
# __all__ is required here for documentation generation
|
| 19 |
+
__all__ = [
|
| 20 |
+
#
|
| 21 |
+
# elect.py
|
| 22 |
+
#
|
| 23 |
+
"make_warp_uniform",
|
| 24 |
+
"elect_one",
|
| 25 |
+
#
|
| 26 |
+
# mbar.py
|
| 27 |
+
#
|
| 28 |
+
"mbarrier_init",
|
| 29 |
+
"mbarrier_init_fence",
|
| 30 |
+
"mbarrier_arrive_and_expect_tx",
|
| 31 |
+
"mbarrier_expect_tx",
|
| 32 |
+
"mbarrier_wait",
|
| 33 |
+
"mbarrier_try_wait",
|
| 34 |
+
"mbarrier_conditional_try_wait",
|
| 35 |
+
"mbarrier_arrive",
|
| 36 |
+
#
|
| 37 |
+
# nvvm_wrappers.py
|
| 38 |
+
#
|
| 39 |
+
"lane_idx",
|
| 40 |
+
"warp_idx",
|
| 41 |
+
"thread_idx",
|
| 42 |
+
"block_dim",
|
| 43 |
+
"block_idx",
|
| 44 |
+
"grid_dim",
|
| 45 |
+
"cluster_idx",
|
| 46 |
+
"cluster_dim",
|
| 47 |
+
"block_in_cluster_idx",
|
| 48 |
+
"block_in_cluster_dim",
|
| 49 |
+
"block_idx_in_cluster",
|
| 50 |
+
"shuffle_sync",
|
| 51 |
+
"shuffle_sync_up",
|
| 52 |
+
"shuffle_sync_down",
|
| 53 |
+
"shuffle_sync_bfly",
|
| 54 |
+
"barrier",
|
| 55 |
+
"barrier_arrive",
|
| 56 |
+
"sync_threads",
|
| 57 |
+
"sync_warp",
|
| 58 |
+
"fence_acq_rel_cta",
|
| 59 |
+
"fence_acq_rel_cluster",
|
| 60 |
+
"fence_acq_rel_gpu",
|
| 61 |
+
"fence_acq_rel_sys",
|
| 62 |
+
"cp_async_commit_group",
|
| 63 |
+
"cp_async_wait_group",
|
| 64 |
+
"cp_async_bulk_commit_group",
|
| 65 |
+
"cp_async_bulk_wait_group",
|
| 66 |
+
"cluster_wait",
|
| 67 |
+
"cluster_arrive",
|
| 68 |
+
"cluster_arrive_relaxed",
|
| 69 |
+
"fence_proxy",
|
| 70 |
+
"vote_ballot_sync",
|
| 71 |
+
"popc",
|
| 72 |
+
"fence_view_async_tmem_load",
|
| 73 |
+
"fence_view_async_tmem_store",
|
| 74 |
+
"warpgroup_reg_alloc",
|
| 75 |
+
"warpgroup_reg_dealloc",
|
| 76 |
+
"fma_packed_f32x2",
|
| 77 |
+
"mul_packed_f32x2",
|
| 78 |
+
"add_packed_f32x2",
|
| 79 |
+
"fmax",
|
| 80 |
+
"rcp_approx",
|
| 81 |
+
"exp2",
|
| 82 |
+
# Constants
|
| 83 |
+
"WARP_SIZE",
|
| 84 |
+
# Forward from auto-generated nvvm python
|
| 85 |
+
"ProxyKind",
|
| 86 |
+
"SharedSpace",
|
| 87 |
+
"RoundingModeKind",
|
| 88 |
+
#
|
| 89 |
+
# smem.py
|
| 90 |
+
#
|
| 91 |
+
"alloc_smem",
|
| 92 |
+
"get_dyn_smem",
|
| 93 |
+
"get_dyn_smem_size",
|
| 94 |
+
#
|
| 95 |
+
# tmem.py
|
| 96 |
+
#
|
| 97 |
+
"retrieve_tmem_ptr",
|
| 98 |
+
"alloc_tmem",
|
| 99 |
+
"relinquish_tmem_alloc_permit",
|
| 100 |
+
"dealloc_tmem",
|
| 101 |
+
]
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/elect.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from cutlass.cutlass_dsl import CuTeDSL, T, dsl_user_op
|
| 13 |
+
|
| 14 |
+
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
| 15 |
+
from cutlass._mlir.dialects import nvvm, scf
|
| 16 |
+
from cutlass._mlir import ir
|
| 17 |
+
|
| 18 |
+
from ..typing import Int, Int32
|
| 19 |
+
from ...impl_utils import check_value_in
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dsl_user_op
|
| 23 |
+
def make_warp_uniform(value: Int, *, loc=None, ip=None) -> Int32:
|
| 24 |
+
"""
|
| 25 |
+
Creates a warp-uniform value from the given integer input.
|
| 26 |
+
|
| 27 |
+
:param value: The integer to make warp uniform.
|
| 28 |
+
:type value: Int
|
| 29 |
+
:return: The warp-uniform value equal to the input.
|
| 30 |
+
:rtype: Int32
|
| 31 |
+
"""
|
| 32 |
+
return Int32(
|
| 33 |
+
_cute_nvgpu_ir.arch_make_warp_uniform(
|
| 34 |
+
Int32(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
| 35 |
+
)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class IfOpRegion:
|
| 40 |
+
"""
|
| 41 |
+
A context manager for if Op.
|
| 42 |
+
Automatically inserts `scf.yield([])` when exiting the context.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, block, *, loc=None, ip=None):
|
| 46 |
+
self.block = block
|
| 47 |
+
self.insert_point = ir.InsertionPoint(self.block)
|
| 48 |
+
self.loc = loc
|
| 49 |
+
self.ip = ip
|
| 50 |
+
|
| 51 |
+
def __enter__(self):
|
| 52 |
+
self.insert_point.__enter__()
|
| 53 |
+
return self.block.arguments
|
| 54 |
+
|
| 55 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 56 |
+
scf.yield_([], loc=self.loc, ip=self.ip)
|
| 57 |
+
self.insert_point.__exit__(exc_type, exc_value, traceback)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dsl_user_op
|
| 61 |
+
def elect_one(*, loc=None, ip=None) -> IfOpRegion:
|
| 62 |
+
"""
|
| 63 |
+
Elects one thread within a warp.
|
| 64 |
+
|
| 65 |
+
.. code-block:: python
|
| 66 |
+
|
| 67 |
+
with elect_one():
|
| 68 |
+
# Only one thread in the warp executes the code in this context
|
| 69 |
+
pass
|
| 70 |
+
"""
|
| 71 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 72 |
+
check_value_in(
|
| 73 |
+
arch,
|
| 74 |
+
[
|
| 75 |
+
"sm_90",
|
| 76 |
+
"sm_90a",
|
| 77 |
+
"sm_100a",
|
| 78 |
+
"sm_100f",
|
| 79 |
+
],
|
| 80 |
+
"arch",
|
| 81 |
+
)
|
| 82 |
+
is_thread_leader = nvvm.elect_sync(T.bool())
|
| 83 |
+
if_op = scf.IfOp(is_thread_leader, loc=loc, ip=ip)
|
| 84 |
+
return IfOpRegion(if_op.then_block, loc=loc, ip=ip)
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/mbar.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op
|
| 14 |
+
|
| 15 |
+
from cutlass._mlir.dialects import nvvm
|
| 16 |
+
from cutlass._mlir import ir
|
| 17 |
+
|
| 18 |
+
from ..typing import Pointer, Int, Boolean, Int32
|
| 19 |
+
from ...impl_utils import check_value_in
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
####################################################################################################
|
| 23 |
+
#
|
| 24 |
+
# Mbarrier management utilities
|
| 25 |
+
#
|
| 26 |
+
####################################################################################################
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dsl_user_op
|
| 30 |
+
def mbarrier_init(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None:
|
| 31 |
+
"""
|
| 32 |
+
Initializes a mbarrier with the specified thread arrival count.
|
| 33 |
+
|
| 34 |
+
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
| 35 |
+
:type mbar_ptr: Pointer
|
| 36 |
+
:param cnt: The arrival count of the mbarrier
|
| 37 |
+
:type cnt: Int
|
| 38 |
+
"""
|
| 39 |
+
nvvm.mbarrier_init_shared(
|
| 40 |
+
mbar_ptr.llvm_ptr, Int32(cnt).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dsl_user_op
|
| 45 |
+
def mbarrier_init_fence(*, loc=None, ip=None) -> None:
|
| 46 |
+
"""
|
| 47 |
+
A fence operation that applies to the mbarrier initializations.
|
| 48 |
+
"""
|
| 49 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 50 |
+
check_value_in(
|
| 51 |
+
arch,
|
| 52 |
+
[
|
| 53 |
+
"sm_90",
|
| 54 |
+
"sm_90a",
|
| 55 |
+
"sm_100a",
|
| 56 |
+
"sm_100f",
|
| 57 |
+
],
|
| 58 |
+
"arch",
|
| 59 |
+
)
|
| 60 |
+
nvvm.fence_mbarrier_init(loc=loc, ip=ip)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dsl_user_op
|
| 64 |
+
def mbarrier_arrive_and_expect_tx(
|
| 65 |
+
mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None
|
| 66 |
+
) -> None:
|
| 67 |
+
"""
|
| 68 |
+
Arrives on a mbarrier and expects a specified number of transaction bytes.
|
| 69 |
+
|
| 70 |
+
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
| 71 |
+
:type mbar_ptr: Pointer
|
| 72 |
+
:param bytes: The number of transaction bytes
|
| 73 |
+
:type bytes: Int
|
| 74 |
+
:param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to
|
| 75 |
+
the mbarrier is converted to a remote address in the peer CTA's
|
| 76 |
+
SMEM.
|
| 77 |
+
"""
|
| 78 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 79 |
+
check_value_in(
|
| 80 |
+
arch,
|
| 81 |
+
[
|
| 82 |
+
"sm_90",
|
| 83 |
+
"sm_90a",
|
| 84 |
+
"sm_100a",
|
| 85 |
+
"sm_100f",
|
| 86 |
+
],
|
| 87 |
+
"arch",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
| 91 |
+
if peer_cta_rank_in_cluster is not None:
|
| 92 |
+
mbar_llvm_ptr = nvvm.mapa_shared_cluster(
|
| 93 |
+
mbar_llvm_ptr.type,
|
| 94 |
+
mbar_llvm_ptr,
|
| 95 |
+
Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
|
| 96 |
+
loc=loc,
|
| 97 |
+
ip=ip,
|
| 98 |
+
)
|
| 99 |
+
space = nvvm.MBarrierSpaceKind.CLUSTER
|
| 100 |
+
else:
|
| 101 |
+
space = nvvm.MBarrierSpaceKind.CTA
|
| 102 |
+
|
| 103 |
+
nvvm.mbarrier_txn(
|
| 104 |
+
mbar_llvm_ptr,
|
| 105 |
+
Int32(bytes).ir_value(loc=loc, ip=ip),
|
| 106 |
+
kind=nvvm.MBarrierTxnKind.ARRIVE_EXPECT_TX,
|
| 107 |
+
space=space,
|
| 108 |
+
loc=loc,
|
| 109 |
+
ip=ip,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@dsl_user_op
|
| 114 |
+
def mbarrier_expect_tx(
|
| 115 |
+
mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None
|
| 116 |
+
) -> None:
|
| 117 |
+
"""
|
| 118 |
+
Expects a specified number of transaction bytes without an arrive.
|
| 119 |
+
|
| 120 |
+
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
| 121 |
+
:type mbar_ptr: Pointer
|
| 122 |
+
:param bytes: The number of transaction bytes
|
| 123 |
+
:type bytes: Int
|
| 124 |
+
:param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to
|
| 125 |
+
the mbarrier is converted to a remote address in the peer CTA's
|
| 126 |
+
SMEM.
|
| 127 |
+
"""
|
| 128 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 129 |
+
check_value_in(
|
| 130 |
+
arch,
|
| 131 |
+
[
|
| 132 |
+
"sm_90",
|
| 133 |
+
"sm_90a",
|
| 134 |
+
"sm_100a",
|
| 135 |
+
"sm_100f",
|
| 136 |
+
],
|
| 137 |
+
"arch",
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
| 141 |
+
if peer_cta_rank_in_cluster is not None:
|
| 142 |
+
mbar_llvm_ptr = nvvm.mapa(
|
| 143 |
+
mbar_llvm_ptr.type,
|
| 144 |
+
mbar_llvm_ptr,
|
| 145 |
+
Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
|
| 146 |
+
loc=loc,
|
| 147 |
+
ip=ip,
|
| 148 |
+
)
|
| 149 |
+
space = nvvm.MBarrierSpaceKind.CLUSTER
|
| 150 |
+
else:
|
| 151 |
+
space = nvvm.MBarrierSpaceKind.CTA
|
| 152 |
+
|
| 153 |
+
nvvm.mbarrier_txn(
|
| 154 |
+
mbar_llvm_ptr,
|
| 155 |
+
Int32(bytes).ir_value(loc=loc, ip=ip),
|
| 156 |
+
kind=nvvm.MBarrierTxnKind.EXPECT_TX,
|
| 157 |
+
space=space,
|
| 158 |
+
loc=loc,
|
| 159 |
+
ip=ip,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@dsl_user_op
|
| 164 |
+
def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None:
|
| 165 |
+
"""
|
| 166 |
+
Waits on a mbarrier with a specified phase.
|
| 167 |
+
|
| 168 |
+
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
| 169 |
+
:type mbar_ptr: Pointer
|
| 170 |
+
:param phase: The phase to wait for (either 0 or 1)
|
| 171 |
+
:type phase: Int
|
| 172 |
+
"""
|
| 173 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 174 |
+
check_value_in(
|
| 175 |
+
arch,
|
| 176 |
+
[
|
| 177 |
+
"sm_90",
|
| 178 |
+
"sm_90a",
|
| 179 |
+
"sm_100a",
|
| 180 |
+
"sm_100f",
|
| 181 |
+
],
|
| 182 |
+
"arch",
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
timeout_ns = 10000000
|
| 186 |
+
# This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX
|
| 187 |
+
# The timeout in ns only applies to the latter and this call is truly blocking
|
| 188 |
+
nvvm.mbarrier_try_wait_parity_shared(
|
| 189 |
+
mbar_ptr.llvm_ptr,
|
| 190 |
+
Int32(phase).ir_value(loc=loc, ip=ip),
|
| 191 |
+
Int32(timeout_ns).ir_value(loc=loc, ip=ip),
|
| 192 |
+
loc=loc,
|
| 193 |
+
ip=ip,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@dsl_user_op
|
| 198 |
+
def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Boolean:
|
| 199 |
+
"""
|
| 200 |
+
Attempts to wait on a mbarrier with a specified phase in a non-blocking fashion.
|
| 201 |
+
|
| 202 |
+
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
| 203 |
+
:type mbar_ptr: Pointer
|
| 204 |
+
:param phase: The phase to wait for (either 0 or 1)
|
| 205 |
+
:type phase: Int
|
| 206 |
+
:return: A boolean value indicating whether the wait operation was successful
|
| 207 |
+
:rtype: Boolean
|
| 208 |
+
"""
|
| 209 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 210 |
+
check_value_in(
|
| 211 |
+
arch,
|
| 212 |
+
[
|
| 213 |
+
"sm_90",
|
| 214 |
+
"sm_90a",
|
| 215 |
+
"sm_100a",
|
| 216 |
+
"sm_100f",
|
| 217 |
+
],
|
| 218 |
+
"arch",
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
return Boolean(
|
| 222 |
+
nvvm.mbarrier_wait_parity(
|
| 223 |
+
T.bool(),
|
| 224 |
+
mbar_ptr.llvm_ptr,
|
| 225 |
+
Int32(phase).ir_value(loc=loc, ip=ip),
|
| 226 |
+
nvvm.MBarrierWaitKind.TRY,
|
| 227 |
+
loc=loc,
|
| 228 |
+
ip=ip,
|
| 229 |
+
)
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@dsl_user_op
|
| 234 |
+
def mbarrier_conditional_try_wait(
|
| 235 |
+
cond, mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None
|
| 236 |
+
) -> Boolean:
|
| 237 |
+
"""
|
| 238 |
+
Conditionally attempts to wait on a mbarrier with a specified phase in a non-blocking fashion.
|
| 239 |
+
|
| 240 |
+
:param cond: A boolean predicate
|
| 241 |
+
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
| 242 |
+
:type mbar_ptr: Pointer
|
| 243 |
+
:param phase: The phase to wait for (either 0 or 1)
|
| 244 |
+
:type phase: Int
|
| 245 |
+
:return: A boolean value indicating whether the wait operation was successful
|
| 246 |
+
:rtype: Boolean
|
| 247 |
+
"""
|
| 248 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 249 |
+
check_value_in(
|
| 250 |
+
arch,
|
| 251 |
+
[
|
| 252 |
+
"sm_90",
|
| 253 |
+
"sm_90a",
|
| 254 |
+
"sm_100a",
|
| 255 |
+
"sm_100f",
|
| 256 |
+
],
|
| 257 |
+
"arch",
|
| 258 |
+
)
|
| 259 |
+
return if_generate(
|
| 260 |
+
cond,
|
| 261 |
+
lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip),
|
| 262 |
+
lambda: Boolean(True).ir_value(loc=loc, ip=ip),
|
| 263 |
+
None,
|
| 264 |
+
[Boolean],
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@dsl_user_op
|
| 269 |
+
def mbarrier_arrive(
|
| 270 |
+
mbar_ptr: Pointer,
|
| 271 |
+
peer_cta_rank_in_cluster: Optional[Int] = None,
|
| 272 |
+
*,
|
| 273 |
+
loc=None,
|
| 274 |
+
ip=None,
|
| 275 |
+
) -> None:
|
| 276 |
+
"""
|
| 277 |
+
Arrives on an mbarrier.
|
| 278 |
+
|
| 279 |
+
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
| 280 |
+
:type mbar_ptr: Pointer
|
| 281 |
+
:param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to
|
| 282 |
+
the mbarrier is converted to a remote address in the peer CTA's
|
| 283 |
+
SMEM.
|
| 284 |
+
"""
|
| 285 |
+
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
| 286 |
+
if peer_cta_rank_in_cluster is not None:
|
| 287 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 288 |
+
check_value_in(
|
| 289 |
+
arch,
|
| 290 |
+
[
|
| 291 |
+
"sm_90",
|
| 292 |
+
"sm_90a",
|
| 293 |
+
"sm_100a",
|
| 294 |
+
"sm_100f",
|
| 295 |
+
],
|
| 296 |
+
"arch",
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
mbar_llvm_ptr = nvvm.mapa_shared_cluster(
|
| 300 |
+
mbar_llvm_ptr.type,
|
| 301 |
+
mbar_llvm_ptr,
|
| 302 |
+
Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
|
| 303 |
+
loc=loc,
|
| 304 |
+
ip=ip,
|
| 305 |
+
)
|
| 306 |
+
space = nvvm.MBarrierSpaceKind.CLUSTER
|
| 307 |
+
else:
|
| 308 |
+
space = nvvm.MBarrierSpaceKind.CTA
|
| 309 |
+
|
| 310 |
+
nvvm.mbarrier_txn(
|
| 311 |
+
mbar_llvm_ptr,
|
| 312 |
+
Int32(1).ir_value(loc=loc, ip=ip),
|
| 313 |
+
kind=nvvm.MBarrierTxnKind.ARRIVE,
|
| 314 |
+
space=space,
|
| 315 |
+
loc=loc,
|
| 316 |
+
ip=ip,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@dsl_user_op
|
| 321 |
+
def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> None:
|
| 322 |
+
"""
|
| 323 |
+
Arrives on an mbarrier for async load **without incrementing** the arrival count
|
| 324 |
+
(`cp.async.mbarrier.arrive.shared ..., noinc=1`).
|
| 325 |
+
Used in the warp-specialized kernel when the non-TMA load warp(producer) is not the same
|
| 326 |
+
as the math/epilogue warp(consumer).
|
| 327 |
+
|
| 328 |
+
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
| 329 |
+
:type mbar_ptr: Pointer
|
| 330 |
+
"""
|
| 331 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 332 |
+
check_value_in(
|
| 333 |
+
arch,
|
| 334 |
+
[
|
| 335 |
+
"sm_90",
|
| 336 |
+
"sm_90a",
|
| 337 |
+
"sm_100a",
|
| 338 |
+
"sm_100f",
|
| 339 |
+
],
|
| 340 |
+
"arch",
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
| 344 |
+
nvvm.cp_async_mbarrier_arrive_shared(
|
| 345 |
+
mbar_llvm_ptr,
|
| 346 |
+
noinc=True,
|
| 347 |
+
loc=loc,
|
| 348 |
+
ip=ip,
|
| 349 |
+
)
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py
ADDED
|
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from functools import partial
|
| 13 |
+
from typing import Optional, Tuple, Union, Callable
|
| 14 |
+
from typing_extensions import deprecated
|
| 15 |
+
|
| 16 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 17 |
+
|
| 18 |
+
from cutlass._mlir import ir
|
| 19 |
+
from cutlass._mlir.dialects import llvm, nvvm, vector
|
| 20 |
+
|
| 21 |
+
# Forward nvvm enums
|
| 22 |
+
from cutlass._mlir.dialects.nvvm import (
|
| 23 |
+
ProxyKind,
|
| 24 |
+
SharedSpace,
|
| 25 |
+
Tcgen05WaitKind,
|
| 26 |
+
SetMaxRegisterAction,
|
| 27 |
+
RoundingModeKind,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
from ..typing import (
|
| 31 |
+
Int,
|
| 32 |
+
Boolean,
|
| 33 |
+
Int16,
|
| 34 |
+
Uint16,
|
| 35 |
+
Int32,
|
| 36 |
+
Uint32,
|
| 37 |
+
Int64,
|
| 38 |
+
Float32,
|
| 39 |
+
BFloat16,
|
| 40 |
+
Numeric,
|
| 41 |
+
as_numeric,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
WARP_SIZE = 32
|
| 45 |
+
FULL_MASK = 0xFFFFFFFF
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dsl_user_op
|
| 49 |
+
def lane_idx(*, loc=None, ip=None) -> Int32:
|
| 50 |
+
"""
|
| 51 |
+
Returns the lane index of the current thread within the warp.
|
| 52 |
+
"""
|
| 53 |
+
return Int32(nvvm.read_ptx_sreg_laneid(T.i32(), loc=loc, ip=ip))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dsl_user_op
|
| 57 |
+
def warp_idx(*, loc=None, ip=None) -> Int32:
|
| 58 |
+
"""
|
| 59 |
+
Returns the warp index within a CTA.
|
| 60 |
+
"""
|
| 61 |
+
warp_size = 32
|
| 62 |
+
tid_x = Int32(nvvm.read_ptx_sreg_tid_x(T.i32(), loc=loc, ip=ip))
|
| 63 |
+
tid_y = Int32(nvvm.read_ptx_sreg_tid_y(T.i32(), loc=loc, ip=ip))
|
| 64 |
+
tid_z = Int32(nvvm.read_ptx_sreg_tid_z(T.i32(), loc=loc, ip=ip))
|
| 65 |
+
ntid_x = Int32(nvvm.read_ptx_sreg_ntid_x(T.i32(), loc=loc, ip=ip))
|
| 66 |
+
ntid_y = Int32(nvvm.read_ptx_sreg_ntid_y(T.i32(), loc=loc, ip=ip))
|
| 67 |
+
tid = tid_x + tid_y * ntid_x + tid_z * ntid_x * ntid_y
|
| 68 |
+
return tid // warp_size
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dsl_user_op
|
| 72 |
+
def thread_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
| 73 |
+
"""
|
| 74 |
+
Returns the thread index within a CTA.
|
| 75 |
+
"""
|
| 76 |
+
return (
|
| 77 |
+
Int32(nvvm.read_ptx_sreg_tid_x(T.i32(), loc=loc, ip=ip)),
|
| 78 |
+
Int32(nvvm.read_ptx_sreg_tid_y(T.i32(), loc=loc, ip=ip)),
|
| 79 |
+
Int32(nvvm.read_ptx_sreg_tid_z(T.i32(), loc=loc, ip=ip)),
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@dsl_user_op
|
| 84 |
+
def block_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
| 85 |
+
"""
|
| 86 |
+
Returns the number of threads in each dimension of the CTA.
|
| 87 |
+
"""
|
| 88 |
+
return (
|
| 89 |
+
Int32(nvvm.read_ptx_sreg_ntid_x(T.i32(), loc=loc, ip=ip)),
|
| 90 |
+
Int32(nvvm.read_ptx_sreg_ntid_y(T.i32(), loc=loc, ip=ip)),
|
| 91 |
+
Int32(nvvm.read_ptx_sreg_ntid_z(T.i32(), loc=loc, ip=ip)),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@dsl_user_op
|
| 96 |
+
def block_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
| 97 |
+
"""
|
| 98 |
+
Returns the CTA identifier within a grid.
|
| 99 |
+
"""
|
| 100 |
+
return (
|
| 101 |
+
Int32(nvvm.read_ptx_sreg_ctaid_x(T.i32(), loc=loc, ip=ip)),
|
| 102 |
+
Int32(nvvm.read_ptx_sreg_ctaid_y(T.i32(), loc=loc, ip=ip)),
|
| 103 |
+
Int32(nvvm.read_ptx_sreg_ctaid_z(T.i32(), loc=loc, ip=ip)),
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@dsl_user_op
|
| 108 |
+
def grid_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
| 109 |
+
"""
|
| 110 |
+
Returns the number of CTAs in each dimension of the grid.
|
| 111 |
+
"""
|
| 112 |
+
return (
|
| 113 |
+
Int32(nvvm.read_ptx_sreg_nctaid_x(T.i32(), loc=loc, ip=ip)),
|
| 114 |
+
Int32(nvvm.read_ptx_sreg_nctaid_y(T.i32(), loc=loc, ip=ip)),
|
| 115 |
+
Int32(nvvm.read_ptx_sreg_nctaid_z(T.i32(), loc=loc, ip=ip)),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@dsl_user_op
|
| 120 |
+
def cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
| 121 |
+
"""
|
| 122 |
+
Returns the cluster identifier within a grid.
|
| 123 |
+
"""
|
| 124 |
+
return (
|
| 125 |
+
Int32(nvvm.read_ptx_sreg_clusterid_x(T.i32(), loc=loc, ip=ip)),
|
| 126 |
+
Int32(nvvm.read_ptx_sreg_clusterid_y(T.i32(), loc=loc, ip=ip)),
|
| 127 |
+
Int32(nvvm.read_ptx_sreg_clusterid_z(T.i32(), loc=loc, ip=ip)),
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@dsl_user_op
|
| 132 |
+
def cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
| 133 |
+
"""
|
| 134 |
+
Returns the number of clusters in each dimension of the grid.
|
| 135 |
+
"""
|
| 136 |
+
return (
|
| 137 |
+
Int32(nvvm.read_ptx_sreg_nclusterid_x(T.i32(), loc=loc, ip=ip)),
|
| 138 |
+
Int32(nvvm.read_ptx_sreg_nclusterid_y(T.i32(), loc=loc, ip=ip)),
|
| 139 |
+
Int32(nvvm.read_ptx_sreg_nclusterid_z(T.i32(), loc=loc, ip=ip)),
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@dsl_user_op
|
| 144 |
+
def block_in_cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
| 145 |
+
"""
|
| 146 |
+
Returns the CTA index within a cluster across all dimensions.
|
| 147 |
+
"""
|
| 148 |
+
return (
|
| 149 |
+
Int32(nvvm.read_ptx_sreg_cluster_ctaid_x(T.i32(), loc=loc, ip=ip)),
|
| 150 |
+
Int32(nvvm.read_ptx_sreg_cluster_ctaid_y(T.i32(), loc=loc, ip=ip)),
|
| 151 |
+
Int32(nvvm.read_ptx_sreg_cluster_ctaid_z(T.i32(), loc=loc, ip=ip)),
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@dsl_user_op
|
| 156 |
+
def block_in_cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]:
|
| 157 |
+
"""
|
| 158 |
+
Returns the dimensions of the cluster.
|
| 159 |
+
"""
|
| 160 |
+
return (
|
| 161 |
+
Int32(nvvm.read_ptx_sreg_cluster_nctaid_x(T.i32(), loc=loc, ip=ip)),
|
| 162 |
+
Int32(nvvm.read_ptx_sreg_cluster_nctaid_y(T.i32(), loc=loc, ip=ip)),
|
| 163 |
+
Int32(nvvm.read_ptx_sreg_cluster_nctaid_z(T.i32(), loc=loc, ip=ip)),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@dsl_user_op
|
| 168 |
+
def block_idx_in_cluster(*, loc=None, ip=None) -> Int32:
|
| 169 |
+
"""
|
| 170 |
+
Returns the linearized identifier of the CTA within the cluster.
|
| 171 |
+
"""
|
| 172 |
+
return Int32(nvvm.read_ptx_sreg_cluster_ctarank(T.i32(), loc=loc, ip=ip))
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@dsl_user_op
|
| 176 |
+
def shuffle_sync_op(
|
| 177 |
+
value: Numeric,
|
| 178 |
+
offset: Int,
|
| 179 |
+
mask: Int = FULL_MASK,
|
| 180 |
+
mask_and_clamp: Int = WARP_SIZE - 1,
|
| 181 |
+
kind: nvvm.ShflKind = nvvm.ShflKind.idx,
|
| 182 |
+
*,
|
| 183 |
+
loc=None,
|
| 184 |
+
ip=None,
|
| 185 |
+
) -> Numeric:
|
| 186 |
+
"""
|
| 187 |
+
Shuffles a value within the threads of a warp.
|
| 188 |
+
|
| 189 |
+
:param value: The value to shuffle
|
| 190 |
+
:type value: Numeric
|
| 191 |
+
:param mask: A mask describing the threads participating in this operation
|
| 192 |
+
:type mask: Int
|
| 193 |
+
:param offset: A source lane or a source lane offset depending on kind
|
| 194 |
+
:type offset: Int
|
| 195 |
+
:param mask_and_clamp: An integer containing two packed values specifying a mask for logically
|
| 196 |
+
splitting warps into sub-segments and an upper bound for clamping the
|
| 197 |
+
source lane index.
|
| 198 |
+
:type mask_and_clamp: Int
|
| 199 |
+
:param kind: The kind of shuffle, can be idx, up, down, or bfly
|
| 200 |
+
:type kind: ShflKind
|
| 201 |
+
:return: The shuffled value
|
| 202 |
+
:rtype: Numeric
|
| 203 |
+
"""
|
| 204 |
+
if not isinstance(value, Numeric):
|
| 205 |
+
value = as_numeric(value)
|
| 206 |
+
if value.width > 64:
|
| 207 |
+
raise ValueError("shuffle_sync only supports values up to 64 bits")
|
| 208 |
+
|
| 209 |
+
orig_type = type(value)
|
| 210 |
+
if value.width < 32:
|
| 211 |
+
if value.dtype.is_float:
|
| 212 |
+
value = value.to(Float32)
|
| 213 |
+
else:
|
| 214 |
+
if value.signed:
|
| 215 |
+
value = value.to(Int32)
|
| 216 |
+
else:
|
| 217 |
+
value = value.to(Uint32)
|
| 218 |
+
return orig_type(
|
| 219 |
+
nvvm.shfl_sync(
|
| 220 |
+
type(value).mlir_type,
|
| 221 |
+
Int32(mask).ir_value(loc=loc, ip=ip),
|
| 222 |
+
value.ir_value(loc=loc, ip=ip),
|
| 223 |
+
Int32(offset).ir_value(loc=loc, ip=ip),
|
| 224 |
+
Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
|
| 225 |
+
kind,
|
| 226 |
+
loc=loc,
|
| 227 |
+
ip=ip,
|
| 228 |
+
)
|
| 229 |
+
)
|
| 230 |
+
elif value.width == 32:
|
| 231 |
+
return orig_type(
|
| 232 |
+
nvvm.shfl_sync(
|
| 233 |
+
type(value).mlir_type,
|
| 234 |
+
Int32(mask).ir_value(loc=loc, ip=ip),
|
| 235 |
+
value.ir_value(loc=loc, ip=ip),
|
| 236 |
+
Int32(offset).ir_value(loc=loc, ip=ip),
|
| 237 |
+
Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
|
| 238 |
+
kind,
|
| 239 |
+
loc=loc,
|
| 240 |
+
ip=ip,
|
| 241 |
+
)
|
| 242 |
+
)
|
| 243 |
+
else:
|
| 244 |
+
if value.width != 64:
|
| 245 |
+
raise ValueError(
|
| 246 |
+
"shuffle_sync only supports 64 bits values when the bit width is larger than 32"
|
| 247 |
+
)
|
| 248 |
+
value = llvm.bitcast(
|
| 249 |
+
T.i64(), value.to(ir.Value, loc=loc, ip=ip), loc=loc, ip=ip
|
| 250 |
+
)
|
| 251 |
+
# extract low 32 bits
|
| 252 |
+
low_32_bits = llvm.trunc(
|
| 253 |
+
T.i32(), value, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip
|
| 254 |
+
)
|
| 255 |
+
# extract high 32 bits
|
| 256 |
+
high_32_bits = llvm.lshr(
|
| 257 |
+
value, Int64(32).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
| 258 |
+
)
|
| 259 |
+
high_32_bits = llvm.trunc(
|
| 260 |
+
T.i32(), high_32_bits, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
low_32_bits_shfl = nvvm.shfl_sync(
|
| 264 |
+
T.i32(),
|
| 265 |
+
Int32(mask).ir_value(loc=loc, ip=ip),
|
| 266 |
+
low_32_bits,
|
| 267 |
+
Int32(offset).ir_value(loc=loc, ip=ip),
|
| 268 |
+
Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
|
| 269 |
+
kind,
|
| 270 |
+
loc=loc,
|
| 271 |
+
ip=ip,
|
| 272 |
+
)
|
| 273 |
+
high_32_bits_shfl = nvvm.shfl_sync(
|
| 274 |
+
T.i32(),
|
| 275 |
+
Int32(mask).ir_value(loc=loc, ip=ip),
|
| 276 |
+
high_32_bits,
|
| 277 |
+
Int32(offset).ir_value(loc=loc, ip=ip),
|
| 278 |
+
Int32(mask_and_clamp).ir_value(loc=loc, ip=ip),
|
| 279 |
+
kind,
|
| 280 |
+
loc=loc,
|
| 281 |
+
ip=ip,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# combine low and high 32 bits
|
| 285 |
+
low_64_bit = llvm.zext(T.i64(), low_32_bits_shfl, loc=loc, ip=ip)
|
| 286 |
+
high_64_bit = llvm.zext(T.i64(), high_32_bits_shfl, loc=loc, ip=ip)
|
| 287 |
+
shlf_res = llvm.shl(
|
| 288 |
+
high_64_bit,
|
| 289 |
+
Int64(32).ir_value(loc=loc, ip=ip),
|
| 290 |
+
llvm.IntegerOverflowFlags.none,
|
| 291 |
+
loc=loc,
|
| 292 |
+
ip=ip,
|
| 293 |
+
)
|
| 294 |
+
shlf_res = llvm.or_(shlf_res, low_64_bit, loc=loc, ip=ip)
|
| 295 |
+
shlf_res = llvm.bitcast(orig_type.mlir_type, shlf_res, loc=loc, ip=ip)
|
| 296 |
+
return orig_type(shlf_res)
|
| 297 |
+
|
| 298 |
+
shuffle_sync = partial(shuffle_sync_op, kind=nvvm.ShflKind.idx)
|
| 299 |
+
shuffle_sync_up = partial(shuffle_sync_op, kind=nvvm.ShflKind.up)
|
| 300 |
+
shuffle_sync_down = partial(shuffle_sync_op, kind=nvvm.ShflKind.down)
|
| 301 |
+
shuffle_sync_bfly = partial(shuffle_sync_op, kind=nvvm.ShflKind.bfly)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@dsl_user_op
|
| 305 |
+
def barrier(*, barrier_id=None, number_of_threads=None, loc=None, ip=None) -> None:
|
| 306 |
+
"""
|
| 307 |
+
Creates a barrier, optionally named.
|
| 308 |
+
"""
|
| 309 |
+
if barrier_id is not None:
|
| 310 |
+
barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip)
|
| 311 |
+
|
| 312 |
+
if number_of_threads is not None:
|
| 313 |
+
number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip)
|
| 314 |
+
|
| 315 |
+
nvvm.barrier(
|
| 316 |
+
barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@dsl_user_op
|
| 321 |
+
def barrier_arrive(
|
| 322 |
+
*, barrier_id=None, number_of_threads=None, loc=None, ip=None
|
| 323 |
+
) -> None:
|
| 324 |
+
if barrier_id is not None:
|
| 325 |
+
barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip)
|
| 326 |
+
|
| 327 |
+
if number_of_threads is None:
|
| 328 |
+
raise ValueError(
|
| 329 |
+
"barrier_arrive needs pass number_of_threads to arrive the barrier",
|
| 330 |
+
)
|
| 331 |
+
number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip)
|
| 332 |
+
|
| 333 |
+
nvvm.barrier_arrive(
|
| 334 |
+
barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
@dsl_user_op
|
| 339 |
+
def sync_threads(*, loc=None, ip=None) -> None:
|
| 340 |
+
"""
|
| 341 |
+
Synchronizes all threads within a CTA.
|
| 342 |
+
"""
|
| 343 |
+
nvvm.barrier(loc=loc, ip=ip)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
@dsl_user_op
|
| 347 |
+
def sync_warp(mask: Int = FULL_MASK, *, loc=None, ip=None) -> None:
|
| 348 |
+
"""
|
| 349 |
+
Performs a warp-wide sync with an optional mask.
|
| 350 |
+
"""
|
| 351 |
+
nvvm.bar_warp_sync(Int32(mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
@dsl_user_op
|
| 355 |
+
def fence_acq_rel_cta(*, loc=None, ip=None) -> None:
|
| 356 |
+
"""
|
| 357 |
+
Fence operation with acquire-release semantics.
|
| 358 |
+
|
| 359 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
|
| 360 |
+
"""
|
| 361 |
+
nvvm.fence_acq_rel_cta(loc=loc, ip=ip)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
@dsl_user_op
|
| 365 |
+
def fence_acq_rel_cluster(*, loc=None, ip=None) -> None:
|
| 366 |
+
"""
|
| 367 |
+
Fence operation with acquire-release semantics.
|
| 368 |
+
|
| 369 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
|
| 370 |
+
"""
|
| 371 |
+
nvvm.fence_acq_rel_cluster(loc=loc, ip=ip)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
@dsl_user_op
|
| 375 |
+
def fence_acq_rel_gpu(*, loc=None, ip=None) -> None:
|
| 376 |
+
"""
|
| 377 |
+
Fence operation with acquire-release semantics.
|
| 378 |
+
|
| 379 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
|
| 380 |
+
"""
|
| 381 |
+
nvvm.fence_acq_rel_gpu(loc=loc, ip=ip)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
@dsl_user_op
|
| 385 |
+
def fence_acq_rel_sys(*, loc=None, ip=None) -> None:
|
| 386 |
+
"""
|
| 387 |
+
Fence operation with acquire-release semantics.
|
| 388 |
+
|
| 389 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
|
| 390 |
+
"""
|
| 391 |
+
nvvm.fence_acq_rel_sys(loc=loc, ip=ip)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
@dsl_user_op
|
| 395 |
+
def cp_async_commit_group(*, loc=None, ip=None) -> None:
|
| 396 |
+
"""
|
| 397 |
+
Commits all prior initiated but uncommitted cp.async instructions.
|
| 398 |
+
|
| 399 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-commit-group>`__.
|
| 400 |
+
"""
|
| 401 |
+
nvvm.cp_async_commit_group(loc=loc, ip=ip)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
@dsl_user_op
|
| 405 |
+
def cp_async_wait_group(n, *, loc=None, ip=None) -> None:
|
| 406 |
+
"""
|
| 407 |
+
Waits till only a specified numbers of cp.async groups are pending.
|
| 408 |
+
|
| 409 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-wait-group-cp-async-wait-all>`__.
|
| 410 |
+
"""
|
| 411 |
+
nvvm.cp_async_wait_group(n, loc=loc, ip=ip)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
@dsl_user_op
|
| 415 |
+
def cp_async_bulk_commit_group(*, loc=None, ip=None) -> None:
|
| 416 |
+
"""
|
| 417 |
+
Commits all prior initiated but uncommitted cp.async.bulk instructions.
|
| 418 |
+
|
| 419 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-commit-group>`__.
|
| 420 |
+
"""
|
| 421 |
+
nvvm.cp_async_bulk_commit_group(loc=loc, ip=ip)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
@dsl_user_op
|
| 425 |
+
def cp_async_bulk_wait_group(group, *, read=None, loc=None, ip=None) -> None:
|
| 426 |
+
"""
|
| 427 |
+
Waits till only a specified numbers of cp.async.bulk groups are pending.
|
| 428 |
+
|
| 429 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-wait-group>`__.
|
| 430 |
+
"""
|
| 431 |
+
nvvm.cp_async_bulk_wait_group(group, read=read, loc=loc, ip=ip)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
@dsl_user_op
|
| 435 |
+
def cluster_wait(*, loc=None, ip=None) -> None:
|
| 436 |
+
"""
|
| 437 |
+
A cluster-wide wait operation.
|
| 438 |
+
"""
|
| 439 |
+
nvvm.cluster_wait(loc=loc, ip=ip)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
@dsl_user_op
|
| 443 |
+
def cluster_arrive(*, aligned=None, loc=None, ip=None) -> None:
|
| 444 |
+
"""
|
| 445 |
+
A cluster-wide arrive operation.
|
| 446 |
+
"""
|
| 447 |
+
nvvm.cluster_arrive(aligned=aligned, loc=loc, ip=ip)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
@dsl_user_op
|
| 451 |
+
def cluster_arrive_relaxed(*, aligned=None, loc=None, ip=None) -> None:
|
| 452 |
+
"""
|
| 453 |
+
A cluster-wide arrive operation with relaxed semantics.
|
| 454 |
+
"""
|
| 455 |
+
nvvm.cluster_arrive_relaxed(aligned=aligned, loc=loc, ip=ip)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
@dsl_user_op
|
| 459 |
+
def fence_proxy(
|
| 460 |
+
kind: ProxyKind,
|
| 461 |
+
*,
|
| 462 |
+
space: Optional[SharedSpace] = None,
|
| 463 |
+
use_intrinsic=None,
|
| 464 |
+
loc=None,
|
| 465 |
+
ip=None,
|
| 466 |
+
) -> None:
|
| 467 |
+
nvvm.fence_proxy(
|
| 468 |
+
kind=kind, space=space, use_intrinsic=use_intrinsic, loc=loc, ip=ip
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
@dsl_user_op
|
| 473 |
+
def vote_ballot_sync(
|
| 474 |
+
pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None
|
| 475 |
+
) -> Int32:
|
| 476 |
+
"""
|
| 477 |
+
Performs a ballot operation across the warp.
|
| 478 |
+
"""
|
| 479 |
+
return Int32(
|
| 480 |
+
nvvm.vote_ballot_sync(
|
| 481 |
+
T.i32(),
|
| 482 |
+
Int32(mask).ir_value(loc=loc, ip=ip),
|
| 483 |
+
Boolean(pred).ir_value(loc=loc, ip=ip),
|
| 484 |
+
loc=loc,
|
| 485 |
+
ip=ip,
|
| 486 |
+
)
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
@dsl_user_op
|
| 491 |
+
def popc(value: Numeric, *, loc=None, ip=None) -> Numeric:
|
| 492 |
+
"""
|
| 493 |
+
Performs a population count operation.
|
| 494 |
+
"""
|
| 495 |
+
if not isinstance(value, Numeric):
|
| 496 |
+
value = as_numeric(value)
|
| 497 |
+
return type(value)(llvm.intr_ctpop(value.ir_value(), loc=loc, ip=ip))
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
@dsl_user_op
|
| 501 |
+
def fence_view_async_tmem_op(
|
| 502 |
+
kind: Tcgen05WaitKind,
|
| 503 |
+
*,
|
| 504 |
+
loc=None,
|
| 505 |
+
ip=None,
|
| 506 |
+
) -> None:
|
| 507 |
+
"""
|
| 508 |
+
Perform a fence operation on the async TMEM load or store.
|
| 509 |
+
|
| 510 |
+
.. note::
|
| 511 |
+
This function is only available on sm_100a and above.
|
| 512 |
+
The fence is required to synchronize the TMEM load/store
|
| 513 |
+
and let the pipeline release or commit the buffer.
|
| 514 |
+
|
| 515 |
+
Take a mma2acc pipeline as an example of LOAD fence, the ACC tensor is from TMEM.
|
| 516 |
+
```
|
| 517 |
+
# Start to copy ACC from TMEM to register
|
| 518 |
+
cute.copy(tmem_load, tACC, rACC)
|
| 519 |
+
fence_view_async_tmem_load()
|
| 520 |
+
# After fence, we can ensure the TMEM buffer is consumed totally.
|
| 521 |
+
# Release the buffer to let the MMA know it can overwrite the buffer.
|
| 522 |
+
mma2accum_pipeline.consumer_release(curr_consumer_state)
|
| 523 |
+
```
|
| 524 |
+
Take a TS GEMM kernel as an example of STORE fence, the A tensor is from TMEM.
|
| 525 |
+
```
|
| 526 |
+
# Start to copy A from register to TMEM
|
| 527 |
+
cute.copy(tmem_store, rA, tA)
|
| 528 |
+
fence_view_async_tmem_store()
|
| 529 |
+
# After fence, we can ensure the TMEM buffer is ready.
|
| 530 |
+
# Commit the buffer to let the MMA know it can start to load A.
|
| 531 |
+
tmem_mma_pipeline.producer_commit(curr_producer_state)
|
| 532 |
+
```
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
:param kind: The kind of fence operation to perform including LOAD and STORE.
|
| 536 |
+
:type kind: Tcgen05WaitKind
|
| 537 |
+
"""
|
| 538 |
+
nvvm.tcgen05_wait(kind, loc=loc, ip=ip)
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
fence_view_async_tmem_load = partial(
|
| 542 |
+
fence_view_async_tmem_op, kind=Tcgen05WaitKind.LOAD
|
| 543 |
+
)
|
| 544 |
+
fence_view_async_tmem_store = partial(
|
| 545 |
+
fence_view_async_tmem_op, kind=Tcgen05WaitKind.STORE
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
@dsl_user_op
|
| 550 |
+
def warpgroup_reg_realloc_op(
|
| 551 |
+
reg_count: int,
|
| 552 |
+
kind: SetMaxRegisterAction,
|
| 553 |
+
*,
|
| 554 |
+
loc=None,
|
| 555 |
+
ip=None,
|
| 556 |
+
) -> None:
|
| 557 |
+
nvvm.setmaxregister(reg_count, kind, loc=loc, ip=ip)
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
warpgroup_reg_alloc = partial(
|
| 561 |
+
warpgroup_reg_realloc_op, kind=SetMaxRegisterAction.increase
|
| 562 |
+
)
|
| 563 |
+
warpgroup_reg_dealloc = partial(
|
| 564 |
+
warpgroup_reg_realloc_op, kind=SetMaxRegisterAction.decrease
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
@dsl_user_op
|
| 569 |
+
def calc_packed_f32x2_op(
|
| 570 |
+
src_a: Tuple[Float32, Float32],
|
| 571 |
+
src_b: Tuple[Float32, Float32],
|
| 572 |
+
src_c: Tuple[Float32, Float32] | None,
|
| 573 |
+
calc_func: Callable,
|
| 574 |
+
*,
|
| 575 |
+
rnd=RoundingModeKind.RZ,
|
| 576 |
+
ftz=True,
|
| 577 |
+
loc=None,
|
| 578 |
+
ip=None,
|
| 579 |
+
) -> Tuple[Float32, Float32]:
|
| 580 |
+
vec_type = ir.VectorType.get([2], Float32.mlir_type, loc=loc)
|
| 581 |
+
vec_src_a = vector.from_elements(
|
| 582 |
+
vec_type, tuple(as_numeric(a).ir_value() for a in src_a), loc=loc, ip=ip
|
| 583 |
+
)
|
| 584 |
+
vec_src_b = vector.from_elements(
|
| 585 |
+
vec_type, tuple(as_numeric(b).ir_value() for b in src_b), loc=loc, ip=ip
|
| 586 |
+
)
|
| 587 |
+
if src_c is not None:
|
| 588 |
+
vec_src_c = vector.from_elements(
|
| 589 |
+
vec_type, tuple(as_numeric(c).ir_value() for c in src_c), loc=loc, ip=ip
|
| 590 |
+
)
|
| 591 |
+
vec_res = calc_func(
|
| 592 |
+
vec_type, vec_src_a, vec_src_b, vec_src_c, rnd=rnd, ftz=ftz, loc=loc, ip=ip
|
| 593 |
+
)
|
| 594 |
+
else:
|
| 595 |
+
vec_res = calc_func(
|
| 596 |
+
vec_type, vec_src_a, vec_src_b, rnd=rnd, ftz=ftz, loc=loc, ip=ip
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
res0 = Float32(
|
| 600 |
+
vector.extract(
|
| 601 |
+
vec_res, dynamic_position=[], static_position=[0], loc=loc, ip=ip
|
| 602 |
+
)
|
| 603 |
+
)
|
| 604 |
+
res1 = Float32(
|
| 605 |
+
vector.extract(
|
| 606 |
+
vec_res, dynamic_position=[], static_position=[1], loc=loc, ip=ip
|
| 607 |
+
)
|
| 608 |
+
)
|
| 609 |
+
return res0, res1
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
fma_packed_f32x2 = partial(calc_packed_f32x2_op, calc_func=nvvm.fma_packed_f32x2)
|
| 613 |
+
mul_packed_f32x2 = partial(
|
| 614 |
+
calc_packed_f32x2_op, src_c=None, calc_func=nvvm.mul_packed_f32x2
|
| 615 |
+
)
|
| 616 |
+
add_packed_f32x2 = partial(
|
| 617 |
+
calc_packed_f32x2_op, src_c=None, calc_func=nvvm.add_packed_f32x2
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
@dsl_user_op
|
| 622 |
+
def fmax(
|
| 623 |
+
a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None
|
| 624 |
+
) -> Float32:
|
| 625 |
+
return Float32(
|
| 626 |
+
nvvm.fmax(
|
| 627 |
+
T.f32(),
|
| 628 |
+
Float32(a).ir_value(loc=loc, ip=ip),
|
| 629 |
+
Float32(b).ir_value(loc=loc, ip=ip),
|
| 630 |
+
loc=loc,
|
| 631 |
+
ip=ip,
|
| 632 |
+
)
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
@dsl_user_op
|
| 637 |
+
def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None):
|
| 638 |
+
return Float32(
|
| 639 |
+
nvvm.rcp_approx_ftz_f(
|
| 640 |
+
T.f32(), Float32(a).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
| 641 |
+
)
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
@dsl_user_op
|
| 646 |
+
@deprecated(
|
| 647 |
+
"cute.arch.exp2 is deprecated, use cute.math.exp2 with `fastmath=True` instead"
|
| 648 |
+
)
|
| 649 |
+
def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
|
| 650 |
+
return Float32(
|
| 651 |
+
llvm.inline_asm(
|
| 652 |
+
T.f32(),
|
| 653 |
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
| 654 |
+
"ex2.approx.ftz.f32 $0, $1;",
|
| 655 |
+
"=f,f",
|
| 656 |
+
has_side_effects=True,
|
| 657 |
+
is_align_stack=False,
|
| 658 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 659 |
+
)
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
@dsl_user_op
|
| 664 |
+
@deprecated(
|
| 665 |
+
"cute.arch.exp is deprecated, use cute.math.exp with `fastmath=True` instead"
|
| 666 |
+
)
|
| 667 |
+
def exp(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
|
| 668 |
+
LOG2_E = 1.4426950408889634
|
| 669 |
+
return exp2(a * LOG2_E, loc=loc, ip=ip)
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
@dsl_user_op
|
| 673 |
+
@deprecated(
|
| 674 |
+
"cute.arch.exp_packed_f32x2 is deprecated, use cute.arch.mul_packed_f32x2 and cute.math.exp2 with `fastmath=True` instead"
|
| 675 |
+
)
|
| 676 |
+
def exp_packed_f32x2(
|
| 677 |
+
a: Tuple[Float32, Float32], *, loc=None, ip=None
|
| 678 |
+
) -> Tuple[Float32, Float32]:
|
| 679 |
+
LOG2_E = Float32(1.4426950408889634)
|
| 680 |
+
b = mul_packed_f32x2(a, (LOG2_E, LOG2_E), loc=loc, ip=ip)
|
| 681 |
+
return exp2(b[0], loc=loc, ip=ip), exp2(b[1], loc=loc, ip=ip)
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/smem.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from typing import Optional, Type
|
| 13 |
+
|
| 14 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 15 |
+
|
| 16 |
+
import cutlass._mlir.dialects.cute as _cute_ir
|
| 17 |
+
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
| 18 |
+
from cutlass._mlir import ir
|
| 19 |
+
|
| 20 |
+
from ..typing import Pointer, Numeric, NumericMeta
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dsl_user_op
|
| 24 |
+
def alloc_smem(
|
| 25 |
+
element_type: Type[Numeric],
|
| 26 |
+
size_in_elems: int,
|
| 27 |
+
alignment: Optional[int] = None,
|
| 28 |
+
*,
|
| 29 |
+
loc=None,
|
| 30 |
+
ip=None,
|
| 31 |
+
) -> Pointer:
|
| 32 |
+
"""
|
| 33 |
+
Statically allocates SMEM.
|
| 34 |
+
|
| 35 |
+
:param element_type: The pointee type of the pointer.
|
| 36 |
+
:type element_type: Type[Numeric]
|
| 37 |
+
:param size_in_elems: The size of the allocation in terms of number of elements of the
|
| 38 |
+
pointee type
|
| 39 |
+
:type size_in_elems: int
|
| 40 |
+
:param alignment: An optional pointer alignment for the allocation
|
| 41 |
+
:type alignment: int
|
| 42 |
+
:return: A pointer to the start of the allocation
|
| 43 |
+
:rtype: Pointer
|
| 44 |
+
"""
|
| 45 |
+
if not isinstance(element_type, NumericMeta):
|
| 46 |
+
raise TypeError(
|
| 47 |
+
f"element_type must be a type of Numeric, but got {element_type}"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
if alignment is None:
|
| 51 |
+
# Default alignment based on the element type's width
|
| 52 |
+
alignment = element_type.width // 8
|
| 53 |
+
ptr_ty = _cute_ir.PtrType.get(
|
| 54 |
+
element_type.mlir_type, _cute_ir.AddressSpace.smem, alignment
|
| 55 |
+
)
|
| 56 |
+
return _cute_nvgpu_ir.arch_alloc_smem(
|
| 57 |
+
ptr=ptr_ty,
|
| 58 |
+
input=ir.IntegerAttr.get(T.i32(), size_in_elems),
|
| 59 |
+
loc=loc,
|
| 60 |
+
ip=ip,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dsl_user_op
|
| 65 |
+
def get_dyn_smem(
|
| 66 |
+
element_type: Type[Numeric],
|
| 67 |
+
alignment: Optional[int] = None,
|
| 68 |
+
*,
|
| 69 |
+
loc=None,
|
| 70 |
+
ip=None,
|
| 71 |
+
) -> Pointer:
|
| 72 |
+
"""
|
| 73 |
+
Retrieves a pointer to a dynamic SMEM allocation.
|
| 74 |
+
|
| 75 |
+
:param element_type: The pointee type of the pointer.
|
| 76 |
+
:type element_type: Type[Numeric]
|
| 77 |
+
:param alignment: An optional pointer alignment, the result pointer is offset appropriately
|
| 78 |
+
:type alignment: int
|
| 79 |
+
:return: A pointer to the start of the dynamic SMEM allocation with a correct
|
| 80 |
+
alignement
|
| 81 |
+
:rtype: Pointer
|
| 82 |
+
"""
|
| 83 |
+
if not isinstance(element_type, NumericMeta):
|
| 84 |
+
raise TypeError(
|
| 85 |
+
f"element_type must be a type of Numeric, but got {element_type}"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if alignment is None:
|
| 89 |
+
# Default alignment based on the element type's width
|
| 90 |
+
alignment = element_type.width // 8
|
| 91 |
+
ptr_ty = _cute_ir.PtrType.get(
|
| 92 |
+
element_type.mlir_type,
|
| 93 |
+
_cute_ir.AddressSpace.smem,
|
| 94 |
+
alignment,
|
| 95 |
+
)
|
| 96 |
+
return _cute_nvgpu_ir.arch_get_dyn_smem(ptr=ptr_ty, loc=loc, ip=ip)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@dsl_user_op
|
| 100 |
+
def get_dyn_smem_size(*, loc=None, ip=None) -> int:
|
| 101 |
+
"""
|
| 102 |
+
Gets the size in bytes of the dynamic shared memory that was specified at kernel launch time.
|
| 103 |
+
This can be used for bounds checking during shared memory allocation.
|
| 104 |
+
|
| 105 |
+
:return: The size of dynamic shared memory in bytes
|
| 106 |
+
:rtype: int
|
| 107 |
+
"""
|
| 108 |
+
return _cute_nvgpu_ir.arch_get_dyn_smem_size(loc=loc, ip=ip)
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/tmem.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from typing import Type
|
| 13 |
+
|
| 14 |
+
from cutlass.cutlass_dsl import dsl_user_op
|
| 15 |
+
|
| 16 |
+
import cutlass._mlir.dialects.cute as _cute_ir
|
| 17 |
+
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
| 18 |
+
|
| 19 |
+
from ..typing import Pointer, Int, Int32, Numeric, NumericMeta
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
SM100_TMEM_CAPACITY_COLUMNS = 512
|
| 23 |
+
SM100_TMEM_MIN_ALLOC_COLUMNS = 32
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dsl_user_op
|
| 27 |
+
def retrieve_tmem_ptr(
|
| 28 |
+
element_type: Type[Numeric],
|
| 29 |
+
alignment: int,
|
| 30 |
+
ptr_to_buffer_holding_addr: Pointer,
|
| 31 |
+
*,
|
| 32 |
+
loc=None,
|
| 33 |
+
ip=None,
|
| 34 |
+
) -> Pointer:
|
| 35 |
+
"""
|
| 36 |
+
Retrieves a pointer to TMEM with the provided element type and alignment.
|
| 37 |
+
|
| 38 |
+
:param element_type: The pointee type of the pointer.
|
| 39 |
+
:type element_type: Type[Numeric]
|
| 40 |
+
:param alignment: The alignment of the result pointer
|
| 41 |
+
:type alignment: int
|
| 42 |
+
:param ptr_to_buffer_holding_addr: A pointer to a SMEM buffer holding the TMEM address of the
|
| 43 |
+
start of the allocation allocation
|
| 44 |
+
:type ptr_to_buffer_holding_addr: Pointer
|
| 45 |
+
:return: A pointer to TMEM
|
| 46 |
+
:rtype: Pointer
|
| 47 |
+
"""
|
| 48 |
+
if not isinstance(element_type, NumericMeta):
|
| 49 |
+
raise TypeError(
|
| 50 |
+
f"element_type must be a type of Numeric, but got {element_type}"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
res_ty = _cute_ir.PtrType.get(
|
| 54 |
+
element_type.mlir_type, _cute_ir.AddressSpace.tmem, alignment
|
| 55 |
+
)
|
| 56 |
+
return _cute_nvgpu_ir.arch_sm100_retrieve_tmem_ptr(
|
| 57 |
+
res_ty, ptr_to_buffer_holding_addr.value, loc=loc, ip=ip
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dsl_user_op
|
| 62 |
+
def alloc_tmem(
|
| 63 |
+
num_columns: Int,
|
| 64 |
+
smem_ptr_to_write_address: Pointer,
|
| 65 |
+
is_two_cta=None,
|
| 66 |
+
*,
|
| 67 |
+
loc=None,
|
| 68 |
+
ip=None,
|
| 69 |
+
) -> None:
|
| 70 |
+
"""
|
| 71 |
+
Allocates TMEM.
|
| 72 |
+
|
| 73 |
+
:param num_columns: The number of TMEM columns to allocate
|
| 74 |
+
:type num_columns: Int
|
| 75 |
+
:param smem_ptr_to_write_address: A pointer to a SMEM buffer where the TMEM address is written
|
| 76 |
+
to
|
| 77 |
+
:type smem_ptr_to_write_address: Pointer
|
| 78 |
+
:param is_two_cta: Optional boolean parameter for 2-CTA MMAs
|
| 79 |
+
"""
|
| 80 |
+
if isinstance(num_columns, int):
|
| 81 |
+
if (
|
| 82 |
+
num_columns < SM100_TMEM_MIN_ALLOC_COLUMNS
|
| 83 |
+
or num_columns > SM100_TMEM_CAPACITY_COLUMNS
|
| 84 |
+
or not (num_columns & (num_columns - 1) == 0)
|
| 85 |
+
):
|
| 86 |
+
raise ValueError(
|
| 87 |
+
f"num_columns must be between 32 and 512, and must be pow of 2, but got {num_columns}"
|
| 88 |
+
)
|
| 89 |
+
_cute_nvgpu_ir.arch_sm100_alloc_tmem(
|
| 90 |
+
Int32(num_columns).ir_value(loc=loc, ip=ip),
|
| 91 |
+
smem_ptr_to_write_address.value,
|
| 92 |
+
is_two_cta=is_two_cta,
|
| 93 |
+
loc=loc,
|
| 94 |
+
ip=ip,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dsl_user_op
|
| 99 |
+
def relinquish_tmem_alloc_permit(is_two_cta=None, *, loc=None, ip=None) -> None:
|
| 100 |
+
"""
|
| 101 |
+
Relinquishes the right to allocate TMEM so that other CTAs potentially in a different grid can
|
| 102 |
+
allocate.
|
| 103 |
+
"""
|
| 104 |
+
_cute_nvgpu_ir.arch_sm100_relinquish_tmem_alloc_permit(
|
| 105 |
+
is_two_cta=is_two_cta, loc=loc, ip=ip
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@dsl_user_op
|
| 110 |
+
def dealloc_tmem(
|
| 111 |
+
tmem_ptr: Pointer,
|
| 112 |
+
num_columns: Int,
|
| 113 |
+
is_two_cta=None,
|
| 114 |
+
*,
|
| 115 |
+
loc=None,
|
| 116 |
+
ip=None,
|
| 117 |
+
) -> None:
|
| 118 |
+
"""
|
| 119 |
+
Deallocates TMEM using the provided pointer and number of columns.
|
| 120 |
+
|
| 121 |
+
:param tmem_ptr: A pointer to the TMEM allocation to de-allocate
|
| 122 |
+
:type tmem_ptr: Pointer
|
| 123 |
+
:param num_columns: The number of columns in the TMEM allocation
|
| 124 |
+
:type num_columns: Int
|
| 125 |
+
:param is_two_cta: Optional boolean parameter for 2-CTA MMAs
|
| 126 |
+
"""
|
| 127 |
+
if isinstance(num_columns, int):
|
| 128 |
+
if (
|
| 129 |
+
num_columns < SM100_TMEM_MIN_ALLOC_COLUMNS
|
| 130 |
+
or num_columns > SM100_TMEM_CAPACITY_COLUMNS
|
| 131 |
+
or not (num_columns & (num_columns - 1) == 0)
|
| 132 |
+
):
|
| 133 |
+
raise ValueError(
|
| 134 |
+
f"num_columns must be between 32 and 512, and must be pow of 2, but got {num_columns}"
|
| 135 |
+
)
|
| 136 |
+
_cute_nvgpu_ir.arch_sm100_dealloc_tmem(
|
| 137 |
+
tmem_ptr.value,
|
| 138 |
+
Int32(num_columns).ir_value(loc=loc, ip=ip),
|
| 139 |
+
is_two_cta=is_two_cta,
|
| 140 |
+
loc=loc,
|
| 141 |
+
ip=ip,
|
| 142 |
+
)
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/core.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/math.py
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from .core import TensorSSA
|
| 13 |
+
from .typing import Numeric
|
| 14 |
+
from cutlass._mlir.dialects import math, arith
|
| 15 |
+
|
| 16 |
+
from typing import Callable, Union
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _math_op(func: Callable, fastmath: bool, *args, **kwargs):
|
| 20 |
+
"""Dispatch the function to either a TensorSSA or a Numeric(Float).
|
| 21 |
+
|
| 22 |
+
:param func: The function to dispatch
|
| 23 |
+
:param args: The input tensor or scalar
|
| 24 |
+
:param kwargs: The input tensor or scalar
|
| 25 |
+
"""
|
| 26 |
+
arg_type = type(args[0])
|
| 27 |
+
for arg in args:
|
| 28 |
+
if not isinstance(arg, TensorSSA) and (
|
| 29 |
+
not isinstance(arg, Numeric) or not type(arg).is_float
|
| 30 |
+
):
|
| 31 |
+
raise TypeError(
|
| 32 |
+
f"Expected a TensorSSA or Numeric(Float), but got {type(arg)}"
|
| 33 |
+
)
|
| 34 |
+
if not isinstance(arg, arg_type):
|
| 35 |
+
raise TypeError(
|
| 36 |
+
f"Expected all inputs to be of type {arg_type}, but got {type(arg)}"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
fastmath_flag = arith.FastMathFlags.fast if fastmath else arith.FastMathFlags.none
|
| 40 |
+
if isinstance(args[0], TensorSSA):
|
| 41 |
+
return TensorSSA(
|
| 42 |
+
func(*args, fastmath=fastmath_flag), args[0].shape, args[0].dtype
|
| 43 |
+
)
|
| 44 |
+
else:
|
| 45 |
+
args = [a.ir_value() for a in args]
|
| 46 |
+
return func(*args, fastmath=fastmath_flag)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def acos(
|
| 50 |
+
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 51 |
+
) -> Union[TensorSSA, Numeric]:
|
| 52 |
+
"""Compute element-wise arc cosine of the input tensor.
|
| 53 |
+
|
| 54 |
+
:param a: Input tensor
|
| 55 |
+
:type a: Union[TensorSSA, Numeric]
|
| 56 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 57 |
+
:type fastmath: bool, optional
|
| 58 |
+
:return: Tensor containing the arc cosine of each element in input tensor
|
| 59 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 60 |
+
|
| 61 |
+
Example:
|
| 62 |
+
|
| 63 |
+
.. code-block::
|
| 64 |
+
|
| 65 |
+
x = cute.make_fragment(layout) # Create tensor
|
| 66 |
+
y = x.load() # Load values
|
| 67 |
+
z = acos(y) # Compute arc cosine
|
| 68 |
+
"""
|
| 69 |
+
return _math_op(math.acos, fastmath, a)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def asin(
|
| 73 |
+
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 74 |
+
) -> Union[TensorSSA, Numeric]:
|
| 75 |
+
"""Compute element-wise arc sine of the input tensor.
|
| 76 |
+
|
| 77 |
+
:param a: Input tensor
|
| 78 |
+
:type a: Union[TensorSSA, Numeric]
|
| 79 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 80 |
+
:type fastmath: bool, optional
|
| 81 |
+
:return: Tensor containing the arc sine of each element in input tensor
|
| 82 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 83 |
+
|
| 84 |
+
Example:
|
| 85 |
+
|
| 86 |
+
.. code-block::
|
| 87 |
+
|
| 88 |
+
x = cute.make_fragment(layout) # Create tensor
|
| 89 |
+
y = x.load() # Load values
|
| 90 |
+
z = asin(y) # Compute arc sine
|
| 91 |
+
"""
|
| 92 |
+
return _math_op(math.asin, fastmath, a)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def atan(
|
| 96 |
+
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 97 |
+
) -> Union[TensorSSA, Numeric]:
|
| 98 |
+
"""Compute element-wise arc tangent of the input tensor.
|
| 99 |
+
|
| 100 |
+
:param a: Input tensor
|
| 101 |
+
:type a: Union[TensorSSA, Numeric]
|
| 102 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 103 |
+
:type fastmath: bool, optional
|
| 104 |
+
:return: Tensor containing the arc tangent of each element in input tensor
|
| 105 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 106 |
+
|
| 107 |
+
Example:
|
| 108 |
+
|
| 109 |
+
.. code-block::
|
| 110 |
+
|
| 111 |
+
x = cute.make_fragment(layout) # Create tensor
|
| 112 |
+
y = x.load() # Load values
|
| 113 |
+
z = atan(y) # Compute arc tangent
|
| 114 |
+
"""
|
| 115 |
+
raise NotImplementedError("atan is not implemented")
|
| 116 |
+
return _math_op(math.atan, fastmath, a)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def atan2(
|
| 120 |
+
a: Union[TensorSSA, Numeric], b: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 121 |
+
) -> Union[TensorSSA, Numeric]:
|
| 122 |
+
"""Compute element-wise arc tangent of two tensors.
|
| 123 |
+
|
| 124 |
+
Computes atan2(a, b) element-wise. The function atan2(a, b) is the angle in radians
|
| 125 |
+
between the positive x-axis and the point given by the coordinates (b, a).
|
| 126 |
+
|
| 127 |
+
:param a: First input tensor (y-coordinates)
|
| 128 |
+
:type a: Union[TensorSSA, Numeric]
|
| 129 |
+
:param b: Second input tensor (x-coordinates)
|
| 130 |
+
:type b: Union[TensorSSA, Numeric]
|
| 131 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 132 |
+
:type fastmath: bool, optional
|
| 133 |
+
:return: Tensor containing the arc tangent of a/b element-wise
|
| 134 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 135 |
+
|
| 136 |
+
Example:
|
| 137 |
+
|
| 138 |
+
.. code-block::
|
| 139 |
+
|
| 140 |
+
y = cute.make_fragment(ptr1, layout).load() # y coordinates
|
| 141 |
+
x = cute.make_fragment(ptr2, layout).load() # x coordinates
|
| 142 |
+
theta = atan2(y, x) # Compute angles
|
| 143 |
+
"""
|
| 144 |
+
return _math_op(math.atan2, fastmath, a, b)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def cos(
|
| 148 |
+
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 149 |
+
) -> Union[TensorSSA, Numeric]:
|
| 150 |
+
"""Compute element-wise cosine of the input tensor.
|
| 151 |
+
|
| 152 |
+
:param a: Input tensor (in radians)
|
| 153 |
+
:type a: Union[TensorSSA, Numeric]
|
| 154 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 155 |
+
:type fastmath: bool, optional
|
| 156 |
+
:return: Tensor containing the cosine of each element
|
| 157 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 158 |
+
|
| 159 |
+
Example:
|
| 160 |
+
|
| 161 |
+
.. code-block::
|
| 162 |
+
|
| 163 |
+
x = cute.make_fragment(layout) # Create tensor
|
| 164 |
+
y = x.load() # Load values
|
| 165 |
+
z = cos(y) # Compute cosine
|
| 166 |
+
"""
|
| 167 |
+
return _math_op(math.cos, fastmath, a)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def erf(
|
| 171 |
+
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 172 |
+
) -> Union[TensorSSA, Numeric]:
|
| 173 |
+
"""Compute element-wise error function of the input tensor.
|
| 174 |
+
|
| 175 |
+
The error function is defined as:
|
| 176 |
+
erf(x) = 2/√π ∫[0 to x] exp(-t²) dt
|
| 177 |
+
|
| 178 |
+
:param a: Input tensor
|
| 179 |
+
:type a: Union[TensorSSA, Numeric]
|
| 180 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 181 |
+
:type fastmath: bool, optional
|
| 182 |
+
:return: Tensor containing the error function value for each element
|
| 183 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 184 |
+
|
| 185 |
+
Example:
|
| 186 |
+
|
| 187 |
+
.. code-block::
|
| 188 |
+
|
| 189 |
+
x = cute.make_fragment(layout) # Create tensor
|
| 190 |
+
y = x.load() # Load values
|
| 191 |
+
z = erf(y) # Compute error function
|
| 192 |
+
"""
|
| 193 |
+
return _math_op(math.erf, fastmath, a)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def exp(
|
| 197 |
+
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 198 |
+
) -> Union[TensorSSA, Numeric]:
|
| 199 |
+
"""Compute element-wise exponential of the input tensor.
|
| 200 |
+
|
| 201 |
+
:param a: Input tensor
|
| 202 |
+
:type a: Union[TensorSSA, Numeric]
|
| 203 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 204 |
+
:type fastmath: bool, optional
|
| 205 |
+
:return: Tensor containing the exponential of each element
|
| 206 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 207 |
+
|
| 208 |
+
Example:
|
| 209 |
+
|
| 210 |
+
.. code-block::
|
| 211 |
+
|
| 212 |
+
x = cute.make_fragment(layout) # Create tensor
|
| 213 |
+
y = x.load() # Load values
|
| 214 |
+
z = exp(y) # Compute exponential
|
| 215 |
+
"""
|
| 216 |
+
return _math_op(math.exp, fastmath, a)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def exp2(
|
| 220 |
+
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 221 |
+
) -> Union[TensorSSA, Numeric]:
|
| 222 |
+
"""Compute element-wise base-2 exponential of the input tensor.
|
| 223 |
+
|
| 224 |
+
:param a: Input tensor
|
| 225 |
+
:type a: Union[TensorSSA, Numeric]
|
| 226 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 227 |
+
:type fastmath: bool, optional
|
| 228 |
+
:return: Tensor containing 2 raised to the power of each element
|
| 229 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 230 |
+
|
| 231 |
+
Example:
|
| 232 |
+
|
| 233 |
+
.. code-block::
|
| 234 |
+
|
| 235 |
+
x = cute.make_fragment(layout) # Create tensor
|
| 236 |
+
y = x.load() # Load values
|
| 237 |
+
z = exp2(y) # Compute 2^x
|
| 238 |
+
"""
|
| 239 |
+
return _math_op(math.exp2, fastmath, a)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def log(
|
| 243 |
+
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 244 |
+
) -> Union[TensorSSA, Numeric]:
|
| 245 |
+
"""Compute element-wise natural logarithm of the input tensor.
|
| 246 |
+
|
| 247 |
+
:param a: Input tensor
|
| 248 |
+
:type a: Union[TensorSSA, Numeric]
|
| 249 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 250 |
+
:type fastmath: bool, optional
|
| 251 |
+
:return: Tensor containing the natural logarithm of each element
|
| 252 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 253 |
+
|
| 254 |
+
Example:
|
| 255 |
+
|
| 256 |
+
.. code-block::
|
| 257 |
+
|
| 258 |
+
x = cute.make_fragment(layout) # Create tensor
|
| 259 |
+
y = x.load() # Load values
|
| 260 |
+
z = log(y) # Compute natural logarithm
|
| 261 |
+
"""
|
| 262 |
+
return _math_op(math.log, fastmath, a)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def log2(
|
| 266 |
+
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 267 |
+
) -> Union[TensorSSA, Numeric]:
|
| 268 |
+
"""Compute element-wise base-2 logarithm of the input tensor.
|
| 269 |
+
|
| 270 |
+
:param a: Input tensor
|
| 271 |
+
:type a: Union[TensorSSA, Numeric]
|
| 272 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 273 |
+
:type fastmath: bool, optional
|
| 274 |
+
:return: Tensor containing the base-2 logarithm of each element
|
| 275 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 276 |
+
|
| 277 |
+
Example:
|
| 278 |
+
|
| 279 |
+
.. code-block::
|
| 280 |
+
|
| 281 |
+
x = cute.make_fragment(layout) # Create tensor
|
| 282 |
+
y = x.load() # Load values
|
| 283 |
+
z = log2(y) # Compute log base 2
|
| 284 |
+
"""
|
| 285 |
+
return _math_op(math.log2, fastmath, a)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def log10(
|
| 289 |
+
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 290 |
+
) -> Union[TensorSSA, Numeric]:
|
| 291 |
+
"""Compute element-wise base-10 logarithm of the input tensor.
|
| 292 |
+
|
| 293 |
+
:param a: Input tensor
|
| 294 |
+
:type a: Union[TensorSSA, Numeric]
|
| 295 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 296 |
+
:type fastmath: bool, optional
|
| 297 |
+
:return: Tensor containing the base-10 logarithm of each element
|
| 298 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 299 |
+
|
| 300 |
+
Example:
|
| 301 |
+
|
| 302 |
+
.. code-block::
|
| 303 |
+
|
| 304 |
+
x = cute.make_fragment(layout) # Create tensor
|
| 305 |
+
y = x.load() # Load values
|
| 306 |
+
z = log10(y) # Compute log base 10
|
| 307 |
+
"""
|
| 308 |
+
return _math_op(math.log10, fastmath, a)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def rsqrt(
|
| 312 |
+
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 313 |
+
) -> Union[TensorSSA, Numeric]:
|
| 314 |
+
"""Compute element-wise reciprocal square root of the input tensor.
|
| 315 |
+
|
| 316 |
+
Computes 1/√x element-wise.
|
| 317 |
+
|
| 318 |
+
:param a: Input tensor
|
| 319 |
+
:type a: Union[TensorSSA, Numeric]
|
| 320 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 321 |
+
:type fastmath: bool, optional
|
| 322 |
+
:return: Tensor containing the reciprocal square root of each element
|
| 323 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 324 |
+
|
| 325 |
+
Example:
|
| 326 |
+
|
| 327 |
+
.. code-block::
|
| 328 |
+
|
| 329 |
+
x = cute.make_fragment(layout) # Create tensor
|
| 330 |
+
y = x.load() # Load values
|
| 331 |
+
z = rsqrt(y) # Compute 1/√x
|
| 332 |
+
"""
|
| 333 |
+
return _math_op(math.rsqrt, fastmath, a)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def sin(
|
| 337 |
+
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 338 |
+
) -> Union[TensorSSA, Numeric]:
|
| 339 |
+
"""Compute element-wise sine of the input tensor.
|
| 340 |
+
|
| 341 |
+
:param a: Input tensor (in radians)
|
| 342 |
+
:type a: Union[TensorSSA, Numeric]
|
| 343 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 344 |
+
:type fastmath: bool, optional
|
| 345 |
+
:return: Tensor containing the sine of each element
|
| 346 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 347 |
+
|
| 348 |
+
Example:
|
| 349 |
+
|
| 350 |
+
.. code-block::
|
| 351 |
+
|
| 352 |
+
x = cute.make_fragment(layout) # Create tensor
|
| 353 |
+
y = x.load() # Load values
|
| 354 |
+
z = sin(y) # Compute sine
|
| 355 |
+
"""
|
| 356 |
+
return _math_op(math.sin, fastmath, a)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def sqrt(
|
| 360 |
+
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 361 |
+
) -> Union[TensorSSA, Numeric]:
|
| 362 |
+
"""Compute element-wise square root of the input tensor.
|
| 363 |
+
|
| 364 |
+
:param a: Input tensor
|
| 365 |
+
:type a: Union[TensorSSA, Numeric]
|
| 366 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 367 |
+
:type fastmath: bool, optional
|
| 368 |
+
:return: Tensor containing the square root of each element
|
| 369 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 370 |
+
|
| 371 |
+
Example:
|
| 372 |
+
|
| 373 |
+
.. code-block::
|
| 374 |
+
|
| 375 |
+
x = cute.make_fragment(layout) # Create tensor
|
| 376 |
+
y = x.load() # Load values
|
| 377 |
+
z = sqrt(y) # Compute square root
|
| 378 |
+
"""
|
| 379 |
+
return _math_op(math.sqrt, fastmath, a)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def tan(
|
| 383 |
+
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 384 |
+
) -> Union[TensorSSA, Numeric]:
|
| 385 |
+
"""Compute element-wise tangent of the input tensor.
|
| 386 |
+
|
| 387 |
+
:param a: Input tensor (in radians)
|
| 388 |
+
:type a: Union[TensorSSA, Numeric]
|
| 389 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 390 |
+
:type fastmath: bool, optional
|
| 391 |
+
:return: Tensor containing the tangent of each element
|
| 392 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 393 |
+
|
| 394 |
+
Example:
|
| 395 |
+
|
| 396 |
+
.. code-block::
|
| 397 |
+
|
| 398 |
+
x = cute.make_fragment(layout) # Create tensor
|
| 399 |
+
y = x.load() # Load values
|
| 400 |
+
z = tan(y) # Compute tangent
|
| 401 |
+
"""
|
| 402 |
+
return _math_op(math.tan, fastmath, a)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def tanh(
|
| 406 |
+
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
| 407 |
+
) -> Union[TensorSSA, Numeric]:
|
| 408 |
+
"""Compute element-wise hyperbolic tangent of the input tensor.
|
| 409 |
+
|
| 410 |
+
:param a: Input tensor
|
| 411 |
+
:type a: Union[TensorSSA, Numeric]
|
| 412 |
+
:param fastmath: Enable fast math optimizations, defaults to False
|
| 413 |
+
:type fastmath: bool, optional
|
| 414 |
+
:return: Tensor containing the hyperbolic tangent of each element
|
| 415 |
+
:rtype: Union[TensorSSA, Numeric]
|
| 416 |
+
|
| 417 |
+
Example:
|
| 418 |
+
|
| 419 |
+
.. code-block::
|
| 420 |
+
|
| 421 |
+
x = cute.make_fragment(layout) # Create tensor
|
| 422 |
+
y = x.load() # Load values
|
| 423 |
+
z = tanh(y) # Compute hyperbolic tangent
|
| 424 |
+
"""
|
| 425 |
+
return _math_op(math.tanh, fastmath, a)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
__all__ = [
|
| 429 |
+
"acos",
|
| 430 |
+
"asin",
|
| 431 |
+
"atan",
|
| 432 |
+
"atan2",
|
| 433 |
+
"cos",
|
| 434 |
+
"erf",
|
| 435 |
+
"exp",
|
| 436 |
+
"exp2",
|
| 437 |
+
"log",
|
| 438 |
+
"log10",
|
| 439 |
+
"log2",
|
| 440 |
+
"rsqrt",
|
| 441 |
+
"sin",
|
| 442 |
+
"sqrt",
|
| 443 |
+
"tan",
|
| 444 |
+
"tanh",
|
| 445 |
+
]
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from . import warp
|
| 13 |
+
from . import cpasync
|
| 14 |
+
from . import warpgroup
|
| 15 |
+
from . import tcgen05
|
| 16 |
+
|
| 17 |
+
from .common import *
|
| 18 |
+
from .helpers import *
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# __all__ is required here for documentation generation
|
| 22 |
+
__all__ = [
|
| 23 |
+
"OpError",
|
| 24 |
+
"MmaUniversalOp",
|
| 25 |
+
"CopyUniversalOp",
|
| 26 |
+
]
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/common.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
import enum
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import Type, Optional
|
| 14 |
+
|
| 15 |
+
from cutlass.cutlass_dsl import DSLBaseError
|
| 16 |
+
|
| 17 |
+
import cutlass._mlir.dialects.cute as _cute_ir
|
| 18 |
+
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
| 19 |
+
from cutlass._mlir import ir
|
| 20 |
+
|
| 21 |
+
from .. import core
|
| 22 |
+
from ..typing import Float16, Float32, Float64, Numeric
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class OpError(DSLBaseError):
|
| 26 |
+
"""
|
| 27 |
+
An exception class for Op construction errors.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self, op: core.Op, message: str, suggestion: Optional[str] = None
|
| 32 |
+
) -> None:
|
| 33 |
+
if suggestion is None:
|
| 34 |
+
# Default suggestion
|
| 35 |
+
suggestion = "Check your Op construction code"
|
| 36 |
+
super().__init__(
|
| 37 |
+
message,
|
| 38 |
+
error_code=f"{op.__class__.__name__} error",
|
| 39 |
+
suggestion=suggestion,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
####################################################################################################
|
| 44 |
+
#
|
| 45 |
+
# MMA Ops and Traits
|
| 46 |
+
#
|
| 47 |
+
####################################################################################################
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass(frozen=True)
|
| 51 |
+
class MmaUniversalOp(core.MmaOp):
|
| 52 |
+
"""
|
| 53 |
+
The universal MMA Operation.
|
| 54 |
+
|
| 55 |
+
This Operation currently expects the A/B operands as well as the accumulator to share the same
|
| 56 |
+
data types.
|
| 57 |
+
|
| 58 |
+
:param abacc_dtype: The data type for the A/B operands and the accumulator
|
| 59 |
+
:type abacc_dtype: Type[Numeric]
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
abacc_dtype: Type[Numeric]
|
| 63 |
+
|
| 64 |
+
def __post_init__(self) -> None:
|
| 65 |
+
if self.abacc_dtype not in [Float16, Float32, Float64]:
|
| 66 |
+
raise OpError(
|
| 67 |
+
self,
|
| 68 |
+
f"expects the 'abacc_dtype' Op parameter to be one of Float16, Float32, or Float64",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def __str__(self) -> str:
|
| 72 |
+
return (
|
| 73 |
+
"universal MMA Operation using FMA"
|
| 74 |
+
f"\n A/B/Accumulator data type = {self.abacc_dtype}"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaUniversalTrait":
|
| 78 |
+
shape_mnk_attr = ir.Attribute.parse(f'#cute.shape<"(1,1,1)">')
|
| 79 |
+
atom_ty = _cute_nvgpu_ir.UniversalFmaAtomType.get(
|
| 80 |
+
shape_mnk_attr,
|
| 81 |
+
self.abacc_dtype.mlir_type,
|
| 82 |
+
self.abacc_dtype.mlir_type,
|
| 83 |
+
self.abacc_dtype.mlir_type,
|
| 84 |
+
)
|
| 85 |
+
return MmaUniversalTrait(_cute_ir.atom(atom_ty, loc=loc, ip=ip))
|
| 86 |
+
|
| 87 |
+
def _verify_fragment_A(self, input, *, loc=None, ip=None):
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
def _verify_fragment_B(self, input, *, loc=None, ip=None):
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
class MmaUniversalTrait(core.Trait):
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
####################################################################################################
|
| 98 |
+
#
|
| 99 |
+
# Copy Ops and Traits
|
| 100 |
+
#
|
| 101 |
+
####################################################################################################
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class MemoryOrder(enum.Enum):
|
| 105 |
+
WEAK = _cute_ir.MemOrderKind.WEAK
|
| 106 |
+
RELAXED = _cute_ir.MemOrderKind.RELAXED
|
| 107 |
+
ACQUIRE = _cute_ir.MemOrderKind.ACQUIRE
|
| 108 |
+
RELEASE = _cute_ir.MemOrderKind.RELEASE
|
| 109 |
+
ACQ_REL = _cute_ir.MemOrderKind.ACQ_REL
|
| 110 |
+
SC = _cute_ir.MemOrderKind.SC
|
| 111 |
+
MMIO = _cute_ir.MemOrderKind.MMIO
|
| 112 |
+
CONSTANT = _cute_ir.MemOrderKind.CONSTANT
|
| 113 |
+
VOLATILE = _cute_ir.MemOrderKind.VOLATILE
|
| 114 |
+
|
| 115 |
+
def __str__(self) -> str:
|
| 116 |
+
return f"{self.__class__.__name__}.{self.name}"
|
| 117 |
+
|
| 118 |
+
def __repr__(self) -> str:
|
| 119 |
+
return f"<{self.__class__.__name__}.{self.name}>"
|
| 120 |
+
|
| 121 |
+
def _to_ir(self) -> _cute_ir.MemOrderKind:
|
| 122 |
+
return self.value
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class MemoryScope(enum.Enum):
|
| 126 |
+
CTA = _cute_ir.MemScopeKind.CTA
|
| 127 |
+
CLUSTER = _cute_ir.MemScopeKind.CLUSTER
|
| 128 |
+
GPU = _cute_ir.MemScopeKind.GPU
|
| 129 |
+
SYS = _cute_ir.MemScopeKind.SYS
|
| 130 |
+
|
| 131 |
+
def __str__(self) -> str:
|
| 132 |
+
return f"{self.__class__.__name__}.{self.name}"
|
| 133 |
+
|
| 134 |
+
def __repr__(self) -> str:
|
| 135 |
+
return f"<{self.__class__.__name__}.{self.name}>"
|
| 136 |
+
|
| 137 |
+
def _to_ir(self) -> _cute_ir.MemScopeKind:
|
| 138 |
+
return self.value
|
| 139 |
+
|
| 140 |
+
@dataclass(frozen=True)
|
| 141 |
+
class CopyUniversalOp(core.CopyOp):
|
| 142 |
+
"""
|
| 143 |
+
The universal Copy Operation.
|
| 144 |
+
|
| 145 |
+
When creating a Copy Atom out of this operation, the expected usage pattern is
|
| 146 |
+
|
| 147 |
+
.. code-block:: python
|
| 148 |
+
|
| 149 |
+
op = cute.nvgpu.CopyUniversalOp()
|
| 150 |
+
atom = cute.make_copy_atom(op, tensor_dtype, num_bits_per_copy=64)
|
| 151 |
+
|
| 152 |
+
- ``tensor_dtype`` is the data type used to build the reference TV Layout (either the source \
|
| 153 |
+
or the destination TV Layout) in unit of tensor elements and is used for partitioning by \
|
| 154 |
+
``TiledCopy`` for example
|
| 155 |
+
- ``num_bits_per_copy`` is a kw argument specifying the number of bits to copy per Atom \
|
| 156 |
+
execution. This can be larger than the width of the above data type. When not provided, \
|
| 157 |
+
the compiler will do a best effort at auto-vectorizing.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
def __str__(self) -> str:
|
| 161 |
+
return "universal Copy Operation"
|
| 162 |
+
|
| 163 |
+
def _make_trait(
|
| 164 |
+
self,
|
| 165 |
+
copy_internal_type: Type[Numeric],
|
| 166 |
+
*,
|
| 167 |
+
loc=None,
|
| 168 |
+
ip=None,
|
| 169 |
+
**kwargs,
|
| 170 |
+
) -> "CopyUniversalTrait":
|
| 171 |
+
num_bits_per_copy = kwargs.get("num_bits_per_copy", 0)
|
| 172 |
+
memory_order = kwargs.get("memory_order", MemoryOrder.WEAK)
|
| 173 |
+
memory_scope = kwargs.get("memory_scope", MemoryScope.CTA)
|
| 174 |
+
if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0):
|
| 175 |
+
raise ValueError(
|
| 176 |
+
"expects a 'num_bits_per_copy' kw argument of type int that is non-negative "
|
| 177 |
+
f"when creating a copy Atom for {self.__class__.__name__}"
|
| 178 |
+
)
|
| 179 |
+
ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get(
|
| 180 |
+
copy_internal_type.mlir_type,
|
| 181 |
+
num_bits_per_copy,
|
| 182 |
+
memory_order._to_ir(),
|
| 183 |
+
memory_scope._to_ir(),
|
| 184 |
+
)
|
| 185 |
+
return CopyUniversalTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class CopyUniversalTrait(core.Trait):
|
| 189 |
+
pass
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from .copy import *
|
| 13 |
+
from .helpers import *
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# __all__ is required here for documentation generation
|
| 17 |
+
__all__ = [
|
| 18 |
+
#
|
| 19 |
+
# copy.py
|
| 20 |
+
#
|
| 21 |
+
"LoadCacheMode",
|
| 22 |
+
"CopyG2SOp",
|
| 23 |
+
"CopyBulkTensorTileG2SOp",
|
| 24 |
+
"CopyBulkTensorTileG2SMulticastOp",
|
| 25 |
+
"CopyBulkTensorTileS2GOp",
|
| 26 |
+
"CopyReduceBulkTensorTileS2GOp",
|
| 27 |
+
#
|
| 28 |
+
# helpers.py
|
| 29 |
+
#
|
| 30 |
+
"make_tiled_tma_atom",
|
| 31 |
+
"tma_partition",
|
| 32 |
+
"create_tma_multicast_mask",
|
| 33 |
+
"prefetch_descriptor",
|
| 34 |
+
"copy_tensormap",
|
| 35 |
+
"update_tma_descriptor",
|
| 36 |
+
"fence_tma_desc_acquire",
|
| 37 |
+
"cp_fence_tma_desc_release",
|
| 38 |
+
"fence_tma_desc_release",
|
| 39 |
+
]
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
import enum
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Optional, Type
|
| 15 |
+
|
| 16 |
+
from cutlass.cutlass_dsl import CuTeDSL, t
|
| 17 |
+
|
| 18 |
+
import cutlass._mlir.dialects.cute as _cute_ir
|
| 19 |
+
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
| 20 |
+
from cutlass._mlir import ir
|
| 21 |
+
|
| 22 |
+
from ...core import CopyOp, Trait, ReductionOp
|
| 23 |
+
from ...typing import Int16, Pointer, Integer, Numeric
|
| 24 |
+
from ..common import OpError
|
| 25 |
+
from ..tcgen05.mma import CtaGroup
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
####################################################################################################
|
| 29 |
+
#
|
| 30 |
+
# Aynchronous copies
|
| 31 |
+
#
|
| 32 |
+
####################################################################################################
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class LoadCacheMode(enum.Enum):
|
| 36 |
+
"""
|
| 37 |
+
An enumeration for the possible cache modes of a non-bulk ``cp.async`` instruction.
|
| 38 |
+
|
| 39 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#cache-operators>`__.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
ALWAYS = _cute_nvgpu_ir.LoadCacheMode.always
|
| 43 |
+
GLOBAL = _cute_nvgpu_ir.LoadCacheMode.global_
|
| 44 |
+
STREAMING = _cute_nvgpu_ir.LoadCacheMode.streaming
|
| 45 |
+
LAST_USE = _cute_nvgpu_ir.LoadCacheMode.last_use
|
| 46 |
+
NONE = _cute_nvgpu_ir.LoadCacheMode.none
|
| 47 |
+
|
| 48 |
+
def __str__(self) -> str:
|
| 49 |
+
return f"{self.__class__.__name__}.{self.name}"
|
| 50 |
+
|
| 51 |
+
def __repr__(self) -> str:
|
| 52 |
+
return f"<{self.__class__.__name__}.{self.name}>"
|
| 53 |
+
|
| 54 |
+
def _to_ir(self) -> _cute_nvgpu_ir.LoadCacheMode:
|
| 55 |
+
return self.value
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass(frozen=True)
|
| 59 |
+
class CopyG2SOp(CopyOp):
|
| 60 |
+
"""
|
| 61 |
+
Non-bulk asynchronous GMEM to SMEM Copy Operation.
|
| 62 |
+
|
| 63 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-non-bulk-copy>`__.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
cache_mode: LoadCacheMode = LoadCacheMode.ALWAYS
|
| 67 |
+
|
| 68 |
+
def __str__(self) -> str:
|
| 69 |
+
res = "cp.async GMEM -> SMEM copy Operation"
|
| 70 |
+
if self.cache_mode != LoadCacheMode.ALWAYS:
|
| 71 |
+
res += f"\n with cache mode = {self.cache_mode}"
|
| 72 |
+
return res
|
| 73 |
+
|
| 74 |
+
def _make_trait(
|
| 75 |
+
self,
|
| 76 |
+
copy_internal_type: Type[t.Numeric],
|
| 77 |
+
*,
|
| 78 |
+
loc=None,
|
| 79 |
+
ip=None,
|
| 80 |
+
**kwargs,
|
| 81 |
+
) -> "CopyG2STrait":
|
| 82 |
+
num_bits_per_copy = kwargs.get("num_bits_per_copy", None)
|
| 83 |
+
# Verify that the user provided enum values
|
| 84 |
+
if not isinstance(self.cache_mode, LoadCacheMode):
|
| 85 |
+
raise OpError(
|
| 86 |
+
self,
|
| 87 |
+
"expects the 'cache_mode' Op parameter to be a LoadCacheMode instance",
|
| 88 |
+
)
|
| 89 |
+
if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy <= 0):
|
| 90 |
+
raise ValueError(
|
| 91 |
+
"expects a 'num_bits_per_copy' kw argument of type int that is positive "
|
| 92 |
+
f"when creating a copy Atom for {self.__class__.__name__}"
|
| 93 |
+
)
|
| 94 |
+
# Verify that the user provided enum values
|
| 95 |
+
if not isinstance(self.cache_mode, LoadCacheMode):
|
| 96 |
+
raise OpError(
|
| 97 |
+
self,
|
| 98 |
+
"expects the 'cache_mode' Op parameter to be a LoadCacheMode instance",
|
| 99 |
+
)
|
| 100 |
+
ty = _cute_nvgpu_ir.CopyAtomSIMTAsyncCopyType.get(
|
| 101 |
+
copy_internal_type.mlir_type, self.cache_mode._to_ir(), num_bits_per_copy
|
| 102 |
+
)
|
| 103 |
+
return CopyG2STrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class CopyG2STrait(Trait):
|
| 107 |
+
pass
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
####################################################################################################
|
| 111 |
+
#
|
| 112 |
+
# Bulk tensor copies a.k.a TMA copies
|
| 113 |
+
#
|
| 114 |
+
####################################################################################################
|
| 115 |
+
|
| 116 |
+
TMA_MBAR_PTR_FIELD_NAME = "tma_bar"
|
| 117 |
+
TMA_MASK_FIELD_NAME = "mcast_mask"
|
| 118 |
+
TMA_DESC_PTR_FIELD_NAME = "tma_descriptor_ptr"
|
| 119 |
+
|
| 120 |
+
#
|
| 121 |
+
# TMA GMEM -> SMEM copies
|
| 122 |
+
#
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@dataclass(frozen=True)
|
| 126 |
+
class CopyBulkTensorTileG2SOp(CopyOp):
|
| 127 |
+
"""
|
| 128 |
+
Bulk tensor asynchrnous GMEM to SMEM Copy Operation using the TMA unit.
|
| 129 |
+
|
| 130 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`__.
|
| 131 |
+
This Operation uses TMA in the ``.tile`` mode.
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
cta_group: CtaGroup = CtaGroup.ONE
|
| 135 |
+
|
| 136 |
+
admissible_archs = [
|
| 137 |
+
"sm_90",
|
| 138 |
+
"sm_90a",
|
| 139 |
+
"sm_100a",
|
| 140 |
+
"sm_100f",
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
def __post_init__(self) -> None:
|
| 144 |
+
if not isinstance(self.cta_group, CtaGroup):
|
| 145 |
+
raise OpError(
|
| 146 |
+
self, "expects the 'cta_group' parameter to be a CtaGroup instance"
|
| 147 |
+
)
|
| 148 |
+
# Arch verification
|
| 149 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 150 |
+
if arch not in self.admissible_archs:
|
| 151 |
+
raise OpError(
|
| 152 |
+
self,
|
| 153 |
+
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
| 154 |
+
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
| 155 |
+
)
|
| 156 |
+
if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90":
|
| 157 |
+
raise OpError(
|
| 158 |
+
self,
|
| 159 |
+
f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}",
|
| 160 |
+
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def __str__(self) -> str:
|
| 164 |
+
res = "cp.async GMEM -> SMEM bulk tensor copy Operation"
|
| 165 |
+
if self.cta_group == CtaGroup.TWO:
|
| 166 |
+
res += f"\n CTA group = 2"
|
| 167 |
+
return res
|
| 168 |
+
|
| 169 |
+
def _make_trait(
|
| 170 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 171 |
+
) -> "CopyBulkTensorTileG2SNonExecTrait":
|
| 172 |
+
raise NotImplementedError(
|
| 173 |
+
"Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum:
|
| 177 |
+
if self.cta_group == CtaGroup.ONE:
|
| 178 |
+
return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_90
|
| 179 |
+
elif self.cta_group == CtaGroup.TWO:
|
| 180 |
+
return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_100_2sm
|
| 181 |
+
else:
|
| 182 |
+
assert False, "unrecognized self.cta_group"
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class CopyBulkTensorTileG2SNonExecTrait(Trait):
|
| 186 |
+
# We allow kw args to be dropped so that the user can write common code for non-multicast
|
| 187 |
+
# and multicast loads.
|
| 188 |
+
def unpack(
|
| 189 |
+
self,
|
| 190 |
+
*,
|
| 191 |
+
loc=None,
|
| 192 |
+
ip=None,
|
| 193 |
+
tma_bar_ptr: Optional[Pointer] = None,
|
| 194 |
+
tma_desc_ptr: Optional[Pointer] = None,
|
| 195 |
+
**kwargs,
|
| 196 |
+
):
|
| 197 |
+
"""
|
| 198 |
+
Custom implementation of unpack for non-executable TMAs.
|
| 199 |
+
|
| 200 |
+
The non-multicast TMA load requires a `tma_bar_ptr` keyword argument to be provided when
|
| 201 |
+
using `cute.copy`. Any other kw arguments will be ignored instead of triggering an error.
|
| 202 |
+
"""
|
| 203 |
+
if not isinstance(tma_bar_ptr, Pointer):
|
| 204 |
+
raise ValueError(
|
| 205 |
+
"expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument"
|
| 206 |
+
)
|
| 207 |
+
exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
|
| 208 |
+
attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_MBAR_PTR_FIELD_NAME}>"
|
| 209 |
+
attr = ir.Attribute.parse(attr_str)
|
| 210 |
+
exec_value = _cute_nvgpu_ir.atom_set_value(
|
| 211 |
+
exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip
|
| 212 |
+
)
|
| 213 |
+
if isinstance(tma_desc_ptr, Pointer):
|
| 214 |
+
attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>"
|
| 215 |
+
attr = ir.Attribute.parse(attr_str)
|
| 216 |
+
exec_value = _cute_nvgpu_ir.atom_set_value(
|
| 217 |
+
exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
|
| 218 |
+
)
|
| 219 |
+
return exec_value
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
#
|
| 223 |
+
# TMA GMEM -> SMEM multicast copies
|
| 224 |
+
#
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@dataclass(frozen=True)
|
| 228 |
+
class CopyBulkTensorTileG2SMulticastOp(CopyOp):
|
| 229 |
+
"""
|
| 230 |
+
Bulk tensor asynchrnous multicast GMEM to SMEM Copy Operation using the TMA unit.
|
| 231 |
+
|
| 232 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`__.
|
| 233 |
+
This Operation uses TMA in the ``.tile`` mode.
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
cta_group: CtaGroup = CtaGroup.ONE
|
| 237 |
+
|
| 238 |
+
admissible_archs = [
|
| 239 |
+
"sm_90",
|
| 240 |
+
"sm_90a",
|
| 241 |
+
"sm_100a",
|
| 242 |
+
"sm_100f",
|
| 243 |
+
]
|
| 244 |
+
|
| 245 |
+
def __post_init__(self):
|
| 246 |
+
if not isinstance(self.cta_group, CtaGroup):
|
| 247 |
+
raise OpError(
|
| 248 |
+
self, "expects the 'cta_group' parameter to be a CtaGroup instance"
|
| 249 |
+
)
|
| 250 |
+
# Arch verification
|
| 251 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 252 |
+
if arch not in self.admissible_archs:
|
| 253 |
+
raise OpError(
|
| 254 |
+
self,
|
| 255 |
+
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
| 256 |
+
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
| 257 |
+
)
|
| 258 |
+
if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90":
|
| 259 |
+
raise OpError(
|
| 260 |
+
self,
|
| 261 |
+
f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}",
|
| 262 |
+
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
def __str__(self) -> str:
|
| 266 |
+
res = "cp.async GMEM -> SMEM bulk tensor multicast copy Operation"
|
| 267 |
+
if self.cta_group == CtaGroup.TWO:
|
| 268 |
+
res += f"\n CTA group = 2"
|
| 269 |
+
return res
|
| 270 |
+
|
| 271 |
+
def _make_trait(
|
| 272 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 273 |
+
) -> "CopyBulkTensorTileG2SMulticastNonExecTrait":
|
| 274 |
+
raise NotImplementedError(
|
| 275 |
+
"Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum:
|
| 279 |
+
if self.cta_group == CtaGroup.ONE:
|
| 280 |
+
return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_90_multicast
|
| 281 |
+
elif self.cta_group == CtaGroup.TWO:
|
| 282 |
+
return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_100_2sm_multicast
|
| 283 |
+
else:
|
| 284 |
+
assert False, "unrecognized self.cta_group"
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class CopyBulkTensorTileG2SMulticastNonExecTrait(Trait):
|
| 288 |
+
def unpack(
|
| 289 |
+
self,
|
| 290 |
+
*,
|
| 291 |
+
loc=None,
|
| 292 |
+
ip=None,
|
| 293 |
+
tma_bar_ptr: Optional[Pointer] = None,
|
| 294 |
+
mcast_mask=None,
|
| 295 |
+
tma_desc_ptr=None,
|
| 296 |
+
):
|
| 297 |
+
"""
|
| 298 |
+
Custom implementation of unpack for non-executable TMAs.
|
| 299 |
+
|
| 300 |
+
The multicast TMA load requires a `tma_bar_ptr` and a `mcast_mask` keyword arguments to be
|
| 301 |
+
provided when using `cute.copy`.
|
| 302 |
+
"""
|
| 303 |
+
if not isinstance(tma_bar_ptr, Pointer):
|
| 304 |
+
raise ValueError(
|
| 305 |
+
"expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument"
|
| 306 |
+
)
|
| 307 |
+
if not isinstance(mcast_mask, Integer):
|
| 308 |
+
raise ValueError(
|
| 309 |
+
"expects a multicast mask to be provided via the mcast_mask kw argument"
|
| 310 |
+
)
|
| 311 |
+
exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
|
| 312 |
+
attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<tma_bar>"
|
| 313 |
+
attr = ir.Attribute.parse(attr_str)
|
| 314 |
+
exec_value = _cute_nvgpu_ir.atom_set_value(
|
| 315 |
+
exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip
|
| 316 |
+
)
|
| 317 |
+
attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<mcast_mask>"
|
| 318 |
+
attr = ir.Attribute.parse(attr_str)
|
| 319 |
+
exec_value = _cute_nvgpu_ir.atom_set_value(
|
| 320 |
+
exec_value, attr, Int16(mcast_mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
| 321 |
+
)
|
| 322 |
+
if isinstance(tma_desc_ptr, Pointer):
|
| 323 |
+
attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>"
|
| 324 |
+
attr = ir.Attribute.parse(attr_str)
|
| 325 |
+
exec_value = _cute_nvgpu_ir.atom_set_value(
|
| 326 |
+
exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
|
| 327 |
+
)
|
| 328 |
+
return exec_value
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
#
|
| 332 |
+
# TMA SMEM -> GMEM copies
|
| 333 |
+
#
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
@dataclass(frozen=True)
|
| 337 |
+
class CopyBulkTensorTileS2GOp(CopyOp):
|
| 338 |
+
"""
|
| 339 |
+
Bulk tensor asynchronous SMEM to GMEM Copy Operation using the TMA unit.
|
| 340 |
+
|
| 341 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`__.
|
| 342 |
+
This Operation uses TMA in the ``.tile`` mode.
|
| 343 |
+
"""
|
| 344 |
+
|
| 345 |
+
admissible_archs = [
|
| 346 |
+
"sm_90",
|
| 347 |
+
"sm_90a",
|
| 348 |
+
"sm_100a",
|
| 349 |
+
"sm_100f",
|
| 350 |
+
]
|
| 351 |
+
|
| 352 |
+
def __post_init__(self):
|
| 353 |
+
# Arch verification
|
| 354 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 355 |
+
if arch not in self.admissible_archs:
|
| 356 |
+
raise OpError(
|
| 357 |
+
self,
|
| 358 |
+
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
| 359 |
+
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
def __str__(self) -> str:
|
| 363 |
+
return "cp.async SMEM -> GMEM bulk tensor copy Operation"
|
| 364 |
+
|
| 365 |
+
def _make_trait(
|
| 366 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 367 |
+
) -> "CopyBulkTensorTileS2GTrait":
|
| 368 |
+
raise NotImplementedError(
|
| 369 |
+
"Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class CopyBulkTensorTileS2GTrait(Trait):
|
| 374 |
+
def unpack(self, *, loc=None, ip=None, tma_desc_ptr: Optional[Pointer] = None):
|
| 375 |
+
"""
|
| 376 |
+
Custom implementation of unpack for non-executable TMAs.
|
| 377 |
+
"""
|
| 378 |
+
exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
|
| 379 |
+
if isinstance(tma_desc_ptr, Pointer):
|
| 380 |
+
attr_str = (
|
| 381 |
+
f"#cute_nvgpu.atom_copy_field_tmastore<{TMA_DESC_PTR_FIELD_NAME}>"
|
| 382 |
+
)
|
| 383 |
+
attr = ir.Attribute.parse(attr_str)
|
| 384 |
+
exec_value = _cute_nvgpu_ir.atom_set_value(
|
| 385 |
+
exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
|
| 386 |
+
)
|
| 387 |
+
return exec_value
|
| 388 |
+
|
| 389 |
+
@dataclass(frozen=True)
|
| 390 |
+
class CopyReduceBulkTensorTileS2GOp(CopyOp):
|
| 391 |
+
"""
|
| 392 |
+
Bulk tensor asynchronous SMEM to GMEM Reduction Operation using the TMA unit.
|
| 393 |
+
|
| 394 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk>`__.
|
| 395 |
+
This Operation uses TMA in the ``.tile`` mode.
|
| 396 |
+
"""
|
| 397 |
+
|
| 398 |
+
reduction_kind: ReductionOp = ReductionOp.ADD
|
| 399 |
+
|
| 400 |
+
admissible_archs = [
|
| 401 |
+
"sm_90",
|
| 402 |
+
"sm_90a",
|
| 403 |
+
"sm_100a",
|
| 404 |
+
"sm_100f",
|
| 405 |
+
]
|
| 406 |
+
|
| 407 |
+
def __post__init__(self):
|
| 408 |
+
# Arch verification
|
| 409 |
+
arch = CuTeDSL.__get_dsl().envar.arch
|
| 410 |
+
if arch not in self.admissible_archs:
|
| 411 |
+
raise OpError(
|
| 412 |
+
self,
|
| 413 |
+
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
| 414 |
+
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
def __str__(self) -> str:
|
| 418 |
+
return "cp.async SMEM -> GMEM bulk tensor reduction Operation"
|
| 419 |
+
|
| 420 |
+
def _make_trait(
|
| 421 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 422 |
+
) -> "CopyReduceBulkTensorTileS2GTrait":
|
| 423 |
+
raise NotImplementedError(
|
| 424 |
+
"Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
def _to_ir(self) -> _cute_nvgpu_ir.ReductionKind:
|
| 428 |
+
if self.reduction_kind == ReductionOp.ADD:
|
| 429 |
+
return _cute_nvgpu_ir.ReductionKind.ADD
|
| 430 |
+
elif self.reduction_kind == ReductionOp.MIN:
|
| 431 |
+
return _cute_nvgpu_ir.ReductionKind.MIN
|
| 432 |
+
elif self.reduction_kind == ReductionOp.MAX:
|
| 433 |
+
return _cute_nvgpu_ir.ReductionKind.MAX
|
| 434 |
+
elif self.reduction_kind == ReductionOp.INC:
|
| 435 |
+
return _cute_nvgpu_ir.ReductionKind.INC
|
| 436 |
+
elif self.reduction_kind == ReductionOp.DEC:
|
| 437 |
+
return _cute_nvgpu_ir.ReductionKind.DEC
|
| 438 |
+
elif self.reduction_kind == ReductionOp.AND:
|
| 439 |
+
return _cute_nvgpu_ir.ReductionKind.AND
|
| 440 |
+
elif self.reduction_kind == ReductionOp.OR:
|
| 441 |
+
return _cute_nvgpu_ir.ReductionKind.OR
|
| 442 |
+
elif self.reduction_kind == ReductionOp.XOR:
|
| 443 |
+
return _cute_nvgpu_ir.ReductionKind.XOR
|
| 444 |
+
else:
|
| 445 |
+
assert False, "unrecognized self.reduction_kind"
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class CopyReduceBulkTensorTileS2GTrait(Trait):
|
| 449 |
+
def unpack(self, *, loc=None, ip=None, tma_desc_ptr: Optional[Pointer] = None):
|
| 450 |
+
"""
|
| 451 |
+
Custom implementation of unpack for non-executable TMAs.
|
| 452 |
+
"""
|
| 453 |
+
exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
|
| 454 |
+
if isinstance(tma_desc_ptr, Pointer):
|
| 455 |
+
attr_str = (
|
| 456 |
+
f"#cute_nvgpu.atom_copy_field_tmareduce<{TMA_DESC_PTR_FIELD_NAME}>"
|
| 457 |
+
)
|
| 458 |
+
attr = ir.Attribute.parse(attr_str)
|
| 459 |
+
exec_value = _cute_nvgpu_ir.atom_set_value(
|
| 460 |
+
exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
|
| 461 |
+
)
|
| 462 |
+
return exec_value
|
| 463 |
+
|
| 464 |
+
__all__ = [
|
| 465 |
+
"LoadCacheMode",
|
| 466 |
+
"CopyG2SOp",
|
| 467 |
+
"CopyBulkTensorTileG2SOp",
|
| 468 |
+
"CopyBulkTensorTileG2SMulticastOp",
|
| 469 |
+
"CopyBulkTensorTileS2GOp",
|
| 470 |
+
"CopyReduceBulkTensorTileS2GOp",
|
| 471 |
+
]
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from typing import Optional, Tuple, Type, Union
|
| 13 |
+
|
| 14 |
+
from cutlass.cutlass_dsl import dsl_user_op
|
| 15 |
+
|
| 16 |
+
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
| 17 |
+
from cutlass._mlir.dialects import llvm
|
| 18 |
+
|
| 19 |
+
from ...typing import Coord, Layout, Tensor, Tiler, Pointer, Int16, Numeric, NumericMeta
|
| 20 |
+
from ... import core
|
| 21 |
+
from .copy import (
|
| 22 |
+
CopyBulkTensorTileG2SOp,
|
| 23 |
+
CopyBulkTensorTileG2SMulticastOp,
|
| 24 |
+
CopyBulkTensorTileS2GOp,
|
| 25 |
+
CopyReduceBulkTensorTileS2GOp,
|
| 26 |
+
CopyBulkTensorTileG2SNonExecTrait,
|
| 27 |
+
CopyBulkTensorTileG2SMulticastNonExecTrait,
|
| 28 |
+
CopyBulkTensorTileS2GTrait,
|
| 29 |
+
CopyReduceBulkTensorTileS2GTrait,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dsl_user_op
|
| 34 |
+
def make_tiled_tma_atom(
|
| 35 |
+
op: Union[
|
| 36 |
+
CopyBulkTensorTileG2SOp,
|
| 37 |
+
CopyBulkTensorTileG2SMulticastOp,
|
| 38 |
+
CopyBulkTensorTileS2GOp,
|
| 39 |
+
CopyReduceBulkTensorTileS2GOp,
|
| 40 |
+
],
|
| 41 |
+
gmem_tensor: Tensor,
|
| 42 |
+
smem_layout: Union[Layout, core.ComposedLayout],
|
| 43 |
+
cta_tiler: Tiler,
|
| 44 |
+
num_multicast: int = 1,
|
| 45 |
+
*,
|
| 46 |
+
internal_type: Optional[Type[Numeric]] = None,
|
| 47 |
+
loc=None,
|
| 48 |
+
ip=None,
|
| 49 |
+
) -> Tuple[core.CopyAtom, Tensor]:
|
| 50 |
+
"""
|
| 51 |
+
Makes a TMA Copy Atom in the ``.tile`` mode to copy tiles of a GMEM tensor to/from SMEM
|
| 52 |
+
buffer with the given Layout.
|
| 53 |
+
|
| 54 |
+
Given
|
| 55 |
+
|
| 56 |
+
- a GMEM tensor
|
| 57 |
+
- a SMEM layout
|
| 58 |
+
- a CTA-level Tiler
|
| 59 |
+
|
| 60 |
+
this function figures out the bulk tensor asynchronous copy instruction to use with the maximum
|
| 61 |
+
"TMA vector length" to copy tiles of the GMEM tensor to/from an SMEM buffer with the provided
|
| 62 |
+
layout and consistent with the provided Tiler.
|
| 63 |
+
|
| 64 |
+
This function returns two results:
|
| 65 |
+
|
| 66 |
+
1. the Copy Atom
|
| 67 |
+
2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates \
|
| 68 |
+
that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the \
|
| 69 |
+
associated layout can output coordinates. Otherwise, TMA tensors can be partitioned \
|
| 70 |
+
similarly to any other CuTe tensors using the algebra.
|
| 71 |
+
|
| 72 |
+
:param op: The Copy Operation to construct an Atom for
|
| 73 |
+
:type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp, CopyReduceBulkTensorTileS2GOp]
|
| 74 |
+
:param gmem_tensor: The GMEM tensor involved in the Copy
|
| 75 |
+
:type gmem_tensor: Tensor
|
| 76 |
+
:param smem_layout: The SMEM layout to construct the Copy Atom for
|
| 77 |
+
:type smem_layout: Union[Layout, core.ComposedLayout]
|
| 78 |
+
:param cta_tiler: The CTA Tiler to use
|
| 79 |
+
:type cta_tiler: Tiler
|
| 80 |
+
:param num_multicast: The multicast factor
|
| 81 |
+
:type num_multicast: int
|
| 82 |
+
:param internal_type: An optional parameter for the internal data type to use when the actual data type is not supported by the TMA unit
|
| 83 |
+
:type internal_type: Type[Numeric]
|
| 84 |
+
:return: A Copy Atom for this Operation and the associated TMA tensor
|
| 85 |
+
:rtype: Tuple[core.CopyAtom, Tensor]
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
if internal_type is not None:
|
| 89 |
+
if not isinstance(internal_type, NumericMeta):
|
| 90 |
+
raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
|
| 91 |
+
internal_type = internal_type.mlir_type
|
| 92 |
+
|
| 93 |
+
cta_v_map = core.composition(
|
| 94 |
+
core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip),
|
| 95 |
+
cta_tiler,
|
| 96 |
+
loc=loc,
|
| 97 |
+
ip=ip,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if isinstance(op, CopyBulkTensorTileG2SOp):
|
| 101 |
+
if num_multicast != 1:
|
| 102 |
+
raise ValueError(
|
| 103 |
+
f"expects num_multicast to be 1 for non multicast G2S copies, "
|
| 104 |
+
f"but got {num_multicast}"
|
| 105 |
+
)
|
| 106 |
+
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
|
| 107 |
+
gmem_tensor.value,
|
| 108 |
+
smem_layout,
|
| 109 |
+
cta_v_map,
|
| 110 |
+
op._to_ir(),
|
| 111 |
+
num_multicast=num_multicast,
|
| 112 |
+
internal_type=internal_type,
|
| 113 |
+
loc=loc,
|
| 114 |
+
ip=ip,
|
| 115 |
+
)
|
| 116 |
+
return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1]
|
| 117 |
+
elif isinstance(op, CopyBulkTensorTileG2SMulticastOp):
|
| 118 |
+
if num_multicast < 1:
|
| 119 |
+
raise ValueError(
|
| 120 |
+
f"expects num_multicast to be >= 1 for multicast G2S copies, "
|
| 121 |
+
f"but got {num_multicast}"
|
| 122 |
+
)
|
| 123 |
+
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
|
| 124 |
+
gmem_tensor.value,
|
| 125 |
+
smem_layout,
|
| 126 |
+
cta_v_map,
|
| 127 |
+
op._to_ir(),
|
| 128 |
+
num_multicast=num_multicast,
|
| 129 |
+
internal_type=internal_type,
|
| 130 |
+
loc=loc,
|
| 131 |
+
ip=ip,
|
| 132 |
+
)
|
| 133 |
+
return (
|
| 134 |
+
core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])),
|
| 135 |
+
res[1],
|
| 136 |
+
)
|
| 137 |
+
elif isinstance(op, CopyBulkTensorTileS2GOp):
|
| 138 |
+
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_store(
|
| 139 |
+
gmem_tensor.value,
|
| 140 |
+
smem_layout,
|
| 141 |
+
cta_v_map,
|
| 142 |
+
internal_type=internal_type,
|
| 143 |
+
loc=loc,
|
| 144 |
+
ip=ip,
|
| 145 |
+
)
|
| 146 |
+
return core.CopyAtom(op, CopyBulkTensorTileS2GTrait(res[0])), res[1]
|
| 147 |
+
elif isinstance(op, CopyReduceBulkTensorTileS2GOp):
|
| 148 |
+
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_reduce(
|
| 149 |
+
gmem_tensor.value,
|
| 150 |
+
smem_layout,
|
| 151 |
+
cta_v_map,
|
| 152 |
+
op._to_ir(),
|
| 153 |
+
internal_type=internal_type,
|
| 154 |
+
loc=loc,
|
| 155 |
+
ip=ip,
|
| 156 |
+
)
|
| 157 |
+
return core.CopyAtom(op, CopyReduceBulkTensorTileS2GTrait(res[0])), res[1]
|
| 158 |
+
else:
|
| 159 |
+
raise ValueError(f"expects a bulk tensor (TMA) Copy Op, but got {op}")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@dsl_user_op
|
| 163 |
+
def tma_partition(
|
| 164 |
+
atom: core.CopyAtom,
|
| 165 |
+
cta_coord: Coord,
|
| 166 |
+
cta_layout: Layout,
|
| 167 |
+
smem_tensor: Tensor,
|
| 168 |
+
gmem_tensor: Tensor,
|
| 169 |
+
*,
|
| 170 |
+
loc=None,
|
| 171 |
+
ip=None,
|
| 172 |
+
) -> Tuple[Tensor, Tensor]:
|
| 173 |
+
"""
|
| 174 |
+
Tiles the GMEM and SMEM tensors for the provided TMA Copy Atom.
|
| 175 |
+
"""
|
| 176 |
+
cta_coord_val = core._pack_coord(cta_coord, loc=loc, ip=ip)
|
| 177 |
+
s, d = _cute_nvgpu_ir.atom_tma_partition(
|
| 178 |
+
atom._trait.value,
|
| 179 |
+
cta_coord=cta_coord_val,
|
| 180 |
+
cta_layout=cta_layout,
|
| 181 |
+
smem_tensor=smem_tensor.value,
|
| 182 |
+
gmem_tensor=gmem_tensor.value,
|
| 183 |
+
loc=loc,
|
| 184 |
+
ip=ip,
|
| 185 |
+
)
|
| 186 |
+
return s, d
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@dsl_user_op
|
| 190 |
+
def create_tma_multicast_mask(
|
| 191 |
+
cta_layout_vmnk: Layout,
|
| 192 |
+
cta_coord_vmnk: Coord,
|
| 193 |
+
mcast_mode: int,
|
| 194 |
+
*,
|
| 195 |
+
loc=None,
|
| 196 |
+
ip=None,
|
| 197 |
+
) -> Int16:
|
| 198 |
+
"""
|
| 199 |
+
Computes a multicast mask for a TMA load Copy.
|
| 200 |
+
|
| 201 |
+
:param cta_layout_vmnk: The VMNK layout of the cluster
|
| 202 |
+
:type cta_layout_vmnk: Layout
|
| 203 |
+
:param cta_coord_vmnk: The VMNK coordinate of the current CTA
|
| 204 |
+
:type cta_coord_vmnk: Coord
|
| 205 |
+
:param mcast_mode: The tensor mode in which to multicast
|
| 206 |
+
:type mcast_mode: int
|
| 207 |
+
:return: The resulting mask
|
| 208 |
+
:rtype: Int16
|
| 209 |
+
"""
|
| 210 |
+
if core.rank(cta_layout_vmnk) != 4:
|
| 211 |
+
raise ValueError(
|
| 212 |
+
f"cta_layout_vmnk must be rank 4, but got {core.pretty_str(cta_layout_vmnk)}"
|
| 213 |
+
)
|
| 214 |
+
if core.rank(cta_coord_vmnk) != 4:
|
| 215 |
+
raise ValueError(
|
| 216 |
+
f"cta_coord_vmnk must be rank 4, but got {core.pretty_str(cta_coord_vmnk)}"
|
| 217 |
+
)
|
| 218 |
+
return core.make_layout_image_mask(
|
| 219 |
+
cta_layout_vmnk, cta_coord_vmnk, mcast_mode, loc=loc, ip=ip
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@dsl_user_op
|
| 224 |
+
def prefetch_descriptor(tma_atom: core.CopyAtom, *, loc=None, ip=None) -> None:
|
| 225 |
+
"""
|
| 226 |
+
Prefetches the TMA descriptor associated with the TMA Atom.
|
| 227 |
+
"""
|
| 228 |
+
_cute_nvgpu_ir.prefetch_tma_desc(tma_atom._trait.value, loc=loc, ip=ip)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
@dsl_user_op
|
| 232 |
+
def copy_tensormap(
|
| 233 |
+
tma_atom: core.CopyAtom, tensormap_ptr: Pointer, *, loc=None, ip=None
|
| 234 |
+
) -> None:
|
| 235 |
+
"""
|
| 236 |
+
Copies the tensormap held by a TMA Copy Atom to the memory location pointed to by the provided
|
| 237 |
+
pointer.
|
| 238 |
+
|
| 239 |
+
:param tma_atom: The TMA Copy Atom
|
| 240 |
+
:type tma_atom: CopyAtom
|
| 241 |
+
:param tensormap_ptr: The pointer to the memory location to copy the tensormap to
|
| 242 |
+
:type tensormap_ptr: Pointer
|
| 243 |
+
"""
|
| 244 |
+
_cute_nvgpu_ir.copy_tma_desc(
|
| 245 |
+
tma_atom._trait.value, tensormap_ptr.value, loc=loc, ip=ip
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
@dsl_user_op
|
| 250 |
+
def update_tma_descriptor(
|
| 251 |
+
tma_atom: core.CopyAtom,
|
| 252 |
+
gmem_tensor: Tensor,
|
| 253 |
+
tma_desc_ptr: Pointer,
|
| 254 |
+
*,
|
| 255 |
+
loc=None,
|
| 256 |
+
ip=None,
|
| 257 |
+
) -> None:
|
| 258 |
+
"""
|
| 259 |
+
Updates the TMA descriptor in the memory location pointed to by the provided pointer using
|
| 260 |
+
information from a TMA Copy Atom and the provided GMEM tensor.
|
| 261 |
+
|
| 262 |
+
Specifically, the following fields of the TMA descriptor will be updated:
|
| 263 |
+
|
| 264 |
+
1. the GMEM tensor base address
|
| 265 |
+
2. the GMEM tensor shape
|
| 266 |
+
3. the GMEM tensor stride
|
| 267 |
+
|
| 268 |
+
Other fields of the TMA descriptor are left unchanged.
|
| 269 |
+
|
| 270 |
+
:param tma_atom: The TMA Copy Atom
|
| 271 |
+
:type tma_atom: CopyAtom
|
| 272 |
+
:param gmem_tensor: The GMEM tensor
|
| 273 |
+
:type gmem_tensor: Tensor
|
| 274 |
+
:param tensormap_ptr: The pointer to the memory location of the descriptor to udpate
|
| 275 |
+
:type tensormap_ptr: Pointer
|
| 276 |
+
"""
|
| 277 |
+
_cute_nvgpu_ir.update_tma_desc(
|
| 278 |
+
tma_atom._trait.value, gmem_tensor.value, tma_desc_ptr.value, loc=loc, ip=ip
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
@dsl_user_op
|
| 283 |
+
def fence_tma_desc_acquire(
|
| 284 |
+
tma_desc_ptr: Pointer,
|
| 285 |
+
*,
|
| 286 |
+
loc=None,
|
| 287 |
+
ip=None,
|
| 288 |
+
) -> None:
|
| 289 |
+
"""
|
| 290 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
|
| 291 |
+
"""
|
| 292 |
+
tma_desc_ptr_i64 = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 293 |
+
llvm.inline_asm(
|
| 294 |
+
None,
|
| 295 |
+
[tma_desc_ptr_i64],
|
| 296 |
+
"fence.proxy.tensormap::generic.acquire.gpu [$0], 128;",
|
| 297 |
+
"l",
|
| 298 |
+
has_side_effects=True,
|
| 299 |
+
is_align_stack=False,
|
| 300 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@dsl_user_op
|
| 305 |
+
def cp_fence_tma_desc_release(
|
| 306 |
+
tma_desc_global_ptr: Pointer,
|
| 307 |
+
tma_desc_shared_ptr: Pointer,
|
| 308 |
+
*,
|
| 309 |
+
loc=None,
|
| 310 |
+
ip=None,
|
| 311 |
+
) -> None:
|
| 312 |
+
"""
|
| 313 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-tensormap-cp-fenceproxy>`__.
|
| 314 |
+
"""
|
| 315 |
+
tma_desc_global_ptr_i64 = tma_desc_global_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 316 |
+
tma_desc_shared_ptr_i32 = tma_desc_shared_ptr.toint(loc=loc, ip=ip).ir_value()
|
| 317 |
+
llvm.inline_asm(
|
| 318 |
+
None,
|
| 319 |
+
[tma_desc_global_ptr_i64, tma_desc_shared_ptr_i32],
|
| 320 |
+
"tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [$0], [$1], 128;",
|
| 321 |
+
"l,r",
|
| 322 |
+
has_side_effects=True,
|
| 323 |
+
is_align_stack=False,
|
| 324 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
@dsl_user_op
|
| 329 |
+
def fence_tma_desc_release(*, loc=None, ip=None) -> None:
|
| 330 |
+
"""
|
| 331 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`__.
|
| 332 |
+
"""
|
| 333 |
+
llvm.inline_asm(
|
| 334 |
+
None,
|
| 335 |
+
[],
|
| 336 |
+
"fence.proxy.tensormap::generic.release.gpu;",
|
| 337 |
+
"",
|
| 338 |
+
has_side_effects=True,
|
| 339 |
+
is_align_stack=False,
|
| 340 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 341 |
+
)
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from typing import Optional, Tuple, Type, Union
|
| 13 |
+
|
| 14 |
+
from cutlass.cutlass_dsl import dsl_user_op
|
| 15 |
+
|
| 16 |
+
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
| 17 |
+
|
| 18 |
+
from .. import core
|
| 19 |
+
from ..typing import Shape, Layout, Tensor, Numeric, NumericMeta
|
| 20 |
+
from ...impl_utils import check_type_in
|
| 21 |
+
from .cpasync.copy import (
|
| 22 |
+
CopyBulkTensorTileG2SOp,
|
| 23 |
+
CopyBulkTensorTileG2SNonExecTrait,
|
| 24 |
+
CopyBulkTensorTileG2SMulticastOp,
|
| 25 |
+
CopyBulkTensorTileG2SMulticastNonExecTrait,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
####################################################################################################
|
| 30 |
+
#
|
| 31 |
+
# TMA creation helpers for tcgen05 MMAs
|
| 32 |
+
#
|
| 33 |
+
####################################################################################################
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dsl_user_op
|
| 37 |
+
def make_tiled_tma_atom_A(
|
| 38 |
+
op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
|
| 39 |
+
gmem_tensor: Tensor,
|
| 40 |
+
smem_layout: Union[Layout, core.ComposedLayout],
|
| 41 |
+
mma_tiler_mnk: Shape,
|
| 42 |
+
tiled_mma: core.TiledMma,
|
| 43 |
+
cluster_shape_vmnk: Shape,
|
| 44 |
+
*,
|
| 45 |
+
internal_type: Optional[Type[Numeric]] = None,
|
| 46 |
+
loc=None,
|
| 47 |
+
ip=None,
|
| 48 |
+
) -> Tuple[core.CopyAtom, Tensor]:
|
| 49 |
+
"""
|
| 50 |
+
Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation
|
| 51 |
+
accounting for the MK projections of the TiledMMA for A tensor loads.
|
| 52 |
+
|
| 53 |
+
Given
|
| 54 |
+
|
| 55 |
+
- a GMEM tensor
|
| 56 |
+
- a SMEM layout
|
| 57 |
+
- a MMA Tiler
|
| 58 |
+
- a TiledMma
|
| 59 |
+
- a Cluster-level shape
|
| 60 |
+
|
| 61 |
+
this function figures out the bulk tensor asynchronous copy instruction to use with the maximum
|
| 62 |
+
"TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided
|
| 63 |
+
layout and consistent with the provided Tiler & tiled_mma (considering the M-mode & K-mode).
|
| 64 |
+
The Cluster-level shape is used to determine the multicast factor across the N-mode for A tensor loads.
|
| 65 |
+
|
| 66 |
+
This function returns two results:
|
| 67 |
+
|
| 68 |
+
1. the Copy Atom
|
| 69 |
+
2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates
|
| 70 |
+
that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the
|
| 71 |
+
associated layout can output coordinates. Otherwise, TMA tensors can be partitioned
|
| 72 |
+
similarly to any other CuTe tensors using the algebra.
|
| 73 |
+
|
| 74 |
+
:param op: The Copy Operation to construct an Atom for
|
| 75 |
+
:type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp]
|
| 76 |
+
:param gmem_tensor: The GMEM tensor to be loaded by this copy atom
|
| 77 |
+
:type gmem_tensor: Tensor
|
| 78 |
+
:param smem_layout: Shared memory layout to load the tensor into (PDSL)
|
| 79 |
+
:type smem_layout: Union[Layout, core.ComposedLayout]
|
| 80 |
+
:param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions
|
| 81 |
+
:type mma_tiler_mnk: Shape
|
| 82 |
+
:param tiled_mma: The TiledMMA that will consume the load as operands
|
| 83 |
+
:type tiled_mma: core.TiledMma
|
| 84 |
+
:param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions
|
| 85 |
+
:type cluster_shape_vmnk: Shape
|
| 86 |
+
:param internal_type: An optional parameter for the internal data type to when element
|
| 87 |
+
type does not match the copy type
|
| 88 |
+
:type internal_type: Type[Numeric]
|
| 89 |
+
:return: A copy atom for this operation and the associated TMA coord tensor
|
| 90 |
+
:rtype: Tuple[core.CopyAtom, Tensor]
|
| 91 |
+
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
if internal_type is not None:
|
| 95 |
+
if not isinstance(internal_type, NumericMeta):
|
| 96 |
+
raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
|
| 97 |
+
internal_type = internal_type.mlir_type
|
| 98 |
+
check_type_in(
|
| 99 |
+
op,
|
| 100 |
+
[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
|
| 101 |
+
"op",
|
| 102 |
+
"make_tiled_tma_atom_A",
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip)
|
| 106 |
+
mma_tiler_mk = (mma_tiler_mnk[0], *mma_tiler_mnk[2:])
|
| 107 |
+
g_tile = core.composition(ident, mma_tiler_mk, loc=loc, ip=ip)
|
| 108 |
+
cta_v_map = tiled_mma._thrfrg_A(g_tile)
|
| 109 |
+
cta_v_map = core.get(cta_v_map, mode=[1])
|
| 110 |
+
cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile)))
|
| 111 |
+
|
| 112 |
+
if isinstance(op, CopyBulkTensorTileG2SOp):
|
| 113 |
+
num_multicast = 1
|
| 114 |
+
else:
|
| 115 |
+
assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
|
| 116 |
+
# multicast across the N-mode since those would share the same tile of A
|
| 117 |
+
num_multicast = core.size(cluster_shape_vmnk, mode=[2])
|
| 118 |
+
|
| 119 |
+
# res[0] = the IR Value for the non-executable atom instance
|
| 120 |
+
# res[1] = the IR Value for the associated TMA tensor
|
| 121 |
+
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
|
| 122 |
+
gmem_tensor.value,
|
| 123 |
+
smem_layout,
|
| 124 |
+
cta_v_map,
|
| 125 |
+
op._to_ir(),
|
| 126 |
+
num_multicast=num_multicast,
|
| 127 |
+
internal_type=internal_type,
|
| 128 |
+
loc=loc,
|
| 129 |
+
ip=ip,
|
| 130 |
+
)
|
| 131 |
+
if isinstance(op, CopyBulkTensorTileG2SOp):
|
| 132 |
+
return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1]
|
| 133 |
+
else:
|
| 134 |
+
assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
|
| 135 |
+
return (
|
| 136 |
+
core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])),
|
| 137 |
+
res[1],
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@dsl_user_op
|
| 142 |
+
def make_tiled_tma_atom_B(
|
| 143 |
+
op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
|
| 144 |
+
gmem_tensor: Tensor,
|
| 145 |
+
smem_layout: Union[Layout, core.ComposedLayout],
|
| 146 |
+
mma_tiler_mnk: Shape,
|
| 147 |
+
tiled_mma: core.TiledMma,
|
| 148 |
+
cluster_shape_vmnk: Shape,
|
| 149 |
+
*,
|
| 150 |
+
internal_type: Optional[Type[Numeric]] = None,
|
| 151 |
+
loc=None,
|
| 152 |
+
ip=None,
|
| 153 |
+
) -> Tuple[core.CopyAtom, Tensor]:
|
| 154 |
+
"""
|
| 155 |
+
Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation
|
| 156 |
+
accounting for the NK projections of the TiledMMA for B tensor loads.
|
| 157 |
+
|
| 158 |
+
Given
|
| 159 |
+
|
| 160 |
+
- a GMEM tensor
|
| 161 |
+
- a SMEM layout
|
| 162 |
+
- a MMA Tiler
|
| 163 |
+
- a TiledMma
|
| 164 |
+
- a Cluster-level shape
|
| 165 |
+
|
| 166 |
+
this function figures out the bulk tensor asynchronous copy instruction to use with the maximum
|
| 167 |
+
"TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided
|
| 168 |
+
layout and consistent with the provided Tiler & tiled_mma (considering the N-mode & K-mode).
|
| 169 |
+
The Cluster-level shape is used to determine the multicast factor across the M-mode for B tensor loads.
|
| 170 |
+
|
| 171 |
+
This function returns two results:
|
| 172 |
+
|
| 173 |
+
1. the Copy Atom
|
| 174 |
+
2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates
|
| 175 |
+
that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the
|
| 176 |
+
associated layout can output coordinates. Otherwise, TMA tensors can be partitioned
|
| 177 |
+
similarly to any other CuTe tensors using the algebra.
|
| 178 |
+
|
| 179 |
+
:param op: The Copy Operation to construct an Atom for
|
| 180 |
+
:type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp]
|
| 181 |
+
:param gmem_tensor: The GMEM tensor to be loaded by this copy atom
|
| 182 |
+
:type gmem_tensor: Tensor
|
| 183 |
+
:param smem_layout: Shared memory layout to load the tensor into (PDSL)
|
| 184 |
+
:type smem_layout: Union[Layout, core.ComposedLayout]
|
| 185 |
+
:param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions
|
| 186 |
+
:type mma_tiler_mnk: Shape
|
| 187 |
+
:param tiled_mma: The TiledMMA that will consume the load as operands
|
| 188 |
+
:type tiled_mma: core.TiledMma
|
| 189 |
+
:param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions
|
| 190 |
+
:type cluster_shape_vmnk: Shape
|
| 191 |
+
:param internal_type: An optional parameter for the internal data type to when element
|
| 192 |
+
type does not match the copy type
|
| 193 |
+
:type internal_type: Type[Numeric]
|
| 194 |
+
:return: A Copy Atom for this Operation and the associated TMA tensor
|
| 195 |
+
:rtype: Tuple[core.CopyAtom, Tensor]
|
| 196 |
+
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
if internal_type is not None:
|
| 200 |
+
if not isinstance(internal_type, NumericMeta):
|
| 201 |
+
raise TypeError(f"internal_type must be a Numeric, but got {internal_type}")
|
| 202 |
+
internal_type = internal_type.mlir_type
|
| 203 |
+
check_type_in(
|
| 204 |
+
op,
|
| 205 |
+
[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp],
|
| 206 |
+
"op",
|
| 207 |
+
"make_tiled_tma_atom_B",
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip)
|
| 211 |
+
mma_tiler_nk = (mma_tiler_mnk[1], *mma_tiler_mnk[2:])
|
| 212 |
+
g_tile = core.composition(ident, mma_tiler_nk, loc=loc, ip=ip)
|
| 213 |
+
cta_v_map = tiled_mma._thrfrg_B(g_tile)
|
| 214 |
+
cta_v_map = core.get(cta_v_map, mode=[1])
|
| 215 |
+
cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile)))
|
| 216 |
+
|
| 217 |
+
if isinstance(op, CopyBulkTensorTileG2SOp):
|
| 218 |
+
num_multicast = 1
|
| 219 |
+
else:
|
| 220 |
+
assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
|
| 221 |
+
# multicast across the M-mode since those would share the same tile of B
|
| 222 |
+
num_multicast = core.size(cluster_shape_vmnk, mode=[1])
|
| 223 |
+
|
| 224 |
+
# res[0] = the IR Value for the non-executable atom instance
|
| 225 |
+
# res[1] = the IR Value for the associated TMA tensor
|
| 226 |
+
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load(
|
| 227 |
+
gmem_tensor.value,
|
| 228 |
+
smem_layout,
|
| 229 |
+
cta_v_map,
|
| 230 |
+
op._to_ir(),
|
| 231 |
+
num_multicast=num_multicast,
|
| 232 |
+
internal_type=internal_type,
|
| 233 |
+
loc=loc,
|
| 234 |
+
ip=ip,
|
| 235 |
+
)
|
| 236 |
+
if isinstance(op, CopyBulkTensorTileG2SOp):
|
| 237 |
+
return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1]
|
| 238 |
+
else:
|
| 239 |
+
assert isinstance(op, CopyBulkTensorTileG2SMulticastOp)
|
| 240 |
+
return (
|
| 241 |
+
core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])),
|
| 242 |
+
res[1],
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
__all__ = [
|
| 247 |
+
"make_tiled_tma_atom_A",
|
| 248 |
+
"make_tiled_tma_atom_B",
|
| 249 |
+
]
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from .copy import *
|
| 13 |
+
from .mma import *
|
| 14 |
+
from .helpers import *
|
| 15 |
+
|
| 16 |
+
# __all__ is required here for documentation generation
|
| 17 |
+
__all__ = [
|
| 18 |
+
#
|
| 19 |
+
# copy.py
|
| 20 |
+
#
|
| 21 |
+
"Repetition",
|
| 22 |
+
"Pack",
|
| 23 |
+
"Unpack",
|
| 24 |
+
"Ld16x64bOp",
|
| 25 |
+
"Ld16x128bOp",
|
| 26 |
+
"Ld16x256bOp",
|
| 27 |
+
"Ld16x32bx2Op",
|
| 28 |
+
"Ld32x32bOp",
|
| 29 |
+
"St16x64bOp",
|
| 30 |
+
"St16x128bOp",
|
| 31 |
+
"St16x256bOp",
|
| 32 |
+
"St16x32bx2Op",
|
| 33 |
+
"St32x32bOp",
|
| 34 |
+
#
|
| 35 |
+
# mma.py
|
| 36 |
+
#
|
| 37 |
+
"OperandMajorMode",
|
| 38 |
+
"OperandSource",
|
| 39 |
+
"CtaGroup",
|
| 40 |
+
"Field",
|
| 41 |
+
"MmaTF32Op",
|
| 42 |
+
"MmaF16BF16Op",
|
| 43 |
+
"MmaI8Op",
|
| 44 |
+
"MmaFP8Op",
|
| 45 |
+
"MmaMXF8Op",
|
| 46 |
+
"MmaMXF4Op",
|
| 47 |
+
"MmaMXF4NVF4Op",
|
| 48 |
+
"SmemLayoutAtomKind",
|
| 49 |
+
#
|
| 50 |
+
# helpers.py
|
| 51 |
+
#
|
| 52 |
+
"make_smem_layout_atom",
|
| 53 |
+
"tile_to_mma_shape",
|
| 54 |
+
"commit",
|
| 55 |
+
"is_tmem_load",
|
| 56 |
+
"is_tmem_store",
|
| 57 |
+
"get_tmem_copy_properties",
|
| 58 |
+
"find_tmem_tensor_col_offset",
|
| 59 |
+
"make_tmem_copy",
|
| 60 |
+
"make_s2t_copy",
|
| 61 |
+
"get_s2t_smem_desc_tensor",
|
| 62 |
+
]
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py
ADDED
|
@@ -0,0 +1,663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
import enum
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Type
|
| 15 |
+
|
| 16 |
+
from cutlass.cutlass_dsl import CuTeDSL
|
| 17 |
+
|
| 18 |
+
import cutlass._mlir.dialects.cute as _cute_ir
|
| 19 |
+
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
| 20 |
+
from cutlass._mlir import ir
|
| 21 |
+
|
| 22 |
+
from ..common import OpError
|
| 23 |
+
from ...core import CopyOp, Trait
|
| 24 |
+
from ...typing import Numeric
|
| 25 |
+
|
| 26 |
+
from .mma import CtaGroup
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Repetition(enum.Enum):
|
| 30 |
+
"""
|
| 31 |
+
An enumeration for the number of repetitions of a given TMEM copy within the instruction.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
x1 = 1
|
| 35 |
+
x2 = 2
|
| 36 |
+
x4 = 4
|
| 37 |
+
x8 = 8
|
| 38 |
+
x16 = 16
|
| 39 |
+
x32 = 32
|
| 40 |
+
x64 = 64
|
| 41 |
+
x128 = 128
|
| 42 |
+
|
| 43 |
+
def __str__(self) -> str:
|
| 44 |
+
return f"{self.__class__.__name__}.{self.name}"
|
| 45 |
+
|
| 46 |
+
def __repr__(self) -> str:
|
| 47 |
+
return f"<{self.__class__.__name__}.{self.name}>"
|
| 48 |
+
|
| 49 |
+
@classmethod
|
| 50 |
+
def _missing_(cls, value):
|
| 51 |
+
if isinstance(value, int):
|
| 52 |
+
if value == 1:
|
| 53 |
+
return Repetition.x1
|
| 54 |
+
elif value == 2:
|
| 55 |
+
return Repetition.x2
|
| 56 |
+
elif value == 8:
|
| 57 |
+
return Repetition.x8
|
| 58 |
+
elif value == 16:
|
| 59 |
+
return Repetition.x16
|
| 60 |
+
elif value == 32:
|
| 61 |
+
return Repetition.x32
|
| 62 |
+
elif value == 64:
|
| 63 |
+
return Repetition.x64
|
| 64 |
+
elif value == 128:
|
| 65 |
+
return Repetition.x128
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Pack(enum.Enum):
|
| 69 |
+
"""
|
| 70 |
+
An enumeration for the possible packing patterns for TMEM to RMEM copies.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
NONE = enum.auto()
|
| 74 |
+
PACK_16b_IN_32b = enum.auto()
|
| 75 |
+
|
| 76 |
+
def __str__(self) -> str:
|
| 77 |
+
return f"{self.__class__.__name__}.{self.name}"
|
| 78 |
+
|
| 79 |
+
def __repr__(self) -> str:
|
| 80 |
+
return f"<{self.__class__.__name__}.{self.name}>"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Unpack(enum.Enum):
|
| 84 |
+
"""
|
| 85 |
+
An enumeration for the possible unpacking patterns for RMEM to TMEM copies.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
NONE = enum.auto()
|
| 89 |
+
UNPACK_32b_IN_16b = enum.auto()
|
| 90 |
+
|
| 91 |
+
def __str__(self) -> str:
|
| 92 |
+
return f"{self.__class__.__name__}.{self.name}"
|
| 93 |
+
|
| 94 |
+
def __repr__(self) -> str:
|
| 95 |
+
return f"<{self.__class__.__name__}.{self.name}>"
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dataclass(frozen=True)
|
| 99 |
+
class _LdBase(CopyOp):
|
| 100 |
+
repeat: Repetition = Repetition.x1
|
| 101 |
+
pack: Pack = Pack.NONE
|
| 102 |
+
|
| 103 |
+
admissible_archs = [
|
| 104 |
+
"sm_100a",
|
| 105 |
+
"sm_100f",
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
def __post_init__(self) -> None:
|
| 109 |
+
# Arch verification
|
| 110 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 111 |
+
if arch not in self.admissible_archs:
|
| 112 |
+
raise OpError(
|
| 113 |
+
self,
|
| 114 |
+
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
| 115 |
+
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
if not isinstance(self.repeat, Repetition):
|
| 119 |
+
raise OpError(
|
| 120 |
+
self,
|
| 121 |
+
"expects the 'repeat' Op parameter to be a tcgen05.Repetition instance",
|
| 122 |
+
)
|
| 123 |
+
if not isinstance(self.pack, Pack):
|
| 124 |
+
raise OpError(
|
| 125 |
+
self,
|
| 126 |
+
"expects the 'pack' Op parameter to be a tcgen05.Pack instance",
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def __str__(self) -> str:
|
| 130 |
+
res = (
|
| 131 |
+
f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation"
|
| 132 |
+
+ f"\n number of repetitions = {self.repeat.value}"
|
| 133 |
+
)
|
| 134 |
+
if self.pack == Pack.PACK_16b_IN_32b:
|
| 135 |
+
res += f"\n with 2x 16-bit to 32b packing"
|
| 136 |
+
return res
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@dataclass(frozen=True)
|
| 140 |
+
class Ld16x64bOp(_LdBase):
|
| 141 |
+
"""
|
| 142 |
+
16x64b TMEM load Operation.
|
| 143 |
+
|
| 144 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
|
| 145 |
+
This Operation corresponds to the ``.16x64b`` qualifier.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
def _make_trait(
|
| 149 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 150 |
+
) -> "Ld16x64bTrait":
|
| 151 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
|
| 152 |
+
copy_internal_type.mlir_type,
|
| 153 |
+
16,
|
| 154 |
+
64,
|
| 155 |
+
self.repeat.value,
|
| 156 |
+
ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
|
| 157 |
+
)
|
| 158 |
+
return Ld16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class Ld16x64bTrait(Trait):
|
| 162 |
+
pass
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@dataclass(frozen=True)
|
| 166 |
+
class Ld16x128bOp(_LdBase):
|
| 167 |
+
"""
|
| 168 |
+
16x128b TMEM load Operation.
|
| 169 |
+
|
| 170 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
|
| 171 |
+
This Operation corresponds to the ``.16x128b`` qualifier.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __post_init__(self) -> None:
|
| 175 |
+
super().__post_init__()
|
| 176 |
+
if self.repeat == Repetition.x128:
|
| 177 |
+
raise OpError(
|
| 178 |
+
self,
|
| 179 |
+
"x128 repetition is not supported",
|
| 180 |
+
suggestion="choose one of x1, x2, x4, x8, x16, x32, x64",
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def _make_trait(
|
| 184 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 185 |
+
) -> "Ld16x128bTrait":
|
| 186 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
|
| 187 |
+
copy_internal_type.mlir_type,
|
| 188 |
+
16,
|
| 189 |
+
128,
|
| 190 |
+
self.repeat.value,
|
| 191 |
+
ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
|
| 192 |
+
)
|
| 193 |
+
return Ld16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class Ld16x128bTrait(Trait):
|
| 197 |
+
pass
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@dataclass(frozen=True)
|
| 201 |
+
class Ld16x256bOp(_LdBase):
|
| 202 |
+
"""
|
| 203 |
+
16x256b TMEM load Operation.
|
| 204 |
+
|
| 205 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
|
| 206 |
+
This Operation corresponds to the ``.16x256b`` qualifier.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
def __post_init__(self) -> None:
|
| 210 |
+
super().__post_init__()
|
| 211 |
+
if self.repeat in (Repetition.x128, Repetition.x64):
|
| 212 |
+
raise OpError(
|
| 213 |
+
self,
|
| 214 |
+
"x64 and x128 repetition is not supported",
|
| 215 |
+
suggestion="choose one of x1, x2, x4, x8, x16, x32",
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def _make_trait(
|
| 219 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 220 |
+
) -> "Ld16x256bTrait":
|
| 221 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
|
| 222 |
+
copy_internal_type.mlir_type,
|
| 223 |
+
16,
|
| 224 |
+
256,
|
| 225 |
+
self.repeat.value,
|
| 226 |
+
ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
|
| 227 |
+
)
|
| 228 |
+
return Ld16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class Ld16x256bTrait(Trait):
|
| 232 |
+
pass
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@dataclass(frozen=True)
|
| 236 |
+
class Ld16x32bx2Op(_LdBase):
|
| 237 |
+
"""
|
| 238 |
+
16x32bx2 TMEM load Operation.
|
| 239 |
+
|
| 240 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
|
| 241 |
+
This Operation corresponds to the ``.16x32bx2`` qualifier.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def _make_trait(
|
| 245 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 246 |
+
) -> "Ld16x32bx2Trait":
|
| 247 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
|
| 248 |
+
copy_internal_type.mlir_type,
|
| 249 |
+
16,
|
| 250 |
+
32,
|
| 251 |
+
self.repeat.value,
|
| 252 |
+
ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
|
| 253 |
+
)
|
| 254 |
+
return Ld16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class Ld16x32bx2Trait(Trait):
|
| 258 |
+
pass
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
@dataclass(frozen=True)
|
| 262 |
+
class Ld32x32bOp(_LdBase):
|
| 263 |
+
"""
|
| 264 |
+
32x32b TMEM load Operation.
|
| 265 |
+
|
| 266 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
|
| 267 |
+
This Operation corresponds to the ``.32x32`` qualifier.
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
def _make_trait(
|
| 271 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 272 |
+
) -> "Ld32x32bTrait":
|
| 273 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get(
|
| 274 |
+
copy_internal_type.mlir_type,
|
| 275 |
+
32,
|
| 276 |
+
32,
|
| 277 |
+
self.repeat.value,
|
| 278 |
+
ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None,
|
| 279 |
+
)
|
| 280 |
+
return Ld32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class Ld32x32bTrait(Trait):
|
| 284 |
+
pass
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
@dataclass(frozen=True)
|
| 288 |
+
class _StBase(CopyOp):
|
| 289 |
+
repeat: Repetition
|
| 290 |
+
unpack: Unpack = Unpack.NONE
|
| 291 |
+
|
| 292 |
+
admissible_archs = [
|
| 293 |
+
"sm_100a",
|
| 294 |
+
"sm_100f",
|
| 295 |
+
]
|
| 296 |
+
|
| 297 |
+
def __post_init__(self) -> None:
|
| 298 |
+
# Arch verification
|
| 299 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 300 |
+
if arch not in self.admissible_archs:
|
| 301 |
+
raise OpError(
|
| 302 |
+
self,
|
| 303 |
+
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
| 304 |
+
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if not isinstance(self.repeat, Repetition):
|
| 308 |
+
raise OpError(
|
| 309 |
+
self,
|
| 310 |
+
"expects the 'repeat' Op parameter to be a tcgen05.Repetition instance",
|
| 311 |
+
)
|
| 312 |
+
if not isinstance(self.unpack, Unpack):
|
| 313 |
+
raise OpError(
|
| 314 |
+
self,
|
| 315 |
+
"expects the 'pack' Op parameter to be a tcgen05.Unpack instance",
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
def __str__(self) -> str:
|
| 319 |
+
res = (
|
| 320 |
+
f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation"
|
| 321 |
+
+ f"\n number of repetitions = {self.repeat.value}"
|
| 322 |
+
)
|
| 323 |
+
if self.unpack == Unpack.UNPACK_32b_IN_16b:
|
| 324 |
+
res += f"\n with 32-bit to 2x 16b unpacking"
|
| 325 |
+
return res
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
@dataclass(frozen=True)
|
| 329 |
+
class St16x64bOp(_StBase):
|
| 330 |
+
"""
|
| 331 |
+
16x64b TMEM store Operation.
|
| 332 |
+
|
| 333 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
|
| 334 |
+
This Operation corresponds to the ``.16x64`` qualifier.
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
def _make_trait(
|
| 338 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 339 |
+
) -> "St16x64bTrait":
|
| 340 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
|
| 341 |
+
copy_internal_type.mlir_type,
|
| 342 |
+
16,
|
| 343 |
+
64,
|
| 344 |
+
self.repeat.value,
|
| 345 |
+
ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
|
| 346 |
+
)
|
| 347 |
+
return St16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
class St16x64bTrait(Trait):
|
| 351 |
+
pass
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
@dataclass(frozen=True)
|
| 355 |
+
class St16x128bOp(_StBase):
|
| 356 |
+
"""
|
| 357 |
+
16x128b TMEM store Operation.
|
| 358 |
+
|
| 359 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
|
| 360 |
+
This Operation corresponds to the ``.16x128`` qualifier.
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
def __post_init__(self) -> None:
|
| 364 |
+
super().__post_init__()
|
| 365 |
+
if self.repeat == Repetition.x128:
|
| 366 |
+
raise OpError(
|
| 367 |
+
self,
|
| 368 |
+
"x128 repetition is not supported",
|
| 369 |
+
suggestion="choose one of x1, x2, x4, x8, x16, x32, x64",
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
def _make_trait(
|
| 373 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 374 |
+
) -> "St16x128bTrait":
|
| 375 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
|
| 376 |
+
copy_internal_type.mlir_type,
|
| 377 |
+
16,
|
| 378 |
+
128,
|
| 379 |
+
self.repeat.value,
|
| 380 |
+
ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
|
| 381 |
+
)
|
| 382 |
+
return St16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
class St16x128bTrait(Trait):
|
| 386 |
+
pass
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
@dataclass(frozen=True)
|
| 390 |
+
class St16x256bOp(_StBase):
|
| 391 |
+
"""
|
| 392 |
+
16x256b TMEM store Operation.
|
| 393 |
+
|
| 394 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
|
| 395 |
+
This Operation corresponds to the ``.16x256`` qualifier.
|
| 396 |
+
"""
|
| 397 |
+
|
| 398 |
+
def __post_init__(self) -> None:
|
| 399 |
+
super().__post_init__()
|
| 400 |
+
if self.repeat in (Repetition.x128, Repetition.x64):
|
| 401 |
+
raise OpError(
|
| 402 |
+
self,
|
| 403 |
+
"x64 and x128 repetition is not supported",
|
| 404 |
+
suggestion="choose one of x1, x2, x4, x8, x16, x32",
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
def _make_trait(
|
| 408 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 409 |
+
) -> "St16x256bTrait":
|
| 410 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
|
| 411 |
+
copy_internal_type.mlir_type,
|
| 412 |
+
16,
|
| 413 |
+
256,
|
| 414 |
+
self.repeat.value,
|
| 415 |
+
ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
|
| 416 |
+
)
|
| 417 |
+
return St16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class St16x256bTrait(Trait):
|
| 421 |
+
pass
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
@dataclass(frozen=True)
|
| 425 |
+
class St16x32bx2Op(_StBase):
|
| 426 |
+
"""
|
| 427 |
+
16x32x2b TMEM store Operation.
|
| 428 |
+
|
| 429 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
|
| 430 |
+
This Operation corresponds to the ``.16x32x2`` qualifier.
|
| 431 |
+
"""
|
| 432 |
+
|
| 433 |
+
def _make_trait(
|
| 434 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 435 |
+
) -> "St16x32bx2Trait":
|
| 436 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
|
| 437 |
+
copy_internal_type.mlir_type,
|
| 438 |
+
16,
|
| 439 |
+
32,
|
| 440 |
+
self.repeat.value,
|
| 441 |
+
ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
|
| 442 |
+
)
|
| 443 |
+
return St16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class St16x32bx2Trait(Trait):
|
| 447 |
+
pass
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
@dataclass(frozen=True)
|
| 451 |
+
class St32x32bOp(_StBase):
|
| 452 |
+
"""
|
| 453 |
+
32x32b TMEM store Operation.
|
| 454 |
+
|
| 455 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
|
| 456 |
+
This Operation corresponds to the ``.32x32`` qualifier.
|
| 457 |
+
"""
|
| 458 |
+
|
| 459 |
+
def _make_trait(
|
| 460 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 461 |
+
) -> "St32x32bTrait":
|
| 462 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get(
|
| 463 |
+
copy_internal_type.mlir_type,
|
| 464 |
+
32,
|
| 465 |
+
32,
|
| 466 |
+
self.repeat.value,
|
| 467 |
+
ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None,
|
| 468 |
+
)
|
| 469 |
+
return St32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class St32x32bTrait(Trait):
|
| 473 |
+
pass
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
@dataclass(frozen=True)
|
| 477 |
+
class _S2TCopyBase(CopyOp):
|
| 478 |
+
cta_group: CtaGroup
|
| 479 |
+
|
| 480 |
+
admissible_archs = [
|
| 481 |
+
"sm_100a",
|
| 482 |
+
"sm_100f",
|
| 483 |
+
]
|
| 484 |
+
|
| 485 |
+
def __post_init__(self) -> None:
|
| 486 |
+
# Arch verification
|
| 487 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 488 |
+
if arch not in self.admissible_archs:
|
| 489 |
+
raise OpError(
|
| 490 |
+
self,
|
| 491 |
+
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
| 492 |
+
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
| 493 |
+
)
|
| 494 |
+
# Verify that the user provided enum values
|
| 495 |
+
if not isinstance(self.cta_group, CtaGroup):
|
| 496 |
+
raise OpError(
|
| 497 |
+
self,
|
| 498 |
+
"expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance",
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
def __str__(self) -> str:
|
| 502 |
+
res = (
|
| 503 |
+
f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation"
|
| 504 |
+
+ f"\n CTA group = {self.cta_group}"
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
return res
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
@dataclass(frozen=True)
|
| 511 |
+
class Cp128x256bOp(_S2TCopyBase):
|
| 512 |
+
"""
|
| 513 |
+
128x256b SMEM to TMEM Copy Operation.
|
| 514 |
+
|
| 515 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
|
| 516 |
+
This Operation corresponds to the ``.128x256b`` qualifier.
|
| 517 |
+
"""
|
| 518 |
+
|
| 519 |
+
def _make_trait(
|
| 520 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 521 |
+
) -> "Cp128x256bTrait":
|
| 522 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
|
| 523 |
+
copy_internal_type.mlir_type,
|
| 524 |
+
128,
|
| 525 |
+
256,
|
| 526 |
+
self.cta_group.value,
|
| 527 |
+
_cute_nvgpu_ir.CopyS2TBroadcast.none,
|
| 528 |
+
)
|
| 529 |
+
return Cp128x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
class Cp128x256bTrait(Trait):
|
| 533 |
+
pass
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
@dataclass(frozen=True)
|
| 537 |
+
class Cp128x128bOp(_S2TCopyBase):
|
| 538 |
+
"""
|
| 539 |
+
128x128b SMEM to TMEM Copy Operation.
|
| 540 |
+
|
| 541 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
|
| 542 |
+
This Operation corresponds to the ``.128x128b`` qualifier.
|
| 543 |
+
"""
|
| 544 |
+
|
| 545 |
+
def _make_trait(
|
| 546 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 547 |
+
) -> "Cp128x128bTrait":
|
| 548 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
|
| 549 |
+
copy_internal_type.mlir_type,
|
| 550 |
+
128,
|
| 551 |
+
128,
|
| 552 |
+
self.cta_group.value,
|
| 553 |
+
_cute_nvgpu_ir.CopyS2TBroadcast.none,
|
| 554 |
+
)
|
| 555 |
+
return Cp128x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
class Cp128x128bTrait(Trait):
|
| 559 |
+
pass
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
@dataclass(frozen=True)
|
| 563 |
+
class Cp4x256bOp(_S2TCopyBase):
|
| 564 |
+
"""
|
| 565 |
+
4x256b SMEM to TMEM Copy Operation.
|
| 566 |
+
|
| 567 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
|
| 568 |
+
This Operation corresponds to the ``.4x256b`` qualifier.
|
| 569 |
+
"""
|
| 570 |
+
|
| 571 |
+
def _make_trait(
|
| 572 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 573 |
+
) -> "Cp4x256bTrait":
|
| 574 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
|
| 575 |
+
copy_internal_type.mlir_type,
|
| 576 |
+
4,
|
| 577 |
+
256,
|
| 578 |
+
self.cta_group.value,
|
| 579 |
+
_cute_nvgpu_ir.CopyS2TBroadcast.none,
|
| 580 |
+
)
|
| 581 |
+
return Cp4x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
class Cp4x256bTrait(Trait):
|
| 585 |
+
pass
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
@dataclass(frozen=True)
|
| 589 |
+
class Cp4x32x128bOp(_S2TCopyBase):
|
| 590 |
+
"""
|
| 591 |
+
32x128b SMEM to TMEM Copy Operation.
|
| 592 |
+
|
| 593 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
|
| 594 |
+
This Operation corresponds to the ``.32x128b`` qualifier with ``warpx4`` broadcast qualifier enabled.
|
| 595 |
+
"""
|
| 596 |
+
|
| 597 |
+
def _make_trait(
|
| 598 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 599 |
+
) -> "Cp4x32x128bTrait":
|
| 600 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
|
| 601 |
+
copy_internal_type.mlir_type,
|
| 602 |
+
32,
|
| 603 |
+
128,
|
| 604 |
+
self.cta_group.value,
|
| 605 |
+
_cute_nvgpu_ir.CopyS2TBroadcast.x4,
|
| 606 |
+
)
|
| 607 |
+
return Cp4x32x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
class Cp4x32x128bTrait(Trait):
|
| 611 |
+
pass
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
@dataclass(frozen=True)
|
| 615 |
+
class Cp2x64x128b0213Op(_S2TCopyBase):
|
| 616 |
+
"""
|
| 617 |
+
64x128b SMEM to TMEM Copy Operation.
|
| 618 |
+
|
| 619 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
|
| 620 |
+
This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::02_13`` broadcast qualifier enabled.
|
| 621 |
+
"""
|
| 622 |
+
|
| 623 |
+
def _make_trait(
|
| 624 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 625 |
+
) -> "Cp2x64x128b0213Trait":
|
| 626 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
|
| 627 |
+
copy_internal_type.mlir_type,
|
| 628 |
+
64,
|
| 629 |
+
128,
|
| 630 |
+
self.cta_group.value,
|
| 631 |
+
_cute_nvgpu_ir.CopyS2TBroadcast.lw_0213,
|
| 632 |
+
)
|
| 633 |
+
return Cp2x64x128b0213Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
class Cp2x64x128b0213Trait(Trait):
|
| 637 |
+
pass
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
@dataclass(frozen=True)
|
| 641 |
+
class Cp2x64x128b0123Op(_S2TCopyBase):
|
| 642 |
+
"""
|
| 643 |
+
64x128b SMEM to TMEM Copy Operation.
|
| 644 |
+
|
| 645 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=tcgen05#tcgen05-instructions-tcgen05-cp>`__.
|
| 646 |
+
This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::01_23`` broadcast qualifier enabled.
|
| 647 |
+
"""
|
| 648 |
+
|
| 649 |
+
def _make_trait(
|
| 650 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 651 |
+
) -> "Cp2x64x128b0123Trait":
|
| 652 |
+
ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get(
|
| 653 |
+
copy_internal_type.mlir_type,
|
| 654 |
+
64,
|
| 655 |
+
128,
|
| 656 |
+
self.cta_group.value,
|
| 657 |
+
_cute_nvgpu_ir.CopyS2TBroadcast.lw_0123,
|
| 658 |
+
)
|
| 659 |
+
return Cp2x64x128b0123Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
class Cp2x64x128b0123Trait(Trait):
|
| 663 |
+
pass
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from typing import overload, Type, Tuple, Union
|
| 13 |
+
|
| 14 |
+
from cutlass.cutlass_dsl import dsl_user_op
|
| 15 |
+
|
| 16 |
+
import cutlass._mlir.dialects.cute as _cute_ir
|
| 17 |
+
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
| 18 |
+
from cutlass._mlir.dialects import nvvm
|
| 19 |
+
|
| 20 |
+
from ...typing import (
|
| 21 |
+
Shape,
|
| 22 |
+
IntTuple,
|
| 23 |
+
Layout,
|
| 24 |
+
Tensor,
|
| 25 |
+
Int,
|
| 26 |
+
Numeric,
|
| 27 |
+
NumericMeta,
|
| 28 |
+
Int16,
|
| 29 |
+
Int32,
|
| 30 |
+
)
|
| 31 |
+
from ... import core
|
| 32 |
+
from .mma import SmemLayoutAtomKind, CtaGroup
|
| 33 |
+
from .copy import (
|
| 34 |
+
Pack,
|
| 35 |
+
Unpack,
|
| 36 |
+
Ld16x64bOp,
|
| 37 |
+
Ld16x128bOp,
|
| 38 |
+
Ld16x256bOp,
|
| 39 |
+
Ld16x32bx2Op,
|
| 40 |
+
Ld32x32bOp,
|
| 41 |
+
St16x64bOp,
|
| 42 |
+
St16x128bOp,
|
| 43 |
+
St16x256bOp,
|
| 44 |
+
St16x32bx2Op,
|
| 45 |
+
St32x32bOp,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
####################################################################################################
|
| 50 |
+
#
|
| 51 |
+
# Helper functions for MMA
|
| 52 |
+
#
|
| 53 |
+
####################################################################################################
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dsl_user_op
|
| 57 |
+
def make_smem_layout_atom(
|
| 58 |
+
kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None
|
| 59 |
+
) -> core.ComposedLayout:
|
| 60 |
+
"""
|
| 61 |
+
Makes a SMEM layout Atom.
|
| 62 |
+
|
| 63 |
+
This function creates a composed layout in unit of elements consistent with the requested layout
|
| 64 |
+
Atom kind and element data type.
|
| 65 |
+
|
| 66 |
+
:param kind: The kind of layout Atom
|
| 67 |
+
:type kind: SmemLayoutAtomKind
|
| 68 |
+
:param element_type: The element data type to construct the layout for
|
| 69 |
+
:type element_type: Type[Numeric]
|
| 70 |
+
:return: The SMEM layout atom
|
| 71 |
+
:rtype: core.ComposedLayout
|
| 72 |
+
"""
|
| 73 |
+
if not isinstance(element_type, NumericMeta):
|
| 74 |
+
raise TypeError(f"element_type must be a Numeric, but got {element_type}")
|
| 75 |
+
|
| 76 |
+
if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER):
|
| 77 |
+
num_contiguous_bits = 128
|
| 78 |
+
sw = core.make_swizzle(0, 4, 3)
|
| 79 |
+
elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32):
|
| 80 |
+
num_contiguous_bits = 256
|
| 81 |
+
sw = core.make_swizzle(1, 4, 3)
|
| 82 |
+
elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64):
|
| 83 |
+
num_contiguous_bits = 512
|
| 84 |
+
sw = core.make_swizzle(2, 4, 3)
|
| 85 |
+
elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128):
|
| 86 |
+
num_contiguous_bits = 1024
|
| 87 |
+
sw = core.make_swizzle(3, 4, 3)
|
| 88 |
+
elif kind == SmemLayoutAtomKind.MN_SW128_32B:
|
| 89 |
+
num_contiguous_bits = 1024
|
| 90 |
+
sw = core.make_swizzle(2, 5, 2)
|
| 91 |
+
else:
|
| 92 |
+
raise ValueError("unrecognized SMEM layout atom kind")
|
| 93 |
+
num_contiguous_elems = num_contiguous_bits // element_type.width
|
| 94 |
+
|
| 95 |
+
if kind in (
|
| 96 |
+
SmemLayoutAtomKind.MN_INTER,
|
| 97 |
+
SmemLayoutAtomKind.MN_SW32,
|
| 98 |
+
SmemLayoutAtomKind.MN_SW64,
|
| 99 |
+
SmemLayoutAtomKind.MN_SW128,
|
| 100 |
+
SmemLayoutAtomKind.MN_SW128_32B,
|
| 101 |
+
):
|
| 102 |
+
# M/N-major layout
|
| 103 |
+
return core.make_composed_layout(
|
| 104 |
+
sw,
|
| 105 |
+
0,
|
| 106 |
+
core.make_layout(
|
| 107 |
+
(num_contiguous_elems, 8), stride=(1, num_contiguous_elems)
|
| 108 |
+
),
|
| 109 |
+
loc=loc,
|
| 110 |
+
ip=ip,
|
| 111 |
+
)
|
| 112 |
+
else:
|
| 113 |
+
# K-major layout
|
| 114 |
+
return core.make_composed_layout(
|
| 115 |
+
sw,
|
| 116 |
+
0,
|
| 117 |
+
core.make_layout(
|
| 118 |
+
(8, num_contiguous_elems), stride=(num_contiguous_elems, 1)
|
| 119 |
+
),
|
| 120 |
+
loc=loc,
|
| 121 |
+
ip=ip,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@overload
|
| 126 |
+
def tile_to_mma_shape(
|
| 127 |
+
atom: Layout, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None
|
| 128 |
+
) -> Layout: ...
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@overload
|
| 132 |
+
def tile_to_mma_shape(
|
| 133 |
+
atom: core.ComposedLayout,
|
| 134 |
+
mma_tile_shape: Shape,
|
| 135 |
+
order: IntTuple = None,
|
| 136 |
+
*,
|
| 137 |
+
loc=None,
|
| 138 |
+
ip=None,
|
| 139 |
+
) -> core.ComposedLayout: ...
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@dsl_user_op
|
| 143 |
+
def tile_to_mma_shape(
|
| 144 |
+
atom, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None
|
| 145 |
+
):
|
| 146 |
+
"""
|
| 147 |
+
Tiles a layout to an MMA shape.
|
| 148 |
+
"""
|
| 149 |
+
# Default order is colexicographical
|
| 150 |
+
if order is None:
|
| 151 |
+
order = tuple(range(core.rank(mma_tile_shape) - 1))
|
| 152 |
+
if core.rank(order) != core.rank(mma_tile_shape) - 1:
|
| 153 |
+
raise ValueError(
|
| 154 |
+
f"rank(order)={core.rank(order)} must be equal to "
|
| 155 |
+
f"rank(mma_tile_shape)-1={core.rank(mma_tile_shape)-1}"
|
| 156 |
+
)
|
| 157 |
+
order_val = core._pack_int_tuple(order, loc=loc, ip=ip)
|
| 158 |
+
mma_tile_shape_val = core._pack_shape(mma_tile_shape, loc=loc, ip=ip)
|
| 159 |
+
|
| 160 |
+
if not (
|
| 161 |
+
core.is_static(atom)
|
| 162 |
+
and core.is_static(mma_tile_shape_val)
|
| 163 |
+
and core.is_static(order_val)
|
| 164 |
+
):
|
| 165 |
+
raise ValueError("tile_to_mma_shape only supports static inputs")
|
| 166 |
+
|
| 167 |
+
res_ty = _cute_nvgpu_ir.tile_to_mma_shape(atom, mma_tile_shape_val, order_val)
|
| 168 |
+
return _cute_ir.static(res_ty, loc=loc, ip=ip)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@dsl_user_op
|
| 172 |
+
def commit(
|
| 173 |
+
mbar_ptr: core.Pointer,
|
| 174 |
+
mask=None,
|
| 175 |
+
cta_group: CtaGroup = CtaGroup.ONE,
|
| 176 |
+
*,
|
| 177 |
+
loc=None,
|
| 178 |
+
ip=None,
|
| 179 |
+
) -> None:
|
| 180 |
+
"""
|
| 181 |
+
Perform an arrive operation on a mbarrier upon completion of previous MMA operations.
|
| 182 |
+
|
| 183 |
+
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
| 184 |
+
:type mbar_ptr: Pointer
|
| 185 |
+
:param mask: An optional multicast mask for the CTAs in the cluster to signal arrival to
|
| 186 |
+
:type mask: Int
|
| 187 |
+
"""
|
| 188 |
+
if cta_group == CtaGroup.ONE:
|
| 189 |
+
group = nvvm.Tcgen05GroupKind.CTA_1
|
| 190 |
+
else:
|
| 191 |
+
assert cta_group == CtaGroup.TWO
|
| 192 |
+
group = nvvm.Tcgen05GroupKind.CTA_2
|
| 193 |
+
|
| 194 |
+
mbar_ptr = mbar_ptr.llvm_ptr
|
| 195 |
+
if mask is not None:
|
| 196 |
+
mask = Int16(mask).ir_value(loc=loc, ip=ip)
|
| 197 |
+
nvvm.tcgen05_commit_arrive(
|
| 198 |
+
mbar_ptr, multicast_mask=mask, group=group, loc=loc, ip=ip
|
| 199 |
+
)
|
| 200 |
+
else:
|
| 201 |
+
nvvm.tcgen05_commit_arrive(mbar_ptr, group=group, loc=loc, ip=ip)
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
####################################################################################################
|
| 206 |
+
#
|
| 207 |
+
# Helper functions for Copies
|
| 208 |
+
#
|
| 209 |
+
####################################################################################################
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def is_tmem_load(atom: core.CopyAtom) -> bool:
|
| 213 |
+
"""
|
| 214 |
+
Returns whether a CopyAtom instance is a TMEM load.
|
| 215 |
+
"""
|
| 216 |
+
return isinstance(
|
| 217 |
+
atom.op,
|
| 218 |
+
(
|
| 219 |
+
Ld16x64bOp,
|
| 220 |
+
Ld16x128bOp,
|
| 221 |
+
Ld16x256bOp,
|
| 222 |
+
Ld16x32bx2Op,
|
| 223 |
+
Ld32x32bOp,
|
| 224 |
+
),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def is_tmem_store(atom: core.CopyAtom) -> bool:
|
| 229 |
+
"""
|
| 230 |
+
Returns whether a CopyAtom instance is a TMEM store.
|
| 231 |
+
"""
|
| 232 |
+
return isinstance(
|
| 233 |
+
atom.op,
|
| 234 |
+
(
|
| 235 |
+
St16x64bOp,
|
| 236 |
+
St16x128bOp,
|
| 237 |
+
St16x256bOp,
|
| 238 |
+
St16x32bx2Op,
|
| 239 |
+
St32x32bOp,
|
| 240 |
+
),
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def get_tmem_copy_properties(
|
| 245 |
+
atom: core.CopyAtom,
|
| 246 |
+
) -> Tuple[int, int, int, Union[Pack, Unpack]]:
|
| 247 |
+
"""
|
| 248 |
+
Returns the properties of a TMEM copy atom (number of data paths, bits, repetitions,
|
| 249 |
+
and whether packing/unpacking is used).
|
| 250 |
+
"""
|
| 251 |
+
if isinstance(atom.op, (Ld16x64bOp, St16x64bOp)):
|
| 252 |
+
num_dp, num_bits = 16, 64
|
| 253 |
+
elif isinstance(atom.op, (Ld16x128bOp, St16x128bOp)):
|
| 254 |
+
num_dp, num_bits = 16, 128
|
| 255 |
+
elif isinstance(atom.op, (Ld16x256bOp, St16x256bOp)):
|
| 256 |
+
num_dp, num_bits = 16, 256
|
| 257 |
+
elif isinstance(atom.op, (Ld16x32bx2Op, St16x32bx2Op)):
|
| 258 |
+
num_dp, num_bits = 16, 32
|
| 259 |
+
elif isinstance(atom.op, (Ld32x32bOp, St32x32bOp)):
|
| 260 |
+
num_dp, num_bits = 32, 32
|
| 261 |
+
else:
|
| 262 |
+
raise ValueError(f"expects 'atom' to be a TMEM copy, but got {atom}")
|
| 263 |
+
if is_tmem_load(atom):
|
| 264 |
+
return num_dp, num_bits, atom.op.repeat.value, atom.op.pack
|
| 265 |
+
else:
|
| 266 |
+
assert is_tmem_store(atom), "atom must be a TMEM store"
|
| 267 |
+
return num_dp, num_bits, atom.op.repeat.value, atom.op.unpack
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
@dsl_user_op
|
| 271 |
+
def find_tmem_tensor_col_offset(tmem_tensor: Tensor, *, loc=None, ip=None) -> Int:
|
| 272 |
+
"""
|
| 273 |
+
Computes the TMEM column offset given a TMEM tensor.
|
| 274 |
+
|
| 275 |
+
:param tmem_tensor: The TMEM tensor to use to compute the columns offset
|
| 276 |
+
:type tmem_tensor: Tensor
|
| 277 |
+
:return: The columns offset
|
| 278 |
+
:rtype: Int
|
| 279 |
+
"""
|
| 280 |
+
tmem_col_mask = 0x0000FFFF
|
| 281 |
+
offset = (
|
| 282 |
+
core.cosize(core.recast_tensor(tmem_tensor, Int32).layout, loc=loc, ip=ip)
|
| 283 |
+
& tmem_col_mask
|
| 284 |
+
)
|
| 285 |
+
if isinstance(offset, int):
|
| 286 |
+
return offset
|
| 287 |
+
return Int32(offset, loc=loc, ip=ip)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
@dsl_user_op
|
| 291 |
+
def make_tmem_copy(
|
| 292 |
+
atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None
|
| 293 |
+
) -> core.TiledCopy:
|
| 294 |
+
"""
|
| 295 |
+
Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor.
|
| 296 |
+
"""
|
| 297 |
+
tiled_copy_val = _cute_nvgpu_ir.atom_make_tmem_copy(
|
| 298 |
+
atom._trait.value, tmem_tensor.value, loc=loc, ip=ip
|
| 299 |
+
)
|
| 300 |
+
new_trait = type(atom._trait)(tiled_copy_val)
|
| 301 |
+
return core.TiledCopy(atom.op, new_trait)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@dsl_user_op
|
| 305 |
+
def make_s2t_copy(
|
| 306 |
+
atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None
|
| 307 |
+
) -> core.TiledCopy:
|
| 308 |
+
"""
|
| 309 |
+
Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor.
|
| 310 |
+
"""
|
| 311 |
+
tiled_copy_val = _cute_nvgpu_ir.atom_make_s2t_copy(
|
| 312 |
+
atom._trait.value, tmem_tensor.value, loc=loc, ip=ip
|
| 313 |
+
)
|
| 314 |
+
new_trait = type(atom._trait)(tiled_copy_val)
|
| 315 |
+
return core.TiledCopy(atom.op, new_trait)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
@dsl_user_op
|
| 319 |
+
def get_s2t_smem_desc_tensor(
|
| 320 |
+
atom: core.CopyAtom, smem_tensor: Tensor, *, loc=None, ip=None
|
| 321 |
+
) -> Tensor:
|
| 322 |
+
"""
|
| 323 |
+
Returns the SMEM descriptor tensor from a S2T copy atom and a SMEM tensor.
|
| 324 |
+
"""
|
| 325 |
+
smem_desc_tensor = _cute_nvgpu_ir.atom_get_copy_s2t_smem_desc_view(
|
| 326 |
+
atom._trait.value, smem_tensor.value, loc=loc, ip=ip
|
| 327 |
+
)
|
| 328 |
+
return smem_desc_tensor
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py
ADDED
|
@@ -0,0 +1,1041 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
import enum
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Type
|
| 15 |
+
|
| 16 |
+
from cutlass.cutlass_dsl import CuTeDSL, T
|
| 17 |
+
|
| 18 |
+
import cutlass._mlir.dialects.cute as _cute_ir
|
| 19 |
+
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
| 20 |
+
from cutlass._mlir import ir
|
| 21 |
+
|
| 22 |
+
from ..common import OpError
|
| 23 |
+
from ... import core
|
| 24 |
+
from ...core import Trait, _pack_shape, rank, depth, _Tensor
|
| 25 |
+
from ...typing import (
|
| 26 |
+
Shape,
|
| 27 |
+
Float4E2M1FN,
|
| 28 |
+
Float8E8M0FNU,
|
| 29 |
+
Float8E5M2,
|
| 30 |
+
Float8E4M3FN,
|
| 31 |
+
Float16,
|
| 32 |
+
BFloat16,
|
| 33 |
+
Float32,
|
| 34 |
+
TFloat32,
|
| 35 |
+
Boolean,
|
| 36 |
+
Int8,
|
| 37 |
+
Uint8,
|
| 38 |
+
Int32,
|
| 39 |
+
Numeric,
|
| 40 |
+
AddressSpace,
|
| 41 |
+
Pointer,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
####################################################################################################
|
| 46 |
+
#
|
| 47 |
+
# MMA Ops and Traits
|
| 48 |
+
#
|
| 49 |
+
####################################################################################################
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class OperandMajorMode(enum.Enum):
|
| 53 |
+
"""
|
| 54 |
+
An enumeration for the majorness of the input operands of the MMA.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
MN = _cute_ir.MajorMode.mn
|
| 58 |
+
K = _cute_ir.MajorMode.k
|
| 59 |
+
|
| 60 |
+
def __str__(self) -> str:
|
| 61 |
+
return f"{self.__class__.__name__}.{self.name}"
|
| 62 |
+
|
| 63 |
+
def __repr__(self) -> str:
|
| 64 |
+
return f"<{self.__class__.__name__}.{self.name}>"
|
| 65 |
+
|
| 66 |
+
@classmethod
|
| 67 |
+
def _missing_(cls, value):
|
| 68 |
+
if isinstance(value, str):
|
| 69 |
+
value = value.upper()
|
| 70 |
+
if value == "MN":
|
| 71 |
+
return OperandMajorMode.MN
|
| 72 |
+
elif value == "K":
|
| 73 |
+
return OperandMajorMode.K
|
| 74 |
+
|
| 75 |
+
def _to_ir(self) -> _cute_ir.MajorMode:
|
| 76 |
+
return self.value
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class OperandSource(enum.Enum):
|
| 80 |
+
"""
|
| 81 |
+
An enumeration for the source memory location of the A input operand of the MMA.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
TMEM = _cute_ir.MmaFragKind.tmem
|
| 85 |
+
SMEM = _cute_ir.MmaFragKind.smem_desc
|
| 86 |
+
|
| 87 |
+
def __str__(self) -> str:
|
| 88 |
+
return f"{self.__class__.__name__}.{self.name}"
|
| 89 |
+
|
| 90 |
+
def __repr__(self) -> str:
|
| 91 |
+
return f"<{self.__class__.__name__}.{self.name}>"
|
| 92 |
+
|
| 93 |
+
def _to_ir(self) -> _cute_ir.MmaFragKind:
|
| 94 |
+
return self.value
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class CtaGroup(enum.Enum):
|
| 98 |
+
"""
|
| 99 |
+
An enumeration for the ``cta_group`` qualifier of the MMA.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
ONE = 1
|
| 103 |
+
TWO = 2
|
| 104 |
+
|
| 105 |
+
def __str__(self) -> str:
|
| 106 |
+
return f"{self.__class__.__name__}.{self.name}"
|
| 107 |
+
|
| 108 |
+
def __repr__(self) -> str:
|
| 109 |
+
return f"<{self.__class__.__name__}.{self.name}>"
|
| 110 |
+
|
| 111 |
+
class Field(enum.Enum):
|
| 112 |
+
"""
|
| 113 |
+
An enumeration for the fields of the MMA Atom that can be modified at runtime.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
NEGATE_A = "neg_a"
|
| 117 |
+
NEGATE_B = "neg_b"
|
| 118 |
+
ACCUMULATE = "accum_c"
|
| 119 |
+
SFA = "sf_a"
|
| 120 |
+
SFB = "sf_b"
|
| 121 |
+
|
| 122 |
+
def __str__(self) -> str:
|
| 123 |
+
return f"{self.__class__.__name__}.{self.name}"
|
| 124 |
+
|
| 125 |
+
def __repr__(self) -> str:
|
| 126 |
+
return f"<{self.__class__.__name__}.{self.name}>"
|
| 127 |
+
|
| 128 |
+
def _to_ir_field_name(self) -> str:
|
| 129 |
+
return self.value
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# Base class for all tcgen05 MMA Ops with syntax `tcgen05.mma.cta_group.kind` used to factor out some internal code
|
| 133 |
+
@dataclass(frozen=True)
|
| 134 |
+
class MmaOp(core.MmaOp):
|
| 135 |
+
a_dtype: Type[Numeric]
|
| 136 |
+
b_dtype: Type[Numeric]
|
| 137 |
+
acc_dtype: Type[Numeric]
|
| 138 |
+
shape_mnk: Shape
|
| 139 |
+
cta_group: CtaGroup
|
| 140 |
+
a_src: OperandSource
|
| 141 |
+
a_major_mode: OperandMajorMode
|
| 142 |
+
b_major_mode: OperandMajorMode
|
| 143 |
+
|
| 144 |
+
admissible_archs = [
|
| 145 |
+
"sm_100a",
|
| 146 |
+
"sm_100f",
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
def __post_init__(self) -> None:
|
| 150 |
+
# Verify arch
|
| 151 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 152 |
+
if arch not in self.admissible_archs:
|
| 153 |
+
raise OpError(
|
| 154 |
+
self,
|
| 155 |
+
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
| 156 |
+
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
| 157 |
+
)
|
| 158 |
+
# Verify that the user provided enum values
|
| 159 |
+
if not isinstance(self.cta_group, CtaGroup):
|
| 160 |
+
raise OpError(
|
| 161 |
+
self,
|
| 162 |
+
"expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance",
|
| 163 |
+
)
|
| 164 |
+
if not isinstance(self.a_src, OperandSource):
|
| 165 |
+
raise OpError(
|
| 166 |
+
self,
|
| 167 |
+
"expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance",
|
| 168 |
+
)
|
| 169 |
+
if not isinstance(self.a_major_mode, OperandMajorMode):
|
| 170 |
+
raise OpError(
|
| 171 |
+
self,
|
| 172 |
+
"expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
|
| 173 |
+
)
|
| 174 |
+
if not isinstance(self.b_major_mode, OperandMajorMode):
|
| 175 |
+
raise OpError(
|
| 176 |
+
self,
|
| 177 |
+
"expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
|
| 178 |
+
)
|
| 179 |
+
# Verify the instruction shape
|
| 180 |
+
if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1):
|
| 181 |
+
raise OpError(
|
| 182 |
+
self,
|
| 183 |
+
f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, "
|
| 184 |
+
f"but got {self.shape_mnk}",
|
| 185 |
+
)
|
| 186 |
+
m, n = self.shape_mnk[0], self.shape_mnk[1]
|
| 187 |
+
if self.cta_group == CtaGroup.ONE:
|
| 188 |
+
if m not in [64, 128]:
|
| 189 |
+
raise OpError(self, f"expects the M-mode to be 64 or 128, but got {m}")
|
| 190 |
+
if m == 64:
|
| 191 |
+
if (n < 8) or (n > 256) or (n % 8 != 0):
|
| 192 |
+
raise OpError(
|
| 193 |
+
self,
|
| 194 |
+
f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}",
|
| 195 |
+
)
|
| 196 |
+
elif m == 128:
|
| 197 |
+
if (n < 16) or (n > 256) or (n % 16 != 0):
|
| 198 |
+
raise OpError(
|
| 199 |
+
self,
|
| 200 |
+
f"expects the N-mode to satisfy 8 <= N <= 256 and N % 16 == 0, but got {n}",
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
if m not in [128, 256]:
|
| 204 |
+
raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}")
|
| 205 |
+
if (n < 32) or (n > 256) or (n % 32 != 0):
|
| 206 |
+
raise OpError(
|
| 207 |
+
self,
|
| 208 |
+
f"expects the N-mode to satisfy 32 <= N <= 256 and N % 32 == 0, but got {n}",
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def __str__(self) -> str:
|
| 212 |
+
return (
|
| 213 |
+
self.__class__.descriptive_name # type: ignore
|
| 214 |
+
+ f"\n A data type = {self.a_dtype}"
|
| 215 |
+
+ f"\n B data type = {self.b_dtype}"
|
| 216 |
+
+ f"\n Accumulator data type = {self.acc_dtype}"
|
| 217 |
+
+ f"\n CTA group = {self.cta_group}"
|
| 218 |
+
+ f"\n A source location = {self.a_src}"
|
| 219 |
+
+ f"\n A major mode = {self.a_major_mode}"
|
| 220 |
+
+ f"\n B major mode = {self.b_major_mode}"
|
| 221 |
+
+ f"\n Instruction shape MNK = {self.shape_mnk}"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
|
| 225 |
+
if input.memspace == AddressSpace.smem and isinstance(
|
| 226 |
+
input.layout.type, _cute_ir.ComposedLayoutType
|
| 227 |
+
):
|
| 228 |
+
raise OpError(
|
| 229 |
+
self,
|
| 230 |
+
f"Expected affine layout for {self._make_trait()}'s operand A, "
|
| 231 |
+
f"but got composed layout instead: {input.layout}"
|
| 232 |
+
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
|
| 233 |
+
)
|
| 234 |
+
return True
|
| 235 |
+
|
| 236 |
+
def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
|
| 237 |
+
if input.memspace == AddressSpace.smem and isinstance(
|
| 238 |
+
input.layout.type, _cute_ir.ComposedLayoutType
|
| 239 |
+
):
|
| 240 |
+
raise OpError(
|
| 241 |
+
self,
|
| 242 |
+
f"Expected affine layout for {self._make_trait()}'s operand B, "
|
| 243 |
+
f"but got composed layout instead: {input.layout}"
|
| 244 |
+
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
|
| 245 |
+
)
|
| 246 |
+
return True
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class MmaTrait(Trait):
|
| 250 |
+
admissible_fields = [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B]
|
| 251 |
+
|
| 252 |
+
def set(self, field, value, *, loc=None, ip=None) -> None:
|
| 253 |
+
if field not in self.admissible_fields:
|
| 254 |
+
raise ValueError(
|
| 255 |
+
f"expects field to be one of {self.admissible_fields}, but got {field}"
|
| 256 |
+
)
|
| 257 |
+
field_name = f"#cute_nvgpu.atom_mma_field_sm100<{field._to_ir_field_name()}>"
|
| 258 |
+
attr = ir.Attribute.parse(field_name)
|
| 259 |
+
self.value = _cute_nvgpu_ir.atom_set_value(
|
| 260 |
+
self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# Base class for all tcgen05 BlockScaled MMA Ops with syntax `tcgen05.mma.cta_group.kind.block_scale` used to factor out some internal code
|
| 265 |
+
@dataclass(frozen=True)
|
| 266 |
+
class BlockScaledMmaOp(core.MmaOp):
|
| 267 |
+
a_dtype: Type[Numeric]
|
| 268 |
+
b_dtype: Type[Numeric]
|
| 269 |
+
acc_dtype: Float32
|
| 270 |
+
sf_dtype: Type[Numeric]
|
| 271 |
+
sf_vec_size: int
|
| 272 |
+
shape_mnk: Shape
|
| 273 |
+
cta_group: CtaGroup
|
| 274 |
+
a_src: OperandSource
|
| 275 |
+
a_major_mode: OperandMajorMode
|
| 276 |
+
b_major_mode: OperandMajorMode
|
| 277 |
+
|
| 278 |
+
admissible_archs = [
|
| 279 |
+
"sm_100a",
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
def __post_init__(self) -> None:
|
| 283 |
+
# Verify arch
|
| 284 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 285 |
+
if arch not in self.admissible_archs:
|
| 286 |
+
raise OpError(
|
| 287 |
+
self,
|
| 288 |
+
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
| 289 |
+
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
| 290 |
+
)
|
| 291 |
+
# Verify that the user provided enum values
|
| 292 |
+
if not isinstance(self.cta_group, CtaGroup):
|
| 293 |
+
raise OpError(
|
| 294 |
+
self,
|
| 295 |
+
"expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance",
|
| 296 |
+
)
|
| 297 |
+
if not isinstance(self.a_src, OperandSource):
|
| 298 |
+
raise OpError(
|
| 299 |
+
self,
|
| 300 |
+
"expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance",
|
| 301 |
+
)
|
| 302 |
+
if not isinstance(self.a_major_mode, OperandMajorMode):
|
| 303 |
+
raise OpError(
|
| 304 |
+
self,
|
| 305 |
+
"expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
|
| 306 |
+
)
|
| 307 |
+
if not isinstance(self.b_major_mode, OperandMajorMode):
|
| 308 |
+
raise OpError(
|
| 309 |
+
self,
|
| 310 |
+
"expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance",
|
| 311 |
+
)
|
| 312 |
+
# Verify the instruction shape
|
| 313 |
+
if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1):
|
| 314 |
+
raise OpError(
|
| 315 |
+
self,
|
| 316 |
+
f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, "
|
| 317 |
+
f"but got {self.shape_mnk}",
|
| 318 |
+
)
|
| 319 |
+
m, n = self.shape_mnk[0], self.shape_mnk[1]
|
| 320 |
+
if self.cta_group == CtaGroup.ONE:
|
| 321 |
+
if m != 128:
|
| 322 |
+
raise OpError(self, f"expects the M-mode to be 128, but got {m}")
|
| 323 |
+
|
| 324 |
+
if (n < 8) or (n > 256) or (n % 8 != 0):
|
| 325 |
+
raise OpError(
|
| 326 |
+
self,
|
| 327 |
+
f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}",
|
| 328 |
+
)
|
| 329 |
+
else:
|
| 330 |
+
if m not in [128, 256]:
|
| 331 |
+
raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}")
|
| 332 |
+
if (n < 16) or (n > 256) or (n % 16 != 0):
|
| 333 |
+
raise OpError(
|
| 334 |
+
self,
|
| 335 |
+
f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}",
|
| 336 |
+
)
|
| 337 |
+
if self.sf_vec_size not in [16, 32]:
|
| 338 |
+
raise OpError(
|
| 339 |
+
self,
|
| 340 |
+
f"expects the scale factor vector size to be 16 or 32, but got {self.sf_vec_size}",
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
def __str__(self) -> str:
|
| 344 |
+
return (
|
| 345 |
+
self.__class__.descriptive_name # type: ignore
|
| 346 |
+
+ f"\n A data type = {self.a_dtype}"
|
| 347 |
+
+ f"\n B data type = {self.b_dtype}"
|
| 348 |
+
+ f"\n Accumulator data type = {self.acc_dtype}"
|
| 349 |
+
+ f"\n Scale factor data type = {self.sf_dtype}"
|
| 350 |
+
+ f"\n Scale factor vector size = {self.sf_vec_size}"
|
| 351 |
+
+ f"\n CTA group = {self.cta_group}"
|
| 352 |
+
+ f"\n A source location = {self.a_src}"
|
| 353 |
+
+ f"\n A major mode = {self.a_major_mode}"
|
| 354 |
+
+ f"\n B major mode = {self.b_major_mode}"
|
| 355 |
+
+ f"\n Instruction shape MNK = {self.shape_mnk}"
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
|
| 359 |
+
if input.memspace == AddressSpace.smem and isinstance(
|
| 360 |
+
input.layout.type, _cute_ir.ComposedLayoutType
|
| 361 |
+
):
|
| 362 |
+
raise OpError(
|
| 363 |
+
self,
|
| 364 |
+
f"Expected affine layout for {self._make_trait()}'s operand A, "
|
| 365 |
+
f"but got composed layout instead: {input.layout}"
|
| 366 |
+
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
|
| 367 |
+
)
|
| 368 |
+
return True
|
| 369 |
+
|
| 370 |
+
def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
|
| 371 |
+
if input.memspace == AddressSpace.smem and isinstance(
|
| 372 |
+
input.layout.type, _cute_ir.ComposedLayoutType
|
| 373 |
+
):
|
| 374 |
+
raise OpError(
|
| 375 |
+
self,
|
| 376 |
+
f"Expected affine layout for {self._make_trait()}'s operand B, "
|
| 377 |
+
f"but got composed layout instead: {input.layout}"
|
| 378 |
+
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
|
| 379 |
+
)
|
| 380 |
+
return True
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
class BlockScaledMmaTraits(Trait):
|
| 384 |
+
admissible_fields = [
|
| 385 |
+
Field.ACCUMULATE,
|
| 386 |
+
Field.NEGATE_A,
|
| 387 |
+
Field.NEGATE_B,
|
| 388 |
+
Field.SFA,
|
| 389 |
+
Field.SFB,
|
| 390 |
+
]
|
| 391 |
+
|
| 392 |
+
def set(self, field, value, *, loc=None, ip=None) -> None:
|
| 393 |
+
if field not in self.admissible_fields:
|
| 394 |
+
raise ValueError(
|
| 395 |
+
f"expects field to be one of {self.admissible_fields}, but got {field}"
|
| 396 |
+
)
|
| 397 |
+
if field in [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B]:
|
| 398 |
+
value = Boolean(value).ir_value(loc=loc, ip=ip)
|
| 399 |
+
elif field in [Field.SFA, Field.SFB]:
|
| 400 |
+
if not isinstance(value, Pointer):
|
| 401 |
+
raise ValueError(
|
| 402 |
+
f"expects value to be a pointer for {field}, but got {type(value).__name__}"
|
| 403 |
+
)
|
| 404 |
+
value = value.value
|
| 405 |
+
|
| 406 |
+
field_name = f"#cute_nvgpu.atom_mma_field_sm100_block_scaled<{field._to_ir_field_name()}>"
|
| 407 |
+
attr = ir.Attribute.parse(field_name)
|
| 408 |
+
self.value = _cute_nvgpu_ir.atom_set_value(
|
| 409 |
+
self.value, attr, value, loc=loc, ip=ip
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
#
|
| 414 |
+
# TF32 MMA
|
| 415 |
+
#
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
@dataclass(frozen=True)
|
| 419 |
+
class MmaTF32Op(MmaOp):
|
| 420 |
+
"""
|
| 421 |
+
TF32 tcgen05 MMA Operation.
|
| 422 |
+
|
| 423 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
|
| 424 |
+
This Operation corresponds to the ``.kind::tf32`` qualifier.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
descriptive_name = "tcgen05 TF32 MMA Operation"
|
| 428 |
+
|
| 429 |
+
def __init__(
|
| 430 |
+
self,
|
| 431 |
+
instruction_shape: Shape,
|
| 432 |
+
cta_group: CtaGroup,
|
| 433 |
+
a_src: OperandSource,
|
| 434 |
+
a_major_mode: OperandMajorMode,
|
| 435 |
+
b_major_mode: OperandMajorMode,
|
| 436 |
+
) -> None:
|
| 437 |
+
super().__init__(
|
| 438 |
+
TFloat32,
|
| 439 |
+
TFloat32,
|
| 440 |
+
Float32,
|
| 441 |
+
instruction_shape,
|
| 442 |
+
cta_group,
|
| 443 |
+
a_src,
|
| 444 |
+
a_major_mode,
|
| 445 |
+
b_major_mode,
|
| 446 |
+
)
|
| 447 |
+
self._verify()
|
| 448 |
+
|
| 449 |
+
def _verify(self) -> None:
|
| 450 |
+
# Verify the instruction shape
|
| 451 |
+
instruction_k = 8
|
| 452 |
+
if rank(self.shape_mnk) == 2:
|
| 453 |
+
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
| 454 |
+
if self.shape_mnk[2] != instruction_k:
|
| 455 |
+
raise OpError(
|
| 456 |
+
self,
|
| 457 |
+
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
| 458 |
+
f"but got {self.shape_mnk[2]}",
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaTF32Trait":
|
| 462 |
+
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
| 463 |
+
ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
|
| 464 |
+
shape_mnk.type.attribute,
|
| 465 |
+
self.cta_group.value,
|
| 466 |
+
self.a_major_mode._to_ir(),
|
| 467 |
+
self.b_major_mode._to_ir(),
|
| 468 |
+
self.a_dtype.mlir_type,
|
| 469 |
+
self.b_dtype.mlir_type,
|
| 470 |
+
self.acc_dtype.mlir_type,
|
| 471 |
+
self.a_src._to_ir(),
|
| 472 |
+
0,
|
| 473 |
+
)
|
| 474 |
+
return MmaTF32Trait(
|
| 475 |
+
_cute_nvgpu_ir.make_sm100_mma(
|
| 476 |
+
ty,
|
| 477 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 478 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 479 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 480 |
+
loc=loc,
|
| 481 |
+
ip=ip,
|
| 482 |
+
)
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class MmaTF32Trait(MmaTrait):
|
| 487 |
+
pass
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
#
|
| 491 |
+
# F16/BF16 MMA
|
| 492 |
+
#
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
@dataclass(frozen=True)
|
| 496 |
+
class MmaF16BF16Op(MmaOp):
|
| 497 |
+
"""
|
| 498 |
+
F16/BF16 tcgen05 MMA Operation.
|
| 499 |
+
|
| 500 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
|
| 501 |
+
This Operation corresponds to the ``.kind::f16`` qualifier.
|
| 502 |
+
"""
|
| 503 |
+
|
| 504 |
+
descriptive_name = "tcgen05 F16/BF16 MMA Operation"
|
| 505 |
+
|
| 506 |
+
def __init__(
|
| 507 |
+
self,
|
| 508 |
+
ab_dtype: Type[Numeric],
|
| 509 |
+
acc_dtype: Type[Numeric],
|
| 510 |
+
instruction_shape: Shape,
|
| 511 |
+
cta_group: CtaGroup,
|
| 512 |
+
a_src: OperandSource,
|
| 513 |
+
a_major_mode: OperandMajorMode,
|
| 514 |
+
b_major_mode: OperandMajorMode,
|
| 515 |
+
) -> None:
|
| 516 |
+
super().__init__(
|
| 517 |
+
ab_dtype,
|
| 518 |
+
ab_dtype,
|
| 519 |
+
acc_dtype,
|
| 520 |
+
instruction_shape,
|
| 521 |
+
cta_group,
|
| 522 |
+
a_src,
|
| 523 |
+
a_major_mode,
|
| 524 |
+
b_major_mode,
|
| 525 |
+
)
|
| 526 |
+
self._verify()
|
| 527 |
+
|
| 528 |
+
def _verify(self) -> None:
|
| 529 |
+
# Input data type verification
|
| 530 |
+
if self.a_dtype not in [Float16, BFloat16]:
|
| 531 |
+
raise OpError(
|
| 532 |
+
self,
|
| 533 |
+
"expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16",
|
| 534 |
+
)
|
| 535 |
+
assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
|
| 536 |
+
# Accumulator data type verification
|
| 537 |
+
if self.acc_dtype not in [Float16, Float32]:
|
| 538 |
+
raise OpError(
|
| 539 |
+
self,
|
| 540 |
+
"expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
|
| 541 |
+
)
|
| 542 |
+
# Instruction shape verification
|
| 543 |
+
instruction_k = 16
|
| 544 |
+
if rank(self.shape_mnk) == 2:
|
| 545 |
+
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
| 546 |
+
if self.shape_mnk[2] != instruction_k:
|
| 547 |
+
raise OpError(
|
| 548 |
+
self,
|
| 549 |
+
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
| 550 |
+
f"but got {self.shape_mnk[2]}",
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait":
|
| 554 |
+
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
| 555 |
+
ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
|
| 556 |
+
shape_mnk.type.attribute,
|
| 557 |
+
self.cta_group.value,
|
| 558 |
+
self.a_major_mode._to_ir(),
|
| 559 |
+
self.b_major_mode._to_ir(),
|
| 560 |
+
self.a_dtype.mlir_type,
|
| 561 |
+
self.b_dtype.mlir_type,
|
| 562 |
+
self.acc_dtype.mlir_type,
|
| 563 |
+
self.a_src._to_ir(),
|
| 564 |
+
0,
|
| 565 |
+
)
|
| 566 |
+
return MmaF16BF16Trait(
|
| 567 |
+
_cute_nvgpu_ir.make_sm100_mma(
|
| 568 |
+
ty,
|
| 569 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 570 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 571 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 572 |
+
loc=loc,
|
| 573 |
+
ip=ip,
|
| 574 |
+
)
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
class MmaF16BF16Trait(MmaTrait):
|
| 579 |
+
pass
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
#
|
| 583 |
+
# I8 MMA
|
| 584 |
+
#
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
@dataclass(frozen=True)
|
| 588 |
+
class MmaI8Op(MmaOp):
|
| 589 |
+
"""
|
| 590 |
+
I8 tcgen05 MMA Operation.
|
| 591 |
+
|
| 592 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
|
| 593 |
+
This Operation corresponds to the ``.kind::i8`` qualifier.
|
| 594 |
+
"""
|
| 595 |
+
|
| 596 |
+
descriptive_name = "tcgen05 I8 MMA Operation"
|
| 597 |
+
|
| 598 |
+
def __init__(
|
| 599 |
+
self,
|
| 600 |
+
ab_dtype: Type[Numeric],
|
| 601 |
+
instruction_shape: Shape,
|
| 602 |
+
cta_group: CtaGroup,
|
| 603 |
+
a_src: OperandSource,
|
| 604 |
+
a_major_mode: OperandMajorMode,
|
| 605 |
+
b_major_mode: OperandMajorMode,
|
| 606 |
+
) -> None:
|
| 607 |
+
super().__init__(
|
| 608 |
+
ab_dtype,
|
| 609 |
+
ab_dtype,
|
| 610 |
+
Int32,
|
| 611 |
+
instruction_shape,
|
| 612 |
+
cta_group,
|
| 613 |
+
a_src,
|
| 614 |
+
a_major_mode,
|
| 615 |
+
b_major_mode,
|
| 616 |
+
)
|
| 617 |
+
self._verify()
|
| 618 |
+
|
| 619 |
+
def _verify(self) -> None:
|
| 620 |
+
# Input data type verification
|
| 621 |
+
if self.a_dtype not in [Int8, Uint8]:
|
| 622 |
+
raise OpError(
|
| 623 |
+
self,
|
| 624 |
+
"expects the 'ab_dtype' Op parameter to be one of Int8 or Uint8",
|
| 625 |
+
)
|
| 626 |
+
assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
|
| 627 |
+
# Instruction shape verification
|
| 628 |
+
instruction_k = 32
|
| 629 |
+
if rank(self.shape_mnk) == 2:
|
| 630 |
+
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
| 631 |
+
if self.shape_mnk[2] != instruction_k:
|
| 632 |
+
raise OpError(
|
| 633 |
+
self,
|
| 634 |
+
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
| 635 |
+
f"but got {self.shape_mnk[2]}",
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaI8Trait":
|
| 639 |
+
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
| 640 |
+
ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
|
| 641 |
+
shape_mnk.type.attribute,
|
| 642 |
+
self.cta_group.value,
|
| 643 |
+
self.a_major_mode._to_ir(),
|
| 644 |
+
self.b_major_mode._to_ir(),
|
| 645 |
+
(T.si8() if self.a_dtype.signed else T.ui8()),
|
| 646 |
+
(T.si8() if self.b_dtype.signed else T.ui8()),
|
| 647 |
+
T.si32(),
|
| 648 |
+
self.a_src._to_ir(),
|
| 649 |
+
0,
|
| 650 |
+
)
|
| 651 |
+
return MmaI8Trait(
|
| 652 |
+
_cute_nvgpu_ir.make_sm100_mma(
|
| 653 |
+
ty,
|
| 654 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 655 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 656 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 657 |
+
loc=loc,
|
| 658 |
+
ip=ip,
|
| 659 |
+
)
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
class MmaI8Trait(MmaTrait):
|
| 664 |
+
pass
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
#
|
| 668 |
+
# F8F6F4 MMA
|
| 669 |
+
#
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
@dataclass(frozen=True)
|
| 673 |
+
class MmaFP8Op(MmaOp):
|
| 674 |
+
"""
|
| 675 |
+
F8 tcgen05 MMA Operation.
|
| 676 |
+
|
| 677 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
|
| 678 |
+
"""
|
| 679 |
+
|
| 680 |
+
descriptive_name = "tcgen05 F8 MMA Operation"
|
| 681 |
+
|
| 682 |
+
def __init__(
|
| 683 |
+
self,
|
| 684 |
+
ab_dtype: Type[Numeric],
|
| 685 |
+
acc_dtype: Type[Numeric],
|
| 686 |
+
instruction_shape: Shape,
|
| 687 |
+
cta_group: CtaGroup,
|
| 688 |
+
a_src: OperandSource,
|
| 689 |
+
a_major_mode: OperandMajorMode,
|
| 690 |
+
b_major_mode: OperandMajorMode,
|
| 691 |
+
) -> None:
|
| 692 |
+
|
| 693 |
+
super().__init__(
|
| 694 |
+
ab_dtype,
|
| 695 |
+
ab_dtype,
|
| 696 |
+
acc_dtype,
|
| 697 |
+
instruction_shape,
|
| 698 |
+
cta_group,
|
| 699 |
+
a_src,
|
| 700 |
+
a_major_mode,
|
| 701 |
+
b_major_mode,
|
| 702 |
+
)
|
| 703 |
+
self._verify()
|
| 704 |
+
|
| 705 |
+
def _verify(self) -> None:
|
| 706 |
+
# Input data type verification
|
| 707 |
+
if self.a_dtype not in [Float8E5M2, Float8E4M3FN]:
|
| 708 |
+
raise OpError(
|
| 709 |
+
self,
|
| 710 |
+
"expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
|
| 711 |
+
)
|
| 712 |
+
assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
|
| 713 |
+
# Accumulator data type verification
|
| 714 |
+
if self.acc_dtype not in [Float16, Float32]:
|
| 715 |
+
raise OpError(
|
| 716 |
+
self,
|
| 717 |
+
"expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
|
| 718 |
+
)
|
| 719 |
+
# Instruction shape verification
|
| 720 |
+
instruction_k = 32
|
| 721 |
+
if rank(self.shape_mnk) == 2:
|
| 722 |
+
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
| 723 |
+
if self.shape_mnk[2] != instruction_k:
|
| 724 |
+
raise OpError(
|
| 725 |
+
self,
|
| 726 |
+
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
| 727 |
+
f"but got {self.shape_mnk[2]}",
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaFP8Trait":
|
| 731 |
+
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
| 732 |
+
ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get(
|
| 733 |
+
shape_mnk.type.attribute,
|
| 734 |
+
self.cta_group.value,
|
| 735 |
+
self.a_major_mode._to_ir(),
|
| 736 |
+
self.b_major_mode._to_ir(),
|
| 737 |
+
self.a_dtype.mlir_type,
|
| 738 |
+
self.b_dtype.mlir_type,
|
| 739 |
+
self.acc_dtype.mlir_type,
|
| 740 |
+
self.a_src._to_ir(),
|
| 741 |
+
0,
|
| 742 |
+
)
|
| 743 |
+
return MmaFP8Trait(
|
| 744 |
+
_cute_nvgpu_ir.make_sm100_mma(
|
| 745 |
+
ty,
|
| 746 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 747 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 748 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 749 |
+
loc=loc,
|
| 750 |
+
ip=ip,
|
| 751 |
+
)
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
class MmaFP8Trait(MmaTrait):
|
| 756 |
+
pass
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
#
|
| 760 |
+
# MXF8F6F4 MMA
|
| 761 |
+
#
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
@dataclass(frozen=True)
|
| 765 |
+
class MmaMXF8Op(BlockScaledMmaOp):
|
| 766 |
+
"""
|
| 767 |
+
MXF8 tcgen05 BlockScaled MMA Operation.
|
| 768 |
+
|
| 769 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
|
| 770 |
+
This Operation corresponds to the ``.kind::mxf8f6f4`` qualifier.
|
| 771 |
+
"""
|
| 772 |
+
|
| 773 |
+
descriptive_name = "tcgen05 MXF8 BlockScaled MMA Operation"
|
| 774 |
+
|
| 775 |
+
def __init__(
|
| 776 |
+
self,
|
| 777 |
+
ab_dtype: Type[Numeric],
|
| 778 |
+
instruction_shape: Shape,
|
| 779 |
+
cta_group: CtaGroup,
|
| 780 |
+
a_src: OperandSource,
|
| 781 |
+
a_major_mode: OperandMajorMode,
|
| 782 |
+
b_major_mode: OperandMajorMode,
|
| 783 |
+
) -> None:
|
| 784 |
+
super().__init__(
|
| 785 |
+
ab_dtype,
|
| 786 |
+
ab_dtype,
|
| 787 |
+
Float32,
|
| 788 |
+
Float8E8M0FNU,
|
| 789 |
+
32,
|
| 790 |
+
instruction_shape,
|
| 791 |
+
cta_group,
|
| 792 |
+
a_src,
|
| 793 |
+
a_major_mode,
|
| 794 |
+
b_major_mode,
|
| 795 |
+
)
|
| 796 |
+
self._verify()
|
| 797 |
+
|
| 798 |
+
def _verify(self) -> None:
|
| 799 |
+
# Input data type verification
|
| 800 |
+
if self.a_dtype not in [Float8E5M2, Float8E4M3FN]:
|
| 801 |
+
raise OpError(
|
| 802 |
+
self,
|
| 803 |
+
"expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
|
| 804 |
+
)
|
| 805 |
+
assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
|
| 806 |
+
# Instruction shape verification
|
| 807 |
+
instruction_k = 32
|
| 808 |
+
if rank(self.shape_mnk) == 2:
|
| 809 |
+
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
| 810 |
+
if self.shape_mnk[2] != instruction_k:
|
| 811 |
+
raise OpError(
|
| 812 |
+
self,
|
| 813 |
+
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
| 814 |
+
f"but got {self.shape_mnk[2]}",
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait":
|
| 818 |
+
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
| 819 |
+
ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get(
|
| 820 |
+
shape_mnk.type.attribute,
|
| 821 |
+
self.cta_group.value,
|
| 822 |
+
self.a_major_mode._to_ir(),
|
| 823 |
+
self.b_major_mode._to_ir(),
|
| 824 |
+
self.a_dtype.mlir_type,
|
| 825 |
+
self.b_dtype.mlir_type,
|
| 826 |
+
self.acc_dtype.mlir_type,
|
| 827 |
+
self.sf_dtype.mlir_type,
|
| 828 |
+
self.a_src._to_ir(),
|
| 829 |
+
self.sf_vec_size,
|
| 830 |
+
)
|
| 831 |
+
return MmaMXF8Trait(
|
| 832 |
+
_cute_nvgpu_ir.make_sm100_mma_bs(
|
| 833 |
+
ty,
|
| 834 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 835 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 836 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 837 |
+
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
|
| 838 |
+
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
|
| 839 |
+
loc=loc,
|
| 840 |
+
ip=ip,
|
| 841 |
+
)
|
| 842 |
+
)
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
class MmaMXF8Trait(BlockScaledMmaTraits):
|
| 846 |
+
pass
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
#
|
| 850 |
+
# MXF4 MMA
|
| 851 |
+
#
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
@dataclass(frozen=True)
|
| 855 |
+
class MmaMXF4Op(BlockScaledMmaOp):
|
| 856 |
+
"""
|
| 857 |
+
MXF4 tcgen05 BlockScaled MMA Operation.
|
| 858 |
+
|
| 859 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
|
| 860 |
+
This Operation corresponds to the ``.kind::mxf4`` qualifier.
|
| 861 |
+
"""
|
| 862 |
+
|
| 863 |
+
descriptive_name = "tcgen05 MXF4 BlockScaled MMA Operation"
|
| 864 |
+
|
| 865 |
+
def __init__(
|
| 866 |
+
self,
|
| 867 |
+
instruction_shape: Shape,
|
| 868 |
+
cta_group: CtaGroup,
|
| 869 |
+
a_src: OperandSource,
|
| 870 |
+
) -> None:
|
| 871 |
+
super().__init__(
|
| 872 |
+
Float4E2M1FN,
|
| 873 |
+
Float4E2M1FN,
|
| 874 |
+
Float32,
|
| 875 |
+
Float8E8M0FNU,
|
| 876 |
+
32,
|
| 877 |
+
instruction_shape,
|
| 878 |
+
cta_group,
|
| 879 |
+
a_src,
|
| 880 |
+
OperandMajorMode.K,
|
| 881 |
+
OperandMajorMode.K,
|
| 882 |
+
)
|
| 883 |
+
self._verify()
|
| 884 |
+
|
| 885 |
+
def _verify(self) -> None:
|
| 886 |
+
# Instruction shape verification
|
| 887 |
+
instruction_k = 64
|
| 888 |
+
if rank(self.shape_mnk) == 2:
|
| 889 |
+
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
| 890 |
+
if self.shape_mnk[2] != instruction_k:
|
| 891 |
+
raise OpError(
|
| 892 |
+
self,
|
| 893 |
+
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
| 894 |
+
f"but got {self.shape_mnk[2]}",
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait":
|
| 898 |
+
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
| 899 |
+
ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get(
|
| 900 |
+
shape_mnk.type.attribute,
|
| 901 |
+
self.cta_group.value,
|
| 902 |
+
self.a_major_mode._to_ir(),
|
| 903 |
+
self.b_major_mode._to_ir(),
|
| 904 |
+
self.a_dtype.mlir_type,
|
| 905 |
+
self.b_dtype.mlir_type,
|
| 906 |
+
self.acc_dtype.mlir_type,
|
| 907 |
+
self.sf_dtype.mlir_type,
|
| 908 |
+
self.a_src._to_ir(),
|
| 909 |
+
self.sf_vec_size,
|
| 910 |
+
)
|
| 911 |
+
return MmaMXF4Trait(
|
| 912 |
+
_cute_nvgpu_ir.make_sm100_mma_bs(
|
| 913 |
+
ty,
|
| 914 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 915 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 916 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 917 |
+
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
|
| 918 |
+
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
|
| 919 |
+
loc=loc,
|
| 920 |
+
ip=ip,
|
| 921 |
+
)
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
class MmaMXF4Trait(BlockScaledMmaTraits):
|
| 926 |
+
pass
|
| 927 |
+
|
| 928 |
+
|
| 929 |
+
#
|
| 930 |
+
# MXF4NVF4 MMA
|
| 931 |
+
#
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
@dataclass(frozen=True)
|
| 935 |
+
class MmaMXF4NVF4Op(BlockScaledMmaOp):
|
| 936 |
+
"""
|
| 937 |
+
MXF4NVF4 tcgen05 BlockScaled MMA Operation.
|
| 938 |
+
|
| 939 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__.
|
| 940 |
+
This Operation corresponds to the ``.kind::mxf4nvf4`` qualifier.
|
| 941 |
+
"""
|
| 942 |
+
|
| 943 |
+
descriptive_name = "tcgen05 MXF4NVF4 BlockScaled MMA Operation"
|
| 944 |
+
|
| 945 |
+
def __init__(
|
| 946 |
+
self,
|
| 947 |
+
sf_dtype: Type[Numeric],
|
| 948 |
+
instruction_shape: Shape,
|
| 949 |
+
cta_group: CtaGroup,
|
| 950 |
+
a_src: OperandSource,
|
| 951 |
+
) -> None:
|
| 952 |
+
super().__init__(
|
| 953 |
+
Float4E2M1FN,
|
| 954 |
+
Float4E2M1FN,
|
| 955 |
+
Float32,
|
| 956 |
+
sf_dtype,
|
| 957 |
+
16,
|
| 958 |
+
instruction_shape,
|
| 959 |
+
cta_group,
|
| 960 |
+
a_src,
|
| 961 |
+
OperandMajorMode.K,
|
| 962 |
+
OperandMajorMode.K,
|
| 963 |
+
)
|
| 964 |
+
self._verify()
|
| 965 |
+
|
| 966 |
+
def _verify(self) -> None:
|
| 967 |
+
# Scale Factor data type verification
|
| 968 |
+
if self.sf_dtype not in [Float8E8M0FNU, Float8E4M3FN]:
|
| 969 |
+
raise OpError(
|
| 970 |
+
self,
|
| 971 |
+
"expects the 'sf_dtype' Op parameter to be one of Float8E8M0FNU",
|
| 972 |
+
)
|
| 973 |
+
# Instruction shape verification
|
| 974 |
+
instruction_k = 64
|
| 975 |
+
if rank(self.shape_mnk) == 2:
|
| 976 |
+
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
| 977 |
+
if self.shape_mnk[2] != instruction_k:
|
| 978 |
+
raise OpError(
|
| 979 |
+
self,
|
| 980 |
+
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
| 981 |
+
f"but got {self.shape_mnk[2]}",
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait":
|
| 985 |
+
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
| 986 |
+
ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get(
|
| 987 |
+
shape_mnk.type.attribute,
|
| 988 |
+
self.cta_group.value,
|
| 989 |
+
self.a_major_mode._to_ir(),
|
| 990 |
+
self.b_major_mode._to_ir(),
|
| 991 |
+
self.a_dtype.mlir_type,
|
| 992 |
+
self.b_dtype.mlir_type,
|
| 993 |
+
self.acc_dtype.mlir_type,
|
| 994 |
+
self.sf_dtype.mlir_type,
|
| 995 |
+
self.a_src._to_ir(),
|
| 996 |
+
self.sf_vec_size,
|
| 997 |
+
)
|
| 998 |
+
return MmaMXF4NVF4Trait(
|
| 999 |
+
_cute_nvgpu_ir.make_sm100_mma_bs(
|
| 1000 |
+
ty,
|
| 1001 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 1002 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 1003 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 1004 |
+
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
|
| 1005 |
+
core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value,
|
| 1006 |
+
loc=loc,
|
| 1007 |
+
ip=ip,
|
| 1008 |
+
)
|
| 1009 |
+
)
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
class MmaMXF4NVF4Trait(BlockScaledMmaTraits):
|
| 1013 |
+
pass
|
| 1014 |
+
|
| 1015 |
+
####################################################################################################
|
| 1016 |
+
#
|
| 1017 |
+
# SMEM layout atoms
|
| 1018 |
+
#
|
| 1019 |
+
####################################################################################################
|
| 1020 |
+
|
| 1021 |
+
|
| 1022 |
+
class SmemLayoutAtomKind(enum.Enum):
|
| 1023 |
+
"""
|
| 1024 |
+
Enum class for the kinds of SMEM layout atoms for SM100.
|
| 1025 |
+
|
| 1026 |
+
Given a swizzle kind, an SMEM layout atom is the compact layout of smallest size that can be
|
| 1027 |
+
used to construct an SMEM layout using blocked product for operand A or B such that the
|
| 1028 |
+
resulting layout is legal for both TMA and UMMA.
|
| 1029 |
+
|
| 1030 |
+
Note that there are other ways of creating legal layouts for operand A and B.
|
| 1031 |
+
"""
|
| 1032 |
+
|
| 1033 |
+
MN_INTER = enum.auto()
|
| 1034 |
+
MN_SW32 = enum.auto()
|
| 1035 |
+
MN_SW64 = enum.auto()
|
| 1036 |
+
MN_SW128 = enum.auto()
|
| 1037 |
+
MN_SW128_32B = enum.auto()
|
| 1038 |
+
K_INTER = enum.auto()
|
| 1039 |
+
K_SW32 = enum.auto()
|
| 1040 |
+
K_SW64 = enum.auto()
|
| 1041 |
+
K_SW128 = enum.auto()
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from .copy import *
|
| 13 |
+
from .mma import *
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# __all__ is required here for documentation generation
|
| 17 |
+
__all__ = [
|
| 18 |
+
# mma.py
|
| 19 |
+
"MmaF16BF16Op",
|
| 20 |
+
# copy.py
|
| 21 |
+
"LdMatrix8x8x16bOp",
|
| 22 |
+
"LdMatrix16x16x8bOp",
|
| 23 |
+
"StMatrix8x8x16bOp",
|
| 24 |
+
"StMatrix16x8x8bOp",
|
| 25 |
+
]
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import Type
|
| 14 |
+
|
| 15 |
+
import cutlass._mlir.dialects.cute as _cute_ir
|
| 16 |
+
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
| 17 |
+
from cutlass._mlir import ir
|
| 18 |
+
|
| 19 |
+
from ..common import OpError
|
| 20 |
+
from ...core import CopyOp, Trait, _pack_shape
|
| 21 |
+
from ...typing import Numeric
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass(frozen=True)
|
| 25 |
+
class BaseOp(CopyOp):
|
| 26 |
+
transpose: bool = False
|
| 27 |
+
num_matrices: int = 1
|
| 28 |
+
|
| 29 |
+
def __post_init__(self) -> None:
|
| 30 |
+
if not isinstance(self.transpose, bool):
|
| 31 |
+
raise OpError(
|
| 32 |
+
self,
|
| 33 |
+
"expects the 'transpose' Op parameter to be a bool instance",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def __str__(self) -> str:
|
| 37 |
+
res = (
|
| 38 |
+
f"{self.__class__.__name__[:-2]} Copy Operation"
|
| 39 |
+
+ f"\n number of matrices = {self.num_matrices}"
|
| 40 |
+
)
|
| 41 |
+
if self.transpose:
|
| 42 |
+
res += f"\n transposed"
|
| 43 |
+
return res
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass(frozen=True)
|
| 47 |
+
class LdMatrix8x8x16bOp(BaseOp):
|
| 48 |
+
"""
|
| 49 |
+
8x8 ``ldmatrix`` Operation.
|
| 50 |
+
|
| 51 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-load-instruction-ldmatrix>`__.
|
| 52 |
+
This operation corresponds to the ``.m8n8`` qualifier.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __post_init__(self) -> None:
|
| 56 |
+
super().__post_init__()
|
| 57 |
+
if self.num_matrices not in [1, 2, 4]:
|
| 58 |
+
raise OpError(
|
| 59 |
+
self,
|
| 60 |
+
"expects the 'num_matrices' Op parameter to be one of [1,2,4]",
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def _make_trait(
|
| 64 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 65 |
+
) -> "LdMatrix8x8x16bTrait":
|
| 66 |
+
mode = _pack_shape((8, 8), loc=loc, ip=ip)
|
| 67 |
+
ty = _cute_nvgpu_ir.CopyAtomLdsmType.get(
|
| 68 |
+
copy_internal_type.mlir_type,
|
| 69 |
+
mode.type.attribute,
|
| 70 |
+
_cute_nvgpu_ir.LdsmSzPattern.u16,
|
| 71 |
+
self.num_matrices,
|
| 72 |
+
ir.UnitAttr.get() if self.transpose else None,
|
| 73 |
+
)
|
| 74 |
+
return LdMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class LdMatrix8x8x16bTrait(Trait):
|
| 78 |
+
pass
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@dataclass(frozen=True)
|
| 82 |
+
class LdMatrix16x16x8bOp(BaseOp):
|
| 83 |
+
"""
|
| 84 |
+
16x16 8-bit ``ldmatrix`` Operation.
|
| 85 |
+
|
| 86 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-load-instruction-ldmatrix>`__.
|
| 87 |
+
This operation corresponds to the ``.m16n16`` and the ``.b16`` qualifiers.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, num_matrices: int) -> None:
|
| 91 |
+
super().__init__(transpose=True, num_matrices=num_matrices)
|
| 92 |
+
self._verify()
|
| 93 |
+
|
| 94 |
+
def _verify(self):
|
| 95 |
+
assert self.transpose, "transpose must be True"
|
| 96 |
+
if self.num_matrices not in [1, 2]:
|
| 97 |
+
raise OpError(
|
| 98 |
+
self,
|
| 99 |
+
"expects the 'num_matrices' Op parameter to be one of [1,2]",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def _make_trait(
|
| 103 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 104 |
+
) -> "LdMatrix16x16x8bTrait":
|
| 105 |
+
mode = _pack_shape((16, 16), loc=loc, ip=ip)
|
| 106 |
+
ty = _cute_nvgpu_ir.CopyAtomLdsmType.get(
|
| 107 |
+
copy_internal_type.mlir_type,
|
| 108 |
+
mode.type.attribute,
|
| 109 |
+
_cute_nvgpu_ir.LdsmSzPattern.u8,
|
| 110 |
+
self.num_matrices,
|
| 111 |
+
ir.UnitAttr.get(),
|
| 112 |
+
)
|
| 113 |
+
return LdMatrix16x16x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class LdMatrix16x16x8bTrait(Trait):
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@dataclass(frozen=True)
|
| 121 |
+
class StMatrix8x8x16bOp(BaseOp):
|
| 122 |
+
"""
|
| 123 |
+
8x8 ``stmatrix`` Operation.
|
| 124 |
+
|
| 125 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-stmatrix>`__.
|
| 126 |
+
This operation corresponds to the ``m8n8`` qualifier.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __post_init__(self) -> None:
|
| 130 |
+
super().__post_init__()
|
| 131 |
+
if self.num_matrices not in [1, 2, 4]:
|
| 132 |
+
raise OpError(
|
| 133 |
+
self,
|
| 134 |
+
"expects the 'num_matrices' Op parameter to be one of [1,2,4]",
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def _make_trait(
|
| 138 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 139 |
+
) -> "StMatrix8x8x16bTrait":
|
| 140 |
+
mode = _pack_shape((8, 8), loc=loc, ip=ip)
|
| 141 |
+
ty = _cute_nvgpu_ir.CopyAtomStsmType.get(
|
| 142 |
+
copy_internal_type.mlir_type,
|
| 143 |
+
mode.type.attribute,
|
| 144 |
+
self.num_matrices,
|
| 145 |
+
ir.UnitAttr.get() if self.transpose else None,
|
| 146 |
+
)
|
| 147 |
+
return StMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class StMatrix8x8x16bTrait(Trait):
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@dataclass(frozen=True)
|
| 155 |
+
class StMatrix16x8x8bOp(BaseOp):
|
| 156 |
+
"""
|
| 157 |
+
16x8 ``stmatrix`` Operation.
|
| 158 |
+
|
| 159 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-stmatrix>`__.
|
| 160 |
+
This operation corresponds to the ``m16n8`` qualifier.
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def __init__(self, num_matrices: int) -> None:
|
| 164 |
+
super().__init__(transpose=True, num_matrices=num_matrices)
|
| 165 |
+
self._verify()
|
| 166 |
+
|
| 167 |
+
def _verify(self):
|
| 168 |
+
if self.num_matrices not in [1, 2, 4]:
|
| 169 |
+
assert self.transpose, "transpose must be True"
|
| 170 |
+
raise OpError(
|
| 171 |
+
self,
|
| 172 |
+
"expects the 'num_matrices' Op parameter to be one of [1,2,4]",
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def _make_trait(
|
| 176 |
+
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
| 177 |
+
) -> "StMatrix16x8x8bTrait":
|
| 178 |
+
mode = _pack_shape((16, 8), loc=loc, ip=ip)
|
| 179 |
+
ty = _cute_nvgpu_ir.CopyAtomStsmType.get(
|
| 180 |
+
copy_internal_type.mlir_type,
|
| 181 |
+
mode.type.attribute,
|
| 182 |
+
self.num_matrices,
|
| 183 |
+
ir.UnitAttr.get(),
|
| 184 |
+
)
|
| 185 |
+
return StMatrix16x8x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class StMatrix16x8x8bTrait(Trait):
|
| 189 |
+
pass
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import Type
|
| 14 |
+
|
| 15 |
+
import cutlass._mlir.dialects.cute as _cute_ir
|
| 16 |
+
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
| 17 |
+
|
| 18 |
+
from ..common import OpError
|
| 19 |
+
from ...core import MmaOp, Trait, _pack_shape, _Tensor
|
| 20 |
+
from ...typing import Shape, Float16, BFloat16, Float32, Numeric, AddressSpace
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass(frozen=True)
|
| 24 |
+
class MmaF16BF16Op(MmaOp):
|
| 25 |
+
"""
|
| 26 |
+
F16/BF16 tcgen05 MMA Operation.
|
| 27 |
+
|
| 28 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma>`__.
|
| 29 |
+
This Operation covers the instructions using the ``.f16`` or ``.bf16`` qualifiers for the input operands.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
ab_dtype: Type[Numeric]
|
| 33 |
+
acc_dtype: Type[Numeric]
|
| 34 |
+
shape_mnk: Shape
|
| 35 |
+
|
| 36 |
+
def __post_init__(self) -> None:
|
| 37 |
+
if self.ab_dtype not in [Float16, BFloat16]:
|
| 38 |
+
raise OpError(
|
| 39 |
+
self,
|
| 40 |
+
"expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16",
|
| 41 |
+
)
|
| 42 |
+
if self.acc_dtype not in [Float16, Float32]:
|
| 43 |
+
raise OpError(
|
| 44 |
+
self,
|
| 45 |
+
"expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
|
| 46 |
+
)
|
| 47 |
+
if (self.ab_dtype == BFloat16) and (self.acc_dtype != Float32):
|
| 48 |
+
raise OpError(
|
| 49 |
+
self,
|
| 50 |
+
"expects the 'acc_dtype' Op parameter to be Float32 when 'ab_dtype' is BFloat16",
|
| 51 |
+
)
|
| 52 |
+
if self.shape_mnk not in [(16, 8, 8), (16, 8, 16)]:
|
| 53 |
+
raise OpError(
|
| 54 |
+
self,
|
| 55 |
+
"expects the 'shape_mnk' Op parameter to be one of (16,8,8) or (16,8,16)",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait":
|
| 59 |
+
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
| 60 |
+
ty = _cute_nvgpu_ir.MmaAtomSM80Type.get(
|
| 61 |
+
shape_mnk.type.attribute,
|
| 62 |
+
self.ab_dtype.mlir_type,
|
| 63 |
+
self.ab_dtype.mlir_type,
|
| 64 |
+
self.acc_dtype.mlir_type,
|
| 65 |
+
)
|
| 66 |
+
return MmaF16BF16Trait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
| 67 |
+
|
| 68 |
+
def __str__(self) -> str:
|
| 69 |
+
return (
|
| 70 |
+
"warp-level F16/BF16 MMA Operation"
|
| 71 |
+
+ f"\n A/B data type = {self.ab_dtype}"
|
| 72 |
+
+ f"\n Accumulator data type = {self.acc_dtype}"
|
| 73 |
+
+ f"\n Instruction shape MNK = {self.shape_mnk}"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
|
| 80 |
+
pass
|
| 81 |
+
|
| 82 |
+
class MmaF16BF16Trait(Trait):
|
| 83 |
+
pass
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from .mma import *
|
| 13 |
+
from .helpers import *
|
| 14 |
+
|
| 15 |
+
# __all__ is required here for documentation generation
|
| 16 |
+
__all__ = [
|
| 17 |
+
# mma.py
|
| 18 |
+
"OperandMajorMode",
|
| 19 |
+
"OperandSource",
|
| 20 |
+
"Field",
|
| 21 |
+
"MmaF16BF16Op",
|
| 22 |
+
"MmaF8Op",
|
| 23 |
+
"SmemLayoutAtomKind",
|
| 24 |
+
# helpers.py
|
| 25 |
+
"make_smem_layout_atom",
|
| 26 |
+
"fence",
|
| 27 |
+
"commit_group",
|
| 28 |
+
"wait_group",
|
| 29 |
+
]
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from typing import Type
|
| 13 |
+
|
| 14 |
+
from cutlass.cutlass_dsl import dsl_user_op
|
| 15 |
+
|
| 16 |
+
from cutlass._mlir.dialects import nvvm
|
| 17 |
+
|
| 18 |
+
from ...typing import Numeric, NumericMeta
|
| 19 |
+
from ... import core
|
| 20 |
+
from .mma import SmemLayoutAtomKind
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dsl_user_op
|
| 24 |
+
def make_smem_layout_atom(
|
| 25 |
+
kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None
|
| 26 |
+
) -> core.ComposedLayout:
|
| 27 |
+
"""
|
| 28 |
+
Makes a SMEM layout Atom.
|
| 29 |
+
|
| 30 |
+
This function creates a composed layout in unit of elements consistent with the requested layout
|
| 31 |
+
Atom kind and element data type.
|
| 32 |
+
|
| 33 |
+
:param kind: The kind of layout Atom
|
| 34 |
+
:type kind: SmemLayoutAtomKind
|
| 35 |
+
:param element_type: The element data type to construct the layout for
|
| 36 |
+
:type element_type: Type[Numeric]
|
| 37 |
+
:return: The SMEM layout atom
|
| 38 |
+
:rtype: core.ComposedLayout
|
| 39 |
+
"""
|
| 40 |
+
if not isinstance(element_type, NumericMeta):
|
| 41 |
+
raise TypeError(f"element_type must be a Numeric, but got {element_type}")
|
| 42 |
+
|
| 43 |
+
if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER):
|
| 44 |
+
num_contiguous_bits = 128
|
| 45 |
+
sw = core.make_swizzle(0, 4, 3)
|
| 46 |
+
elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32):
|
| 47 |
+
num_contiguous_bits = 256
|
| 48 |
+
sw = core.make_swizzle(1, 4, 3)
|
| 49 |
+
elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64):
|
| 50 |
+
num_contiguous_bits = 512
|
| 51 |
+
sw = core.make_swizzle(2, 4, 3)
|
| 52 |
+
elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128):
|
| 53 |
+
num_contiguous_bits = 1024
|
| 54 |
+
sw = core.make_swizzle(3, 4, 3)
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError("unrecognized SMEM layout atom kind")
|
| 57 |
+
num_contiguous_elems = num_contiguous_bits // element_type.width
|
| 58 |
+
|
| 59 |
+
if kind in (
|
| 60 |
+
SmemLayoutAtomKind.MN_INTER,
|
| 61 |
+
SmemLayoutAtomKind.MN_SW32,
|
| 62 |
+
SmemLayoutAtomKind.MN_SW64,
|
| 63 |
+
SmemLayoutAtomKind.MN_SW128,
|
| 64 |
+
):
|
| 65 |
+
# M/N-major layout
|
| 66 |
+
return core.make_composed_layout(
|
| 67 |
+
sw,
|
| 68 |
+
0,
|
| 69 |
+
core.make_layout(
|
| 70 |
+
(num_contiguous_elems, 8), stride=(1, num_contiguous_elems)
|
| 71 |
+
),
|
| 72 |
+
loc=loc,
|
| 73 |
+
ip=ip,
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
# K-major layout
|
| 77 |
+
return core.make_composed_layout(
|
| 78 |
+
sw,
|
| 79 |
+
0,
|
| 80 |
+
core.make_layout(
|
| 81 |
+
(8, num_contiguous_elems), stride=(num_contiguous_elems, 1)
|
| 82 |
+
),
|
| 83 |
+
loc=loc,
|
| 84 |
+
ip=ip,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dsl_user_op
|
| 89 |
+
def fence(*, loc=None, ip=None) -> None:
|
| 90 |
+
"""
|
| 91 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-fence>`__.
|
| 92 |
+
"""
|
| 93 |
+
nvvm.wgmma_fence_aligned(loc=None, ip=None)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@dsl_user_op
|
| 97 |
+
def commit_group(*, loc=None, ip=None) -> None:
|
| 98 |
+
"""
|
| 99 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group>`__.
|
| 100 |
+
"""
|
| 101 |
+
nvvm.wgmma_commit_group_sync_aligned(loc=loc, ip=ip)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@dsl_user_op
|
| 105 |
+
def wait_group(group, *, loc=None, ip=None) -> None:
|
| 106 |
+
"""
|
| 107 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-wait-group>`__.
|
| 108 |
+
"""
|
| 109 |
+
nvvm.wgmma_wait_group_sync_aligned(group, loc=loc, ip=ip)
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
import enum
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Type
|
| 15 |
+
|
| 16 |
+
from cutlass.cutlass_dsl import CuTeDSL
|
| 17 |
+
|
| 18 |
+
import cutlass._mlir.dialects.cute as _cute_ir
|
| 19 |
+
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
| 20 |
+
from cutlass._mlir import ir
|
| 21 |
+
|
| 22 |
+
from ..common import OpError
|
| 23 |
+
from ...core import MmaOp, Trait, _pack_shape, rank, depth, _Tensor
|
| 24 |
+
from ...typing import (
|
| 25 |
+
Shape,
|
| 26 |
+
Float16,
|
| 27 |
+
BFloat16,
|
| 28 |
+
Float32,
|
| 29 |
+
Boolean,
|
| 30 |
+
Float8E5M2,
|
| 31 |
+
Float8E4M3FN,
|
| 32 |
+
Numeric,
|
| 33 |
+
AddressSpace,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
####################################################################################################
|
| 38 |
+
#
|
| 39 |
+
# MMA Ops and Traits
|
| 40 |
+
#
|
| 41 |
+
####################################################################################################
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class OperandMajorMode(enum.Enum):
|
| 45 |
+
"""
|
| 46 |
+
An enumeration for the majorness of the input operands of the MMA.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
MN = _cute_ir.MajorMode.mn
|
| 50 |
+
K = _cute_ir.MajorMode.k
|
| 51 |
+
|
| 52 |
+
def __str__(self) -> str:
|
| 53 |
+
return f"{self.__class__.__name__}.{self.name}"
|
| 54 |
+
|
| 55 |
+
def __repr__(self) -> str:
|
| 56 |
+
return f"<{self.__class__.__name__}.{self.name}>"
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def _missing_(cls, value):
|
| 60 |
+
if isinstance(value, str):
|
| 61 |
+
value = value.upper()
|
| 62 |
+
if value == "MN":
|
| 63 |
+
return OperandMajorMode.MN
|
| 64 |
+
elif value == "K":
|
| 65 |
+
return OperandMajorMode.K
|
| 66 |
+
|
| 67 |
+
def _to_ir(self) -> _cute_ir.MajorMode:
|
| 68 |
+
return self.value
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class OperandSource(enum.Enum):
|
| 72 |
+
"""
|
| 73 |
+
An enumeration for the source memory location of the A input operand of the MMA.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
RMEM = _cute_ir.MmaFragKind.rmem
|
| 77 |
+
SMEM = _cute_ir.MmaFragKind.smem_desc
|
| 78 |
+
|
| 79 |
+
def __str__(self) -> str:
|
| 80 |
+
return f"{self.__class__.__name__}.{self.name}"
|
| 81 |
+
|
| 82 |
+
def __repr__(self) -> str:
|
| 83 |
+
return f"<{self.__class__.__name__}.{self.name}>"
|
| 84 |
+
|
| 85 |
+
def _to_ir(self) -> _cute_ir.MmaFragKind:
|
| 86 |
+
return self.value
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class Field(enum.Enum):
|
| 90 |
+
"""
|
| 91 |
+
An enumeration for the fields of the MMA Atom that can be modified at runtime.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
ACCUMULATE = "accum_c"
|
| 95 |
+
|
| 96 |
+
def __str__(self) -> str:
|
| 97 |
+
return f"{self.__class__.__name__}.{self.name}"
|
| 98 |
+
|
| 99 |
+
def __repr__(self) -> str:
|
| 100 |
+
return f"<{self.__class__.__name__}.{self.name}>"
|
| 101 |
+
|
| 102 |
+
def _to_ir_field_name(self) -> str:
|
| 103 |
+
return self.value
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@dataclass(frozen=True)
|
| 107 |
+
class MmaOp(MmaOp):
|
| 108 |
+
a_dtype: Type[Numeric]
|
| 109 |
+
b_dtype: Type[Numeric]
|
| 110 |
+
acc_dtype: Type[Numeric]
|
| 111 |
+
shape_mnk: Shape
|
| 112 |
+
a_src: OperandSource
|
| 113 |
+
a_major_mode: OperandMajorMode
|
| 114 |
+
b_major_mode: OperandMajorMode
|
| 115 |
+
|
| 116 |
+
admissible_archs = ["sm_90a"]
|
| 117 |
+
|
| 118 |
+
def __post_init__(self) -> None:
|
| 119 |
+
# Verify arch
|
| 120 |
+
arch = CuTeDSL._get_dsl().envar.arch
|
| 121 |
+
if arch not in self.admissible_archs:
|
| 122 |
+
raise OpError(
|
| 123 |
+
self,
|
| 124 |
+
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
| 125 |
+
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
| 126 |
+
)
|
| 127 |
+
# Verify that the user provided enum values
|
| 128 |
+
if not isinstance(self.a_src, OperandSource):
|
| 129 |
+
raise OpError(
|
| 130 |
+
self,
|
| 131 |
+
"expects the 'a_src' Op parameter to be a warpgroup.OperandSource instance",
|
| 132 |
+
)
|
| 133 |
+
if not isinstance(self.a_major_mode, OperandMajorMode):
|
| 134 |
+
raise OpError(
|
| 135 |
+
self,
|
| 136 |
+
"expects the 'a_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance",
|
| 137 |
+
)
|
| 138 |
+
if not isinstance(self.b_major_mode, OperandMajorMode):
|
| 139 |
+
raise OpError(
|
| 140 |
+
self,
|
| 141 |
+
"expects the 'b_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance",
|
| 142 |
+
)
|
| 143 |
+
# Verify instruction shape
|
| 144 |
+
if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1):
|
| 145 |
+
raise OpError(
|
| 146 |
+
self,
|
| 147 |
+
f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, "
|
| 148 |
+
f"but got {self.shape_mnk}",
|
| 149 |
+
)
|
| 150 |
+
m, n = self.shape_mnk[0], self.shape_mnk[1]
|
| 151 |
+
if m != 64:
|
| 152 |
+
raise OpError(self, f"expects the M-mode to be 64, but got {m}")
|
| 153 |
+
if (n < 8) or (n > 256) or (n % 8 != 0):
|
| 154 |
+
raise OpError(
|
| 155 |
+
self,
|
| 156 |
+
f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0. but got {n}",
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def __str__(self) -> str:
|
| 160 |
+
return (
|
| 161 |
+
self.__class__.descriptive_name # type: ignore
|
| 162 |
+
+ f"\n A data type = {self.a_dtype}"
|
| 163 |
+
+ f"\n B data type = {self.b_dtype}"
|
| 164 |
+
+ f"\n Accumulator data type = {self.acc_dtype}"
|
| 165 |
+
+ f"\n A source location = {self.a_src}"
|
| 166 |
+
+ f"\n A major mode = {self.a_major_mode}"
|
| 167 |
+
+ f"\n B major mode = {self.b_major_mode}"
|
| 168 |
+
+ f"\n Instruction shape MNK = {self.shape_mnk}"
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
|
| 172 |
+
if input.memspace == AddressSpace.smem and isinstance(
|
| 173 |
+
input.layout.type, _cute_ir.ComposedLayoutType
|
| 174 |
+
):
|
| 175 |
+
raise OpError(
|
| 176 |
+
self,
|
| 177 |
+
f"Expected affine layout for {self._make_trait()}'s operand A, "
|
| 178 |
+
f"but got composed layout instead: {input.layout}"
|
| 179 |
+
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
|
| 180 |
+
)
|
| 181 |
+
return True
|
| 182 |
+
|
| 183 |
+
def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
|
| 184 |
+
if input.memspace == AddressSpace.smem and isinstance(
|
| 185 |
+
input.layout.type, _cute_ir.ComposedLayoutType
|
| 186 |
+
):
|
| 187 |
+
raise OpError(
|
| 188 |
+
self,
|
| 189 |
+
f"Expected affine layout for {self._make_trait()}'s operand B, "
|
| 190 |
+
f"but got composed layout instead: {input.layout}"
|
| 191 |
+
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
|
| 192 |
+
)
|
| 193 |
+
return True
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class MmaTrait(Trait):
|
| 197 |
+
admissible_fields = [Field.ACCUMULATE]
|
| 198 |
+
|
| 199 |
+
def set(self, field, value, *, loc=None, ip=None) -> None:
|
| 200 |
+
if field not in self.admissible_fields:
|
| 201 |
+
raise ValueError(
|
| 202 |
+
f"invalid field, must be {Field.ACCUMULATE}, but got {field}"
|
| 203 |
+
)
|
| 204 |
+
field_name = f"#cute_nvgpu.atom_mma_field_sm90<{field._to_ir_field_name()}>"
|
| 205 |
+
attr = ir.Attribute.parse(field_name)
|
| 206 |
+
self.value = _cute_nvgpu_ir.atom_set_value(
|
| 207 |
+
self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@dataclass(frozen=True)
|
| 212 |
+
class MmaF16BF16Op(MmaOp):
|
| 213 |
+
"""
|
| 214 |
+
F16/BF16 warpgroup MMA Operation.
|
| 215 |
+
|
| 216 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-mma-async>`__.
|
| 217 |
+
This Operation covers the instructions using the ``.f16`` or ``.bf16`` qualifiers for the input operands.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
descriptive_name = "warpgroup F16/BF16 MMA Operation"
|
| 221 |
+
|
| 222 |
+
def __init__(
|
| 223 |
+
self,
|
| 224 |
+
ab_dtype: Type[Numeric],
|
| 225 |
+
acc_dtype: Type[Numeric],
|
| 226 |
+
instruction_shape: Shape,
|
| 227 |
+
a_src: OperandSource,
|
| 228 |
+
a_major_mode: OperandMajorMode,
|
| 229 |
+
b_major_mode: OperandMajorMode,
|
| 230 |
+
) -> None:
|
| 231 |
+
super().__init__(
|
| 232 |
+
ab_dtype,
|
| 233 |
+
ab_dtype,
|
| 234 |
+
acc_dtype,
|
| 235 |
+
instruction_shape,
|
| 236 |
+
a_src,
|
| 237 |
+
a_major_mode,
|
| 238 |
+
b_major_mode,
|
| 239 |
+
)
|
| 240 |
+
self._verify()
|
| 241 |
+
|
| 242 |
+
def _verify(self) -> None:
|
| 243 |
+
# Input data type verification
|
| 244 |
+
if self.a_dtype not in [Float16, BFloat16]:
|
| 245 |
+
raise OpError(
|
| 246 |
+
self,
|
| 247 |
+
"expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16",
|
| 248 |
+
)
|
| 249 |
+
assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same"
|
| 250 |
+
# Accumulator data type verification
|
| 251 |
+
if self.acc_dtype not in [Float16, Float32]:
|
| 252 |
+
raise OpError(
|
| 253 |
+
self,
|
| 254 |
+
"expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
|
| 255 |
+
)
|
| 256 |
+
if (self.a_dtype == BFloat16) and (self.acc_dtype != Float32):
|
| 257 |
+
raise OpError(
|
| 258 |
+
self,
|
| 259 |
+
"expects the 'acc_dtype' Op parameter to be Float32 when 'ab_dtype' is BFloat16",
|
| 260 |
+
)
|
| 261 |
+
# Verify the instruction shape
|
| 262 |
+
instruction_k = 16
|
| 263 |
+
if rank(self.shape_mnk) == 2:
|
| 264 |
+
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
| 265 |
+
if self.shape_mnk[2] != instruction_k:
|
| 266 |
+
raise OpError(
|
| 267 |
+
self,
|
| 268 |
+
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
| 269 |
+
f"but got {self.shape_mnk[2]}",
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait":
|
| 273 |
+
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
| 274 |
+
ty = _cute_nvgpu_ir.MmaAtomSM90Type.get(
|
| 275 |
+
shape_mnk.type.attribute,
|
| 276 |
+
self.a_major_mode._to_ir(),
|
| 277 |
+
self.b_major_mode._to_ir(),
|
| 278 |
+
self.a_dtype.mlir_type,
|
| 279 |
+
self.b_dtype.mlir_type,
|
| 280 |
+
self.acc_dtype.mlir_type,
|
| 281 |
+
self.a_src._to_ir(),
|
| 282 |
+
)
|
| 283 |
+
return MmaF16BF16Trait(
|
| 284 |
+
_cute_nvgpu_ir.make_sm90_mma(
|
| 285 |
+
ty,
|
| 286 |
+
Boolean(False).ir_value(loc=loc, ip=ip),
|
| 287 |
+
loc=loc,
|
| 288 |
+
ip=ip,
|
| 289 |
+
)
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class MmaF16BF16Trait(MmaTrait):
|
| 294 |
+
pass
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
@dataclass(frozen=True)
|
| 298 |
+
class MmaF8Op(MmaOp):
|
| 299 |
+
"""
|
| 300 |
+
F16/BF16 warpgroup MMA Operation.
|
| 301 |
+
|
| 302 |
+
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-multiply-and-accumulate-instruction-wgmma-mma-async>`__.
|
| 303 |
+
This Operation covers the instructions using the ``.e4m3`` or ``.e5m2`` qualifiers for the input operands.
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
descriptive_name = "warpgroup F8 MMA Operation"
|
| 307 |
+
|
| 308 |
+
def __init__(
|
| 309 |
+
self,
|
| 310 |
+
a_dtype: Type[Numeric],
|
| 311 |
+
b_dtype: Type[Numeric],
|
| 312 |
+
acc_dtype: Type[Numeric],
|
| 313 |
+
instruction_shape: Shape,
|
| 314 |
+
a_src: OperandSource,
|
| 315 |
+
a_major_mode: OperandMajorMode,
|
| 316 |
+
b_major_mode: OperandMajorMode,
|
| 317 |
+
) -> None:
|
| 318 |
+
super().__init__(
|
| 319 |
+
a_dtype,
|
| 320 |
+
b_dtype,
|
| 321 |
+
acc_dtype,
|
| 322 |
+
instruction_shape,
|
| 323 |
+
a_src,
|
| 324 |
+
a_major_mode,
|
| 325 |
+
b_major_mode,
|
| 326 |
+
)
|
| 327 |
+
self._verify()
|
| 328 |
+
|
| 329 |
+
def _verify(self):
|
| 330 |
+
# Input data type verification
|
| 331 |
+
if self.a_dtype not in [Float8E5M2, Float8E4M3FN]:
|
| 332 |
+
raise OpError(
|
| 333 |
+
self,
|
| 334 |
+
"expects the 'a_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
|
| 335 |
+
)
|
| 336 |
+
if self.b_dtype not in [Float8E5M2, Float8E4M3FN]:
|
| 337 |
+
raise OpError(
|
| 338 |
+
self,
|
| 339 |
+
"expects the 'b_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN",
|
| 340 |
+
)
|
| 341 |
+
# Accumulator data type verification
|
| 342 |
+
if self.acc_dtype not in [Float16, Float32]:
|
| 343 |
+
raise OpError(
|
| 344 |
+
self,
|
| 345 |
+
"expects the 'acc_dtype' Op parameter to be one of Float16 or Float32",
|
| 346 |
+
)
|
| 347 |
+
# Verify the instruction shape
|
| 348 |
+
instruction_k = 32
|
| 349 |
+
if rank(self.shape_mnk) == 2:
|
| 350 |
+
object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k))
|
| 351 |
+
if self.shape_mnk[2] != instruction_k:
|
| 352 |
+
raise OpError(
|
| 353 |
+
self,
|
| 354 |
+
f"expects the instruction extent in the K-mode to be {instruction_k}, "
|
| 355 |
+
f"but got {self.shape_mnk[2]}",
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF8Trait":
|
| 359 |
+
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
| 360 |
+
ty = _cute_nvgpu_ir.MmaAtomSM90Type.get(
|
| 361 |
+
shape_mnk.type.attribute,
|
| 362 |
+
self.a_major_mode._to_ir(),
|
| 363 |
+
self.b_major_mode._to_ir(),
|
| 364 |
+
self.a_dtype.mlir_type,
|
| 365 |
+
self.b_dtype.mlir_type,
|
| 366 |
+
self.acc_dtype.mlir_type,
|
| 367 |
+
self.a_src._to_ir(),
|
| 368 |
+
)
|
| 369 |
+
return MmaF8Trait(
|
| 370 |
+
_cute_nvgpu_ir.make_sm90_mma(
|
| 371 |
+
ty, Boolean(False).ir_value(loc=loc, ip=ip), loc=loc, ip=ip
|
| 372 |
+
)
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class MmaF8Trait(MmaTrait):
|
| 377 |
+
pass
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
####################################################################################################
|
| 381 |
+
#
|
| 382 |
+
# SMEM layout atoms
|
| 383 |
+
#
|
| 384 |
+
####################################################################################################
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class SmemLayoutAtomKind(enum.Enum):
|
| 388 |
+
"""
|
| 389 |
+
Enum class for the kinds of SMEM layout atoms for SM90.
|
| 390 |
+
|
| 391 |
+
Given a swizzle kind, an SMEM layout atom is the compact layout of smallest size that can
|
| 392 |
+
be used to construct an SMEM layout using blocked product for operand A or B such that the
|
| 393 |
+
resulting layout is legal for both TMA and UMMA.
|
| 394 |
+
|
| 395 |
+
Note that there are other ways of creating legal layouts for operand A and B.
|
| 396 |
+
"""
|
| 397 |
+
|
| 398 |
+
MN_INTER = enum.auto()
|
| 399 |
+
MN_SW32 = enum.auto()
|
| 400 |
+
MN_SW64 = enum.auto()
|
| 401 |
+
MN_SW128 = enum.auto()
|
| 402 |
+
K_INTER = enum.auto()
|
| 403 |
+
K_SW32 = enum.auto()
|
| 404 |
+
K_SW64 = enum.auto()
|
| 405 |
+
K_SW128 = enum.auto()
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/runtime.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
import ctypes
|
| 13 |
+
from functools import lru_cache
|
| 14 |
+
import itertools
|
| 15 |
+
import operator
|
| 16 |
+
from time import time
|
| 17 |
+
from typing import Union
|
| 18 |
+
|
| 19 |
+
# MLIR modules imports
|
| 20 |
+
from cutlass._mlir import ir
|
| 21 |
+
import cutlass._mlir.dialects.cute as _cute_ir
|
| 22 |
+
|
| 23 |
+
from cutlass.base_dsl.dsl import is_dynamic_expression
|
| 24 |
+
from cutlass.cutlass_dsl import JitArgAdapterRegistry
|
| 25 |
+
|
| 26 |
+
# Local modules imports
|
| 27 |
+
from .typing import (
|
| 28 |
+
AddressSpace,
|
| 29 |
+
Tensor,
|
| 30 |
+
Type,
|
| 31 |
+
Pointer,
|
| 32 |
+
Boolean,
|
| 33 |
+
Numeric,
|
| 34 |
+
Float4E2M1FN,
|
| 35 |
+
Int64,
|
| 36 |
+
Int32,
|
| 37 |
+
Int16,
|
| 38 |
+
Int8,
|
| 39 |
+
Uint64,
|
| 40 |
+
Uint32,
|
| 41 |
+
Uint16,
|
| 42 |
+
Uint8,
|
| 43 |
+
Float64,
|
| 44 |
+
Float32,
|
| 45 |
+
Float16,
|
| 46 |
+
BFloat16,
|
| 47 |
+
Float8E5M2,
|
| 48 |
+
)
|
| 49 |
+
from . import core
|
| 50 |
+
from .core import _Tensor as CoreTensor
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class _Pointer(Pointer):
|
| 54 |
+
"""Runtime representation of a pointer that can inter-operate with various data structures,
|
| 55 |
+
including numpy arrays and device memory.
|
| 56 |
+
|
| 57 |
+
:param pointer: The pointer to the data
|
| 58 |
+
:type pointer: int or pointer-like object
|
| 59 |
+
:param dtype: Data type of the elements pointed to
|
| 60 |
+
:type dtype: Type
|
| 61 |
+
:param mem_space: Memory space where the pointer resides, defaults to generic
|
| 62 |
+
:type mem_space: _cute_ir.AddressSpace, optional
|
| 63 |
+
:param assumed_align: Assumed alignment of input pointer in bytes, defaults to None
|
| 64 |
+
:type assumed_align: int, optional
|
| 65 |
+
|
| 66 |
+
:ivar _pointer: The underlying pointer
|
| 67 |
+
:ivar _dtype: Data type of the elements
|
| 68 |
+
:ivar _addr_space: Memory space of the pointer
|
| 69 |
+
:ivar _assumed_align: Alignment of the pointer in bytes
|
| 70 |
+
:ivar _desc: C-type descriptor for the pointer
|
| 71 |
+
:ivar _c_pointer: C-compatible pointer representation
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
pointer,
|
| 77 |
+
dtype,
|
| 78 |
+
mem_space: _cute_ir.AddressSpace = _cute_ir.AddressSpace.generic,
|
| 79 |
+
assumed_align=None,
|
| 80 |
+
):
|
| 81 |
+
self._pointer = pointer
|
| 82 |
+
self._dtype = dtype
|
| 83 |
+
self._addr_space = mem_space
|
| 84 |
+
|
| 85 |
+
if assumed_align is None:
|
| 86 |
+
self._assumed_align = dtype.width // 8
|
| 87 |
+
else:
|
| 88 |
+
self._assumed_align = assumed_align
|
| 89 |
+
|
| 90 |
+
self._c_pointer = None
|
| 91 |
+
assert (
|
| 92 |
+
int(self._pointer) % self._assumed_align == 0
|
| 93 |
+
), f"pointer must be {self._assumed_align} bytes aligned"
|
| 94 |
+
|
| 95 |
+
def size_in_bytes(self) -> int:
|
| 96 |
+
self._desc = ctypes.c_void_p(int(self._pointer))
|
| 97 |
+
return ctypes.sizeof(self._desc)
|
| 98 |
+
|
| 99 |
+
def __get_mlir_types__(self):
|
| 100 |
+
return [self.mlir_type]
|
| 101 |
+
|
| 102 |
+
def __c_pointers__(self):
|
| 103 |
+
if self._c_pointer is None:
|
| 104 |
+
self._desc = ctypes.c_void_p(int(self._pointer))
|
| 105 |
+
self._c_pointer = ctypes.addressof(self._desc)
|
| 106 |
+
return [self._c_pointer]
|
| 107 |
+
|
| 108 |
+
def __new_from_mlir_values__(self, values):
|
| 109 |
+
assert len(values) == 1
|
| 110 |
+
return values[0]
|
| 111 |
+
|
| 112 |
+
def __extract_mlir_values__(self):
|
| 113 |
+
return [self._c_pointer]
|
| 114 |
+
|
| 115 |
+
# Move mlir Type out of __init__ to decouple with mlir Context
|
| 116 |
+
@property
|
| 117 |
+
def mlir_type(self) -> ir.Type:
|
| 118 |
+
return _cute_ir.PtrType.get(
|
| 119 |
+
self._dtype.mlir_type, self._addr_space, self._assumed_align
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def dtype(self) -> Type[Numeric]:
|
| 124 |
+
return self._dtype
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def memspace(self):
|
| 128 |
+
return self._addr_space
|
| 129 |
+
|
| 130 |
+
def align(self, min_align: int, *, loc=None, ip=None) -> Pointer:
|
| 131 |
+
raise NotImplementedError("align is not supported in runtime")
|
| 132 |
+
|
| 133 |
+
def verify(self, expected_py_type):
|
| 134 |
+
if expected_py_type is Pointer:
|
| 135 |
+
return True
|
| 136 |
+
elif isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer:
|
| 137 |
+
return True
|
| 138 |
+
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
def __str__(self) -> str:
|
| 142 |
+
return f"Ptr<0x{int(self._pointer):016x}@{self._addr_space}>"
|
| 143 |
+
|
| 144 |
+
def __repr__(self):
|
| 145 |
+
return self.__str__()
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class _Tensor(Tensor):
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
tensor,
|
| 152 |
+
assumed_align=None,
|
| 153 |
+
):
|
| 154 |
+
# If tensor is already a DLPack object, use it directly
|
| 155 |
+
if hasattr(tensor, "__dlpack_device__") and not hasattr(tensor, "__dlpack__"):
|
| 156 |
+
self._dlpack_data = tensor
|
| 157 |
+
else:
|
| 158 |
+
self._dlpack_data = tensor.__dlpack__()
|
| 159 |
+
self._dltensor_wrapper = None
|
| 160 |
+
self._assumed_align = assumed_align
|
| 161 |
+
self._is_dynamic = False
|
| 162 |
+
self._memref_desc = None
|
| 163 |
+
self._dtype = None
|
| 164 |
+
|
| 165 |
+
@property
|
| 166 |
+
def __class__(self) -> Type[Tensor]:
|
| 167 |
+
# Cheat to let `type(_Tensor())` to return cute.Tensor
|
| 168 |
+
return Tensor
|
| 169 |
+
|
| 170 |
+
@staticmethod
|
| 171 |
+
def lazily_load_dltensor(func):
|
| 172 |
+
"""Decorator to lazily load the DLTensorWrapper.
|
| 173 |
+
|
| 174 |
+
This decorator loads the DLTensorWrapper when needed,
|
| 175 |
+
avoiding overhead in the critical path of calling JIT functions.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def wrapper(self, *args, **kwargs):
|
| 179 |
+
if self._dltensor_wrapper is None:
|
| 180 |
+
self._dltensor_wrapper = _cute_ir.DLTensorWrapper(self._dlpack_data)
|
| 181 |
+
return func(self, *args, **kwargs)
|
| 182 |
+
|
| 183 |
+
return wrapper
|
| 184 |
+
|
| 185 |
+
@lazily_load_dltensor
|
| 186 |
+
def mark_layout_dynamic(self, leading_dim: int | None = None):
|
| 187 |
+
"""Marks the tensor layout as dynamic based on the leading dimension.
|
| 188 |
+
|
| 189 |
+
:param leading_dim: The leading dimension of the layout, defaults to None
|
| 190 |
+
:type leading_dim: int, optional
|
| 191 |
+
|
| 192 |
+
When ``leading_dim`` is None, automatically deduces the leading dimension from the tensor layout.
|
| 193 |
+
The layout can be deduced only when exactly one dimension has a stride of 1. Raises an error
|
| 194 |
+
if the layout cannot be automatically deduced.
|
| 195 |
+
|
| 196 |
+
When ``leading_dim`` is explicitly specified, marks the layout as dynamic while setting the
|
| 197 |
+
stride at ``leading_dim`` to 1. Also validates that the specified ``leading_dim`` is consistent
|
| 198 |
+
with the existing layout by checking that the corresponding stride of that dimension is 1.
|
| 199 |
+
|
| 200 |
+
Limitation: only support flat layout for now. Will work on supporting nested layout in the future.
|
| 201 |
+
|
| 202 |
+
:return: The tensor with dynamic layout
|
| 203 |
+
:rtype: _Tensor
|
| 204 |
+
"""
|
| 205 |
+
self._dltensor_wrapper.mark_layout_dynamic(leading_dim)
|
| 206 |
+
return self
|
| 207 |
+
|
| 208 |
+
@lazily_load_dltensor
|
| 209 |
+
def mark_compact_shape_dynamic(
|
| 210 |
+
self,
|
| 211 |
+
mode: int,
|
| 212 |
+
stride_order: tuple[int, ...] | None = None,
|
| 213 |
+
divisibility: int = 1,
|
| 214 |
+
):
|
| 215 |
+
"""Marks the tensor shape as dynamic and propagates dynamic and divisibility information to the corresponding strides.
|
| 216 |
+
|
| 217 |
+
:param mode: The mode of the compact shape, defaults to 0
|
| 218 |
+
:type mode: int
|
| 219 |
+
:param stride_order: Consistent with `torch.Tensor.dim_order`. Defaults to None.
|
| 220 |
+
Indicates the order of the modes (dimensions) if the current layout were converted to row-major order.
|
| 221 |
+
It starts from the outermost to the innermost dimension.
|
| 222 |
+
:type stride_order: tuple[int, ...], optional
|
| 223 |
+
:param divisibility: The divisibility constraint for the compact shape, defaults to 1
|
| 224 |
+
:type divisibility: int, optional
|
| 225 |
+
:return: The tensor with dynamic compact shape
|
| 226 |
+
:rtype: _Tensor
|
| 227 |
+
|
| 228 |
+
If ``stride_order`` is not provided, the stride ordering will be automatically deduced from the layout.
|
| 229 |
+
Automatic deduction is only possible when exactly one dimension has a stride of 1 (compact layout).
|
| 230 |
+
An error is raised if automatic deduction fails.
|
| 231 |
+
|
| 232 |
+
If ``stride_order`` is explicitly specified, it does the consistency check with the layout.
|
| 233 |
+
|
| 234 |
+
For example:
|
| 235 |
+
- Layout: (4,2):(1,4) has stride_order: (1,0) indicates the innermost dimension is 0(`4:1`), the outermost dimension is 1(`2:4`)
|
| 236 |
+
- Layout: (5,3,2,4):(3,1,15,30) has stride_order: (3,2,0,1) indicates the innermost dimension is 1(`3:1`), the outermost dimension is 3(`4:30`).
|
| 237 |
+
|
| 238 |
+
Using `torch.Tensor.dim_order()` to get the stride order of the torch tensor.
|
| 239 |
+
.. code-block:: python
|
| 240 |
+
a = torch.empty(3, 4)
|
| 241 |
+
t = cute.runtime.from_dlpack(a)
|
| 242 |
+
t = t.mark_compact_shape_dynamic(mode=0, stride_order=a.dim_order())
|
| 243 |
+
"""
|
| 244 |
+
self._dltensor_wrapper.mark_compact_shape_dynamic(
|
| 245 |
+
mode, stride_order, divisibility
|
| 246 |
+
)
|
| 247 |
+
return self
|
| 248 |
+
|
| 249 |
+
@property
|
| 250 |
+
@lazily_load_dltensor
|
| 251 |
+
def element_type(self) -> Type[Numeric]:
|
| 252 |
+
if self._dtype is None:
|
| 253 |
+
self._dtype = self._dltensor_wrapper.dtype
|
| 254 |
+
return self._dtype
|
| 255 |
+
|
| 256 |
+
@element_type.setter
|
| 257 |
+
def element_type(self, new_type):
|
| 258 |
+
"""Set the element type of the tensor.
|
| 259 |
+
|
| 260 |
+
:warning: This API is added for narrow precision before we have a clean `recast_tensor` story.
|
| 261 |
+
|
| 262 |
+
:note: It is only used for the case that frameworks don't natively support narrow precision but we get tensor
|
| 263 |
+
from frameworks with storage type like uint8.
|
| 264 |
+
|
| 265 |
+
**Example**:
|
| 266 |
+
|
| 267 |
+
.. code-block:: python
|
| 268 |
+
|
| 269 |
+
# Create a tensor from a numpy array
|
| 270 |
+
import numpy as np
|
| 271 |
+
from cutlass.cute import from_dlpack
|
| 272 |
+
|
| 273 |
+
# Create a tensor with Float32 elements
|
| 274 |
+
a = np.zeros(shape, dtype=np.uint8)
|
| 275 |
+
tensor = from_dlpack(a)
|
| 276 |
+
|
| 277 |
+
# Change the element type to Float4E2M1FN even storage type is uint8
|
| 278 |
+
tensor.element_type = cutlass.Float4E2M1FN
|
| 279 |
+
|
| 280 |
+
src = from_dlpack(... data tensor ...)
|
| 281 |
+
# convert and initialize narrow precision tensor
|
| 282 |
+
cute.testing.convert(src, tensor)
|
| 283 |
+
"""
|
| 284 |
+
self._dtype = new_type
|
| 285 |
+
|
| 286 |
+
@property
|
| 287 |
+
@lazily_load_dltensor
|
| 288 |
+
def memspace(self):
|
| 289 |
+
return self._dltensor_wrapper.address_space
|
| 290 |
+
|
| 291 |
+
@property
|
| 292 |
+
@lazily_load_dltensor
|
| 293 |
+
def size_in_bytes(self) -> int:
|
| 294 |
+
return self._dltensor_wrapper.size_in_bytes()
|
| 295 |
+
|
| 296 |
+
@property
|
| 297 |
+
@lazily_load_dltensor
|
| 298 |
+
def mlir_type(self) -> ir.Type:
|
| 299 |
+
return self._dltensor_wrapper.get_type(
|
| 300 |
+
self.element_type.mlir_type, self._assumed_align
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
@lazily_load_dltensor
|
| 304 |
+
def __str__(self) -> str:
|
| 305 |
+
return f"Tensor<0x{self._dltensor_wrapper.str}>"
|
| 306 |
+
|
| 307 |
+
def __repr__(self):
|
| 308 |
+
return self.__str__()
|
| 309 |
+
|
| 310 |
+
def __setitem__(self, crd, value):
|
| 311 |
+
raise TypeError(f"runtime._Tensor is not indexable")
|
| 312 |
+
|
| 313 |
+
def __getitem__(self, crd):
|
| 314 |
+
raise TypeError(f"runtime._Tensor is not indexable")
|
| 315 |
+
|
| 316 |
+
@property
|
| 317 |
+
@lazily_load_dltensor
|
| 318 |
+
def iterator(self):
|
| 319 |
+
return _Pointer(
|
| 320 |
+
self._dltensor_wrapper.data_ptr,
|
| 321 |
+
self.element_type,
|
| 322 |
+
self.memspace,
|
| 323 |
+
self._assumed_align,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
@property
|
| 327 |
+
def layout(self):
|
| 328 |
+
raise NotImplementedError(
|
| 329 |
+
f"layout property is not supported in runtime, support in future"
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
@property
|
| 333 |
+
@lazily_load_dltensor
|
| 334 |
+
def shape(self):
|
| 335 |
+
return self._dltensor_wrapper.shape
|
| 336 |
+
|
| 337 |
+
@property
|
| 338 |
+
@lazily_load_dltensor
|
| 339 |
+
def stride(self):
|
| 340 |
+
strides = self._dltensor_wrapper.stride
|
| 341 |
+
if strides is None:
|
| 342 |
+
strides = itertools.accumulate(
|
| 343 |
+
reversed(self.shape), func=operator.mul, initial=1
|
| 344 |
+
)
|
| 345 |
+
strides = tuple(reversed(list(strides)[:-1]))
|
| 346 |
+
|
| 347 |
+
return strides
|
| 348 |
+
|
| 349 |
+
@property
|
| 350 |
+
@lru_cache(maxsize=128, typed=True)
|
| 351 |
+
def leading_dim(self):
|
| 352 |
+
"""Get the leading dimension of this Tensor.
|
| 353 |
+
|
| 354 |
+
:return: The leading dimension index or indices
|
| 355 |
+
:rtype: int or tuple or None
|
| 356 |
+
|
| 357 |
+
The return value depends on the tensor's stride pattern:
|
| 358 |
+
|
| 359 |
+
* If a single leading dimension is found, returns an integer index
|
| 360 |
+
* If nested leading dimensions are found, returns a tuple of indices
|
| 361 |
+
* If no leading dimension is found, returns None
|
| 362 |
+
"""
|
| 363 |
+
return core.leading_dim(self.shape, self.stride)
|
| 364 |
+
|
| 365 |
+
def fill(self, value: Numeric):
|
| 366 |
+
raise TypeError(f"fill function is not supported in runtime")
|
| 367 |
+
|
| 368 |
+
@property
|
| 369 |
+
@lazily_load_dltensor
|
| 370 |
+
def data_ptr(self):
|
| 371 |
+
return self._dltensor_wrapper.data_ptr
|
| 372 |
+
|
| 373 |
+
@lazily_load_dltensor
|
| 374 |
+
def __c_pointers__(self):
|
| 375 |
+
self._memref_desc = self._dltensor_wrapper.build_memref_desc(
|
| 376 |
+
self._assumed_align
|
| 377 |
+
)
|
| 378 |
+
return [_cute_ir.pycapsule_get_pointer(self._memref_desc)]
|
| 379 |
+
|
| 380 |
+
def __get_mlir_types__(self):
|
| 381 |
+
return [self.mlir_type]
|
| 382 |
+
|
| 383 |
+
def __new_from_mlir_values__(self, values):
|
| 384 |
+
assert len(values) == 1
|
| 385 |
+
assert isinstance(values[0], CoreTensor)
|
| 386 |
+
return CoreTensor(values[0].value, self._dtype)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def from_dlpack(
|
| 390 |
+
tensor_dlpack,
|
| 391 |
+
assumed_align=None,
|
| 392 |
+
) -> Tensor:
|
| 393 |
+
"""Convert from tensor object supporting __dlpack__() to a CuTe Tensor.
|
| 394 |
+
|
| 395 |
+
:param tensor_dlpack: Tensor object that supports the DLPack protocol
|
| 396 |
+
:type tensor_dlpack: object
|
| 397 |
+
:param assumed_align: Assumed alignment of the tensor (bytes), defaults to None,
|
| 398 |
+
if None, will use the element size bytes as the assumed alignment.
|
| 399 |
+
:type assumed_align: int, optional
|
| 400 |
+
:return: A CuTe Tensor object
|
| 401 |
+
:rtype: Tensor
|
| 402 |
+
|
| 403 |
+
Examples:
|
| 404 |
+
.. code-block:: python
|
| 405 |
+
|
| 406 |
+
import torch
|
| 407 |
+
from cutlass.cute.runtime import from_dlpack
|
| 408 |
+
x = torch.randn(100, 100)
|
| 409 |
+
y = from_dlpack(x)
|
| 410 |
+
y.shape
|
| 411 |
+
# (100, 100)
|
| 412 |
+
type(y)
|
| 413 |
+
# <class 'cutlass.cute.Tensor'>
|
| 414 |
+
"""
|
| 415 |
+
return _Tensor(
|
| 416 |
+
tensor_dlpack,
|
| 417 |
+
assumed_align=assumed_align,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def make_ptr(
|
| 422 |
+
dtype: Type[Numeric],
|
| 423 |
+
value: Union[int, ctypes._Pointer],
|
| 424 |
+
mem_space: AddressSpace = AddressSpace.generic,
|
| 425 |
+
assumed_align=None,
|
| 426 |
+
) -> Pointer:
|
| 427 |
+
"""Create a pointer from a memory address
|
| 428 |
+
|
| 429 |
+
:param dtype: Data type of the pointer elements
|
| 430 |
+
:type dtype: Type[Numeric]
|
| 431 |
+
:param value: Memory address as integer or ctypes pointer
|
| 432 |
+
:type value: Union[int, ctypes._Pointer]
|
| 433 |
+
:param mem_space: Memory address space, defaults to AddressSpace.generic
|
| 434 |
+
:type mem_space: AddressSpace, optional
|
| 435 |
+
:param align_bytes: Alignment in bytes, defaults to None
|
| 436 |
+
:type align_bytes: int, optional
|
| 437 |
+
:return: A pointer object
|
| 438 |
+
:rtype: Pointer
|
| 439 |
+
|
| 440 |
+
.. code-block:: python
|
| 441 |
+
|
| 442 |
+
import numpy as np
|
| 443 |
+
import ctypes
|
| 444 |
+
|
| 445 |
+
from cutlass import Float32
|
| 446 |
+
from cutlass.cute.runtime import make_ptr
|
| 447 |
+
|
| 448 |
+
# Create a numpy array
|
| 449 |
+
a = np.random.randn(16, 32).astype(np.float32)
|
| 450 |
+
|
| 451 |
+
# Get pointer address as integer
|
| 452 |
+
ptr_address = a.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
| 453 |
+
|
| 454 |
+
# Create pointer from address
|
| 455 |
+
y = make_ptr(cutlass.Float32, ptr_address)
|
| 456 |
+
|
| 457 |
+
# Check properties
|
| 458 |
+
print(y.element_type)
|
| 459 |
+
print(type(y)) # <class 'cutlass.cute.Pointer'>
|
| 460 |
+
"""
|
| 461 |
+
# check if value is int or ctypes.POINTER
|
| 462 |
+
if isinstance(value, int):
|
| 463 |
+
address_value = value
|
| 464 |
+
elif isinstance(value, ctypes._Pointer):
|
| 465 |
+
# get address value
|
| 466 |
+
address_value = ctypes.cast(value, ctypes.c_void_p).value
|
| 467 |
+
assert address_value is not None, "Pointer address is None"
|
| 468 |
+
else:
|
| 469 |
+
raise TypeError(
|
| 470 |
+
f"Expect int or ctypes.POINTER for value but got {type(value)=}"
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
return _Pointer(address_value, dtype, mem_space, assumed_align=assumed_align)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
class TensorAdapter:
|
| 477 |
+
"""
|
| 478 |
+
Convert a DLPack protocol supported tensor/array to a cute tensor.
|
| 479 |
+
"""
|
| 480 |
+
|
| 481 |
+
def __init__(self, arg):
|
| 482 |
+
self._arg = from_dlpack(arg).mark_layout_dynamic()
|
| 483 |
+
|
| 484 |
+
def __new_from_mlir_values__(self, values):
|
| 485 |
+
return self._arg.__new_from_mlir_values__(values)
|
| 486 |
+
|
| 487 |
+
def __c_pointers__(self):
|
| 488 |
+
return self._arg.__c_pointers__()
|
| 489 |
+
|
| 490 |
+
def __get_mlir_types__(self):
|
| 491 |
+
return self._arg.__get_mlir_types__()
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
# -------------------------------------------------------------------------
|
| 495 |
+
# Try to register_jit_arg_adapter for TensorAdapter
|
| 496 |
+
# -------------------------------------------------------------------------
|
| 497 |
+
|
| 498 |
+
try: # Register for numpy.ndarray
|
| 499 |
+
import numpy
|
| 500 |
+
|
| 501 |
+
JitArgAdapterRegistry.register_jit_arg_adapter(numpy.ndarray)(TensorAdapter)
|
| 502 |
+
except ImportError:
|
| 503 |
+
pass # silent attempt, suppress error
|
| 504 |
+
|
| 505 |
+
try: # Register for torch.Tensor
|
| 506 |
+
import torch
|
| 507 |
+
|
| 508 |
+
JitArgAdapterRegistry.register_jit_arg_adapter(torch.Tensor)(TensorAdapter)
|
| 509 |
+
except ImportError:
|
| 510 |
+
pass # silent attempt, suppress error
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/testing.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
import functools
|
| 13 |
+
import inspect
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
from enum import Enum
|
| 17 |
+
from inspect import isclass
|
| 18 |
+
from itertools import product
|
| 19 |
+
from time import time
|
| 20 |
+
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
| 21 |
+
|
| 22 |
+
import cuda.bindings.driver as cuda_driver
|
| 23 |
+
import cuda.bindings.runtime as cuda_runtime
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
import cutlass._mlir.ir as ir
|
| 27 |
+
import cutlass.base_dsl.jit_executor
|
| 28 |
+
import cutlass.cute as cute
|
| 29 |
+
from cutlass._mlir.dialects import builtin, cf, nvvm, vector
|
| 30 |
+
from cutlass.cute import core, nvgpu
|
| 31 |
+
from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, t, dsl_user_op
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dsl_user_op
|
| 35 |
+
def assert_(cond, msg=None, *, loc=None, ip=None):
|
| 36 |
+
cf.assert_(t.Boolean(cond).ir_value(), msg if msg else "", loc=loc, ip=ip)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _maybe_recast_tensor_from_f4(src: core.Tensor, tv_layout: core.Layout):
|
| 40 |
+
if src.element_type.width == 4:
|
| 41 |
+
tv_layout = core.recast_layout(8, 4, tv_layout)
|
| 42 |
+
src = core.recast_tensor(src, dtype=t.Int8)
|
| 43 |
+
return src, tv_layout
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _maybe_recast_to_f4(input: core.TensorSSA, dtype: Type[core.Numeric]):
|
| 47 |
+
"""Conditionally recasts the tensor to 4-bit type if the destination type is 4-bit.
|
| 48 |
+
|
| 49 |
+
:param input: The input tensor to recast.
|
| 50 |
+
:param dtype: The target numeric type to potentially recast to.
|
| 51 |
+
:raises TypeError: If dtype is not a subclass of Numeric.
|
| 52 |
+
:return: A new tensor recast to 4-bit if dtype is 4-bit, otherwise returns self unchanged.
|
| 53 |
+
"""
|
| 54 |
+
if not isclass(dtype) or not issubclass(dtype, core.Numeric):
|
| 55 |
+
raise TypeError(f"dst_ty must be a type of Numeric, but got {dtype}")
|
| 56 |
+
|
| 57 |
+
if dtype.width == 4:
|
| 58 |
+
recast_shape = core.recast_layout(4, 8, core.make_layout(input.shape)).shape
|
| 59 |
+
i4_vec = vector.bitcast(
|
| 60 |
+
T.vector(input.type.shape[0] * 2, T.i(4)), input.maybe_downcast()
|
| 61 |
+
)
|
| 62 |
+
res_vect = builtin.unrealized_conversion_cast(
|
| 63 |
+
[T.vector(i4_vec.type.shape[0], dtype.mlir_type)], [i4_vec]
|
| 64 |
+
)
|
| 65 |
+
return core.TensorSSA(res_vect, recast_shape, dtype)
|
| 66 |
+
return input
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _maybe_recast_from_f4(input: core.TensorSSA, src_dtype: Type[core.Numeric]):
|
| 70 |
+
"""Conditionally recasts the tensor from 4-bit type if the source type is 4-bit.
|
| 71 |
+
|
| 72 |
+
:param input: The input tensor to recast.
|
| 73 |
+
:param src_dtype: The source numeric type to potentially recast from.
|
| 74 |
+
:raises TypeError: If src_dtype is not a subclass of Numeric.
|
| 75 |
+
:return: A new tensor recast from 4-bit if src_dtype is 4-bit, otherwise returns self unchanged.
|
| 76 |
+
"""
|
| 77 |
+
if not isclass(src_dtype) or not issubclass(src_dtype, core.Numeric):
|
| 78 |
+
raise TypeError(f"src_ty must be a type of Numeric, but got {src_dtype}")
|
| 79 |
+
|
| 80 |
+
if src_dtype.width == 4:
|
| 81 |
+
recast_shape = core.recast_layout(8, 4, core.make_layout(input.shape)).shape
|
| 82 |
+
i4_vec = builtin.unrealized_conversion_cast(
|
| 83 |
+
[T.vector(input.type.shape[0], T.i(4))], [input.maybe_downcast()]
|
| 84 |
+
)
|
| 85 |
+
res_vect = vector.bitcast(T.vector(i4_vec.type.shape[0] // 2, T.i8()), i4_vec)
|
| 86 |
+
return core.TensorSSA(res_vect, recast_shape, core.Int8)
|
| 87 |
+
return input
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@CuTeDSL.kernel
|
| 91 |
+
def _convert_kernel(
|
| 92 |
+
gSrc: core.Tensor,
|
| 93 |
+
gDst: core.Tensor,
|
| 94 |
+
cSrc: core.Tensor,
|
| 95 |
+
src_tv_layout: core.Layout,
|
| 96 |
+
dst_tv_layout: core.Layout,
|
| 97 |
+
src_shape: core.Shape,
|
| 98 |
+
src_ty,
|
| 99 |
+
dst_ty,
|
| 100 |
+
):
|
| 101 |
+
tidx = nvvm.read_ptx_sreg_tid_x(T.i32())
|
| 102 |
+
bidx = nvvm.read_ptx_sreg_ctaid_x(T.i32())
|
| 103 |
+
|
| 104 |
+
cta_coord = (None, bidx)
|
| 105 |
+
# logical idx -> address
|
| 106 |
+
ctaSrc = gSrc[cta_coord] # (...,TileV,...)
|
| 107 |
+
ctaDst = gDst[cta_coord] # (...,TileV,...)
|
| 108 |
+
ctaCSrc = cSrc[cta_coord] # (...,TileV,...)
|
| 109 |
+
# print(f"ctaSrc = {ctaSrc.type}")
|
| 110 |
+
|
| 111 |
+
# compose with CTA TV layout
|
| 112 |
+
# tid, vid -> address
|
| 113 |
+
tidfrgSrc = core.composition(ctaSrc, src_tv_layout) # (T,V)
|
| 114 |
+
tidfrgDst = core.composition(ctaDst, dst_tv_layout) # (T,V)
|
| 115 |
+
tidfrgCSrc = core.composition(ctaCSrc, src_tv_layout) # (T,V)
|
| 116 |
+
# print(f"tidfrgSrc = {tidfrgSrc.type}")
|
| 117 |
+
|
| 118 |
+
# slice for threads
|
| 119 |
+
thr_coord = (tidx, None)
|
| 120 |
+
thrSrc = tidfrgSrc[thr_coord] # (V)
|
| 121 |
+
thrDst = tidfrgDst[thr_coord] # (V)
|
| 122 |
+
thrCSrc = tidfrgCSrc[thr_coord] # (V)
|
| 123 |
+
# print(f"thrSrc = {thrSrc.type}")
|
| 124 |
+
|
| 125 |
+
# predicate
|
| 126 |
+
if core.elem_less(thrCSrc[0], src_shape):
|
| 127 |
+
# allocate fragments for gmem->rmem
|
| 128 |
+
frgSrc = core.make_fragment(
|
| 129 |
+
core.get(src_tv_layout, mode=[1]), gSrc.element_type
|
| 130 |
+
) # (V)
|
| 131 |
+
frgDst = core.make_fragment(
|
| 132 |
+
core.get(dst_tv_layout, mode=[1]), gDst.element_type
|
| 133 |
+
) # (V)
|
| 134 |
+
# print(f"frgSrc = {frgSrc.type}")
|
| 135 |
+
|
| 136 |
+
# Move data to reg address space
|
| 137 |
+
copy_atom_load = core.make_copy_atom(nvgpu.CopyUniversalOp(), gSrc.element_type)
|
| 138 |
+
core.copy(copy_atom_load, thrSrc, frgSrc)
|
| 139 |
+
|
| 140 |
+
vec_src = frgSrc.load()
|
| 141 |
+
vec_src = _maybe_recast_to_f4(vec_src, src_ty)
|
| 142 |
+
vec_dst = vec_src.to(dst_ty)
|
| 143 |
+
vec_dst = _maybe_recast_from_f4(vec_dst, dst_ty)
|
| 144 |
+
frgDst.store(vec_dst)
|
| 145 |
+
|
| 146 |
+
# Copy the results back to c
|
| 147 |
+
copy_atom_stg = core.make_copy_atom(nvgpu.CopyUniversalOp(), gDst.element_type)
|
| 148 |
+
core.copy(copy_atom_stg, frgDst, thrDst)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@CuTeDSL.jit(preprocess=False)
|
| 152 |
+
def _convert(
|
| 153 |
+
src: core.Tensor,
|
| 154 |
+
dst: core.Tensor,
|
| 155 |
+
leading_mode: Constexpr,
|
| 156 |
+
elem_per_copy: Constexpr,
|
| 157 |
+
):
|
| 158 |
+
|
| 159 |
+
# Step 1. figure proper tv_layout
|
| 160 |
+
src_ty = src.element_type
|
| 161 |
+
dst_ty = dst.element_type
|
| 162 |
+
|
| 163 |
+
tv_layout = core.make_layout((128, elem_per_copy), stride=(elem_per_copy, 1))
|
| 164 |
+
|
| 165 |
+
# Step 2. maybe recast from f4 tensor
|
| 166 |
+
src, src_tv_layout = _maybe_recast_tensor_from_f4(src, tv_layout)
|
| 167 |
+
dst, dst_tv_layout = _maybe_recast_tensor_from_f4(dst, tv_layout)
|
| 168 |
+
src_shape = src.shape
|
| 169 |
+
# predicate tensor
|
| 170 |
+
idA = core.make_identity_tensor(src.shape)
|
| 171 |
+
|
| 172 |
+
# Step 3. select a proper tiling pattern as (...,TileV, ...)
|
| 173 |
+
src_cta_tiler = [
|
| 174 |
+
1,
|
| 175 |
+
] * core.rank(src.layout)
|
| 176 |
+
src_cta_tiler[leading_mode] = core.size(src_tv_layout) # (...,TileV,...)
|
| 177 |
+
dst_cta_tiler = [
|
| 178 |
+
1,
|
| 179 |
+
] * core.rank(dst.layout)
|
| 180 |
+
dst_cta_tiler[leading_mode] = core.size(dst_tv_layout) # (...,TileV,...)
|
| 181 |
+
|
| 182 |
+
# Step 4. partition input and output tensor by cta tiler.
|
| 183 |
+
gS = core.zipped_divide(
|
| 184 |
+
src, tuple(src_cta_tiler)
|
| 185 |
+
) # ((...,TileV,...),(...,RestV,...))
|
| 186 |
+
cS = core.zipped_divide(
|
| 187 |
+
idA, tuple(src_cta_tiler)
|
| 188 |
+
) # ((...,TileV,...),(...,RestV,...))
|
| 189 |
+
gD = core.zipped_divide(
|
| 190 |
+
dst, tuple(dst_cta_tiler)
|
| 191 |
+
) # ((...,TileV,...),(...,RestV,...))
|
| 192 |
+
# print(f"{gS.type=}")
|
| 193 |
+
|
| 194 |
+
_convert_kernel(
|
| 195 |
+
gS,
|
| 196 |
+
gD,
|
| 197 |
+
cS,
|
| 198 |
+
src_tv_layout,
|
| 199 |
+
dst_tv_layout,
|
| 200 |
+
src_shape,
|
| 201 |
+
src_ty,
|
| 202 |
+
dst_ty,
|
| 203 |
+
).launch(
|
| 204 |
+
grid=[core.size(gS, mode=[1]), 1, 1],
|
| 205 |
+
block=[core.size(src_tv_layout, mode=[0]), 1, 1],
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# Converts from src tensor to dst tensor, their logical shape are required to be the same.
|
| 210 |
+
# And when src or dst dtype is narrow precision(Float4E2M1FN/Float8E8M0FNU/Float8E4M3FN), the shape of
|
| 211 |
+
# their leading dimension should be 4(fp8)/8(fp4) element align. (nvgpu.cvt_fptrunc/cvt_fpext
|
| 212 |
+
# needs 32-bits aligned input/output)
|
| 213 |
+
def convert(src: core.Tensor, dst: core.Tensor):
|
| 214 |
+
assert len(src.shape) == len(
|
| 215 |
+
dst.shape
|
| 216 |
+
), "Shape of src and dst tensors should be the same rank."
|
| 217 |
+
# find leading mode
|
| 218 |
+
leading_mode = [
|
| 219 |
+
idx
|
| 220 |
+
for idx, (shape, stride) in enumerate(zip(src.shape, src.stride))
|
| 221 |
+
if shape > 1 and stride == 1
|
| 222 |
+
]
|
| 223 |
+
if len(leading_mode) != 1:
|
| 224 |
+
raise ValueError(f"Leading mode should be unique, but got {leading_mode}")
|
| 225 |
+
leading_mode = leading_mode[0]
|
| 226 |
+
|
| 227 |
+
elem_per_copy = 2
|
| 228 |
+
|
| 229 |
+
if src.element_type.width == 4 or dst.element_type.width == 4:
|
| 230 |
+
elem_per_copy = 8
|
| 231 |
+
elif src.element_type.width == 8 or dst.element_type.width == 8:
|
| 232 |
+
elem_per_copy = 4
|
| 233 |
+
assert (
|
| 234 |
+
src.shape[leading_mode] % elem_per_copy == 0
|
| 235 |
+
and dst.shape[leading_mode] % elem_per_copy == 0
|
| 236 |
+
)
|
| 237 |
+
_convert(src, dst, leading_mode, elem_per_copy)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
#########################################
|
| 241 |
+
# Testing utilities
|
| 242 |
+
#########################################
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def sample_pytest(rand_cfg=None):
|
| 246 |
+
"""
|
| 247 |
+
Decorator to randomly sample pytest parametrized tests.
|
| 248 |
+
rand_cfg: Tuple[int, float] - (random_seed, sample_ratio)
|
| 249 |
+
Sampling is disabled when:
|
| 250 |
+
- A specific test is selected (via -k or direct test path)
|
| 251 |
+
- Not running under pytest
|
| 252 |
+
"""
|
| 253 |
+
import functools
|
| 254 |
+
import os
|
| 255 |
+
import random
|
| 256 |
+
import sys
|
| 257 |
+
|
| 258 |
+
import pytest
|
| 259 |
+
|
| 260 |
+
seed, sample_ratio = rand_cfg
|
| 261 |
+
random.seed(seed)
|
| 262 |
+
|
| 263 |
+
def decorator(func):
|
| 264 |
+
@functools.wraps(func)
|
| 265 |
+
def wrapper(*args, **kwargs):
|
| 266 |
+
if rand_cfg is not None and "PYTEST_CURRENT_TEST" in os.environ:
|
| 267 |
+
# Check if test was explicitly selected like ::test_name[param1-param2-...]
|
| 268 |
+
if "-k" in sys.argv or any(".py::" in arg for arg in sys.argv):
|
| 269 |
+
# Test was explicitly selected, don't skip
|
| 270 |
+
return func(*args, **kwargs)
|
| 271 |
+
|
| 272 |
+
if random.uniform(0.0, 1.0) > sample_ratio:
|
| 273 |
+
pytest.skip(f"Randomly skipped (sampling ratio: {sample_ratio})")
|
| 274 |
+
return func(*args, **kwargs)
|
| 275 |
+
|
| 276 |
+
return wrapper
|
| 277 |
+
|
| 278 |
+
return decorator
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
#########################################
|
| 282 |
+
# Benchmarking utilities
|
| 283 |
+
#########################################
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class JitArguments:
|
| 287 |
+
"""
|
| 288 |
+
A type to hold both args and kwargs for passing to a kernel while benchmarking.
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
def __init__(self, *args, **kwargs):
|
| 292 |
+
self.args = args
|
| 293 |
+
self.kwargs = kwargs
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _cuda_success(
|
| 297 |
+
err: Union[tuple, cuda_runtime.cudaError_t, cuda_driver.CUresult], message: str
|
| 298 |
+
):
|
| 299 |
+
"""
|
| 300 |
+
Helper function to check CUDA API errors.
|
| 301 |
+
"""
|
| 302 |
+
if isinstance(err, tuple):
|
| 303 |
+
_cuda_success(err[0], message)
|
| 304 |
+
elif isinstance(err, cuda_runtime.cudaError_t):
|
| 305 |
+
error_message = cuda_runtime.cudaGetErrorString(err)[1].decode("utf-8")
|
| 306 |
+
if err != cuda_runtime.cudaError_t.cudaSuccess:
|
| 307 |
+
raise RuntimeError(f"{message} : {error_message}")
|
| 308 |
+
elif isinstance(err, cuda_driver.CUresult):
|
| 309 |
+
if err != cuda_driver.CUresult.CUDA_SUCCESS:
|
| 310 |
+
error_message = cuda_driver.cuGetErrorString(err)[1].decode("utf-8")
|
| 311 |
+
raise RuntimeError(f"{message} : {error_message}")
|
| 312 |
+
else:
|
| 313 |
+
raise TypeError(
|
| 314 |
+
f"{err} is an unexpected type : it should be a cudaError_t or CUresult"
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def _does_kernel_use_stream(
|
| 319 |
+
kernel: Callable, stream: cuda_driver.CUstream, *args, **kwargs
|
| 320 |
+
):
|
| 321 |
+
"""
|
| 322 |
+
This function checks if the kernel uses the provided non-default stream.
|
| 323 |
+
It does this by capturing the stream and then checking if any kernels were launched.
|
| 324 |
+
:param kernel: The kernel to check
|
| 325 |
+
:type kernel: Callable
|
| 326 |
+
:param stream: The stream to check
|
| 327 |
+
:type stream: cuda_driver.CUstream
|
| 328 |
+
:return: True if the kernel uses the stream, False otherwise
|
| 329 |
+
:rtype: bool
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
assert int(stream) != int(
|
| 333 |
+
cuda_driver.CUstream_flags.CU_STREAM_DEFAULT
|
| 334 |
+
), "Stream must be a non-default stream"
|
| 335 |
+
|
| 336 |
+
err = cuda_runtime.cudaStreamBeginCapture(
|
| 337 |
+
stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal
|
| 338 |
+
)
|
| 339 |
+
_cuda_success(err, "Error on stream capture")
|
| 340 |
+
|
| 341 |
+
kernel(*args, **kwargs)
|
| 342 |
+
|
| 343 |
+
err, graph = cuda_runtime.cudaStreamEndCapture(stream)
|
| 344 |
+
_cuda_success(err, "Error on stream capture")
|
| 345 |
+
|
| 346 |
+
# Get number of nodes in warmup graph to check it matches what is expected
|
| 347 |
+
err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(graph)
|
| 348 |
+
_cuda_success(err, "Error on querying graph")
|
| 349 |
+
return num_nodes > 0
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def benchmark(
|
| 353 |
+
callable: Callable,
|
| 354 |
+
*,
|
| 355 |
+
warmup_iterations: int = 10,
|
| 356 |
+
iterations: int = 100,
|
| 357 |
+
stream: Optional[cuda_driver.CUstream] = None,
|
| 358 |
+
kernel_arguments: Optional[JitArguments] = None,
|
| 359 |
+
workspace_generator: Optional[Callable[[], JitArguments]] = None,
|
| 360 |
+
workspace_count: int = 1,
|
| 361 |
+
use_cuda_graphs: bool = False,
|
| 362 |
+
) -> float:
|
| 363 |
+
"""Benchmarks a callable function with the specified parameters.
|
| 364 |
+
|
| 365 |
+
For example,
|
| 366 |
+
.. code-block:: python
|
| 367 |
+
|
| 368 |
+
from cutlass.cute.testing import benchmark
|
| 369 |
+
|
| 370 |
+
@cute.jit
|
| 371 |
+
def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda_driver.CUstream):
|
| 372 |
+
# contents of the function
|
| 373 |
+
pass
|
| 374 |
+
|
| 375 |
+
time_us = benchmark(user_function, kernel_arguments=JitArguments(a, b, c, stream)
|
| 376 |
+
warmup_iterations=10, iterations=100
|
| 377 |
+
stream=stream)
|
| 378 |
+
|
| 379 |
+
To prevent skewing results by repeately accessing the L2 cache, use the workspace_count and workspace_generator
|
| 380 |
+
parameters to cycle through a number of different workspaces.
|
| 381 |
+
|
| 382 |
+
.. code-block:: python
|
| 383 |
+
|
| 384 |
+
from cutlass.cute.testing import benchmark
|
| 385 |
+
|
| 386 |
+
@cute.jit
|
| 387 |
+
def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):
|
| 388 |
+
# contents of the function
|
| 389 |
+
pass
|
| 390 |
+
|
| 391 |
+
def workspace_generator():
|
| 392 |
+
# create a, b, and c
|
| 393 |
+
return JitArguments(a, b, c)
|
| 394 |
+
|
| 395 |
+
time_us = benchmark(user_function,
|
| 396 |
+
workspace_generator=workspace_generator,
|
| 397 |
+
workspace_count=10,
|
| 398 |
+
warmup_iterations=10000,
|
| 399 |
+
iterations=1000)
|
| 400 |
+
|
| 401 |
+
To benchmark you may always configure the function being profiled (callable), the warmup iterations, and
|
| 402 |
+
the number of profiling iterations.
|
| 403 |
+
|
| 404 |
+
Whenever the kernel being benchmarked runs in a non-default stream, the stream must be provided through the stream parameter.
|
| 405 |
+
|
| 406 |
+
To use CUDA graphs, the callable must be a compiled @cute.jit annotated function.
|
| 407 |
+
When using CUDA graphs, the kernel must be launched in a non-default stream.
|
| 408 |
+
|
| 409 |
+
:param callable: The function to benchmark
|
| 410 |
+
:type callable: Callable
|
| 411 |
+
:param warmup_iterations: Number of warmup iterations, defaults to 10
|
| 412 |
+
:type warmup_iterations: int, optional
|
| 413 |
+
:param iterations: Number of benchmark iterations, defaults to 100
|
| 414 |
+
:type iterations: int, optional
|
| 415 |
+
:param stream: Stream kernel is launched in, defaults to CUDA stream default
|
| 416 |
+
:type stream: CUstream, None
|
| 417 |
+
:param kernel_arguments: Kernel arguments to launch callable with, defaults to None
|
| 418 |
+
:type kernel_arguments: JitArguments, None
|
| 419 |
+
:param workspace_generator: Function that returns kernel arguments, defaults to None
|
| 420 |
+
:type workspace_generator: Callable
|
| 421 |
+
:param workspace_count: Number of workspaces (arguments) to loop through, looping through enough workspaces will keep the L2 cache cold
|
| 422 |
+
:type workspace_count: int, optional
|
| 423 |
+
:param use_cuda_graphs: Whether to use cuda graphs, defaults to False
|
| 424 |
+
:type use_cuda_graphs: bool, optional
|
| 425 |
+
|
| 426 |
+
:return: The benchmark time in microseconds
|
| 427 |
+
:rtype: float
|
| 428 |
+
"""
|
| 429 |
+
|
| 430 |
+
if stream is None:
|
| 431 |
+
stream = cuda_driver.CUstream(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT)
|
| 432 |
+
|
| 433 |
+
if workspace_count < 1:
|
| 434 |
+
raise ValueError("workspace_count must be at least 1")
|
| 435 |
+
|
| 436 |
+
time_us = float("nan")
|
| 437 |
+
if workspace_generator == None:
|
| 438 |
+
# If no workspace generator is provided, we need a single workspace
|
| 439 |
+
if workspace_count != 1:
|
| 440 |
+
raise ValueError("Need a single workspace if not providing a generator")
|
| 441 |
+
|
| 442 |
+
# If no workspace generator is provided, we need a kernel_argument
|
| 443 |
+
if kernel_arguments == None:
|
| 444 |
+
raise ValueError(
|
| 445 |
+
"Please pass a kernel argument if not providing a generator"
|
| 446 |
+
)
|
| 447 |
+
workspace_generator = lambda: kernel_arguments
|
| 448 |
+
|
| 449 |
+
workspaces = [workspace_generator() for _ in range(workspace_count)]
|
| 450 |
+
|
| 451 |
+
for workspace in workspaces:
|
| 452 |
+
if type(workspace) != JitArguments:
|
| 453 |
+
raise TypeError(
|
| 454 |
+
"workspace_generator and/or kernel_arguments should use JitArguments type"
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
def _loop_and_call_kernel(iterations: int, workspace_index: int = 0):
|
| 458 |
+
for _ in range(iterations):
|
| 459 |
+
current_workspace = workspaces[workspace_index]
|
| 460 |
+
callable(*current_workspace.args, **current_workspace.kwargs)
|
| 461 |
+
workspace_index = (workspace_index + 1) % workspace_count
|
| 462 |
+
return workspace_index
|
| 463 |
+
|
| 464 |
+
# Create CUDA events for timing
|
| 465 |
+
err, start_event = cuda_driver.cuEventCreate(
|
| 466 |
+
cuda_driver.CUevent_flags.CU_EVENT_DEFAULT
|
| 467 |
+
)
|
| 468 |
+
_cuda_success(err, "Error on creating event")
|
| 469 |
+
err, end_event = cuda_driver.cuEventCreate(
|
| 470 |
+
cuda_driver.CUevent_flags.CU_EVENT_DEFAULT
|
| 471 |
+
)
|
| 472 |
+
_cuda_success(err, "Error on creating event")
|
| 473 |
+
|
| 474 |
+
elapsed_time = float("nan")
|
| 475 |
+
|
| 476 |
+
if use_cuda_graphs:
|
| 477 |
+
# Check if the callable is a JitExecutor
|
| 478 |
+
if not isinstance(callable, cutlass.base_dsl.jit_executor.JitExecutor):
|
| 479 |
+
raise TypeError("Function must be precompiled to be used with CUDA Graphs")
|
| 480 |
+
|
| 481 |
+
# Check if the stream is a non-default stream
|
| 482 |
+
if int(stream) == int(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT):
|
| 483 |
+
raise ValueError(
|
| 484 |
+
"Measuring with CUDA Graphs requires executing in a non-default stream"
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
workspace_index = 0
|
| 488 |
+
|
| 489 |
+
# Capture warmup graph
|
| 490 |
+
err = cuda_runtime.cudaStreamBeginCapture(
|
| 491 |
+
stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal
|
| 492 |
+
)
|
| 493 |
+
_cuda_success(err, "Error on stream capture")
|
| 494 |
+
|
| 495 |
+
workspace_index = _loop_and_call_kernel(warmup_iterations)
|
| 496 |
+
err, gwarm = cuda_runtime.cudaStreamEndCapture(stream)
|
| 497 |
+
_cuda_success(err, "Error on stream capture")
|
| 498 |
+
|
| 499 |
+
# Get number of nodes in warmup graph to check it matches what is expected
|
| 500 |
+
err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(gwarm)
|
| 501 |
+
_cuda_success(err, "Error on querying graph")
|
| 502 |
+
# Assertion is >= since we may launch multiple kernels in one host function
|
| 503 |
+
if num_nodes < warmup_iterations:
|
| 504 |
+
raise ValueError(
|
| 505 |
+
f"CUDA stream passed to benchmark does not match the stream the kernel was launched in"
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# Capture profiling graph
|
| 509 |
+
err = cuda_runtime.cudaStreamBeginCapture(
|
| 510 |
+
stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal
|
| 511 |
+
)
|
| 512 |
+
_cuda_success(err, "Error on stream capture")
|
| 513 |
+
_loop_and_call_kernel(iterations, workspace_index)
|
| 514 |
+
err, gprofile = cuda_runtime.cudaStreamEndCapture(stream)
|
| 515 |
+
_cuda_success(err, "Error on stream capture")
|
| 516 |
+
|
| 517 |
+
# Instantiate graphs
|
| 518 |
+
err, gwarm = cuda_runtime.cudaGraphInstantiate(gwarm, 0)
|
| 519 |
+
_cuda_success(err, "Error on graph instantiation")
|
| 520 |
+
err, gprofile = cuda_runtime.cudaGraphInstantiate(gprofile, 0)
|
| 521 |
+
_cuda_success(err, "Error on graph instantiation")
|
| 522 |
+
|
| 523 |
+
# Launch warmup graph
|
| 524 |
+
err = cuda_runtime.cudaGraphLaunch(gwarm, stream)
|
| 525 |
+
_cuda_success(err, "Error on graph launch")
|
| 526 |
+
|
| 527 |
+
# Record start time
|
| 528 |
+
err = cuda_driver.cuEventRecord(start_event, stream)
|
| 529 |
+
_cuda_success(err, "Error on recording event")
|
| 530 |
+
|
| 531 |
+
# Launch profiling graph
|
| 532 |
+
err = cuda_runtime.cudaGraphLaunch(gprofile, stream)
|
| 533 |
+
_cuda_success(err, "Error on graph launch")
|
| 534 |
+
|
| 535 |
+
# Record end time
|
| 536 |
+
err = cuda_driver.cuEventRecord(end_event, stream)
|
| 537 |
+
_cuda_success(err, "Error on recording event")
|
| 538 |
+
err = cuda_driver.cuEventSynchronize(end_event)
|
| 539 |
+
_cuda_success(err, "Error on synchronizing event")
|
| 540 |
+
|
| 541 |
+
# Get elapsed time
|
| 542 |
+
err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event)
|
| 543 |
+
_cuda_success(err, "Error on querying event")
|
| 544 |
+
|
| 545 |
+
# Destroy graphs
|
| 546 |
+
err = cuda_runtime.cudaGraphExecDestroy(gwarm)
|
| 547 |
+
_cuda_success(err, "Error on destroying graph")
|
| 548 |
+
err = cuda_runtime.cudaGraphExecDestroy(gprofile)
|
| 549 |
+
_cuda_success(err, "Error on destroying graph")
|
| 550 |
+
|
| 551 |
+
else:
|
| 552 |
+
|
| 553 |
+
if int(stream) != int(
|
| 554 |
+
cuda_driver.CUstream_flags.CU_STREAM_DEFAULT
|
| 555 |
+
) and not _does_kernel_use_stream(
|
| 556 |
+
callable, stream, *workspaces[0].args, **workspaces[0].kwargs
|
| 557 |
+
):
|
| 558 |
+
raise ValueError(
|
| 559 |
+
"CUDA stream passed to benchmark does not match the stream the kernel was launched in"
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# Not using graphs
|
| 563 |
+
# Warmup
|
| 564 |
+
workspace_index = _loop_and_call_kernel(warmup_iterations)
|
| 565 |
+
# Record start event
|
| 566 |
+
err = cuda_driver.cuEventRecord(start_event, stream)
|
| 567 |
+
_cuda_success(err, "Error on recording event")
|
| 568 |
+
_loop_and_call_kernel(iterations, workspace_index)
|
| 569 |
+
# Record end event
|
| 570 |
+
err = cuda_driver.cuEventRecord(end_event, stream)
|
| 571 |
+
_cuda_success(err, "Error on recording event")
|
| 572 |
+
# Synchronize end event
|
| 573 |
+
err = cuda_driver.cuEventSynchronize(end_event)
|
| 574 |
+
_cuda_success(err, "Error on synchronizing event")
|
| 575 |
+
err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event)
|
| 576 |
+
_cuda_success(err, "Error on querying event")
|
| 577 |
+
|
| 578 |
+
# Destroy events
|
| 579 |
+
err = cuda_driver.cuEventDestroy(start_event)
|
| 580 |
+
_cuda_success(err, "Error on destroying event")
|
| 581 |
+
err = cuda_driver.cuEventDestroy(end_event)
|
| 582 |
+
_cuda_success(err, "Error on destroying event")
|
| 583 |
+
|
| 584 |
+
return elapsed_time / iterations * 1e3
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def get_workspace_count(
|
| 588 |
+
one_workspace_bytes: int, warmup_iterations: int, iterations: int
|
| 589 |
+
) -> int:
|
| 590 |
+
"""Calculate the number of workspaces needed to fill L2 cache.
|
| 591 |
+
|
| 592 |
+
:param one_workspace_bytes: Size of one workspace in bytes
|
| 593 |
+
:type one_workspace_bytes: int
|
| 594 |
+
:param warmup_iterations: Number of warmup iterations
|
| 595 |
+
:type warmup_iterations: int
|
| 596 |
+
:param iterations: Number of iterations
|
| 597 |
+
:type iterations: int
|
| 598 |
+
:return: Number of workspaces needed
|
| 599 |
+
:rtype: int
|
| 600 |
+
"""
|
| 601 |
+
num_l2_cache_bytes = cutlass.utils.HardwareInfo().get_l2_cache_size_in_bytes()
|
| 602 |
+
return max(
|
| 603 |
+
1,
|
| 604 |
+
min(
|
| 605 |
+
warmup_iterations + iterations, # Don't create more workspaces than needed
|
| 606 |
+
(num_l2_cache_bytes + one_workspace_bytes - 1)
|
| 607 |
+
// one_workspace_bytes, # Ceiling division
|
| 608 |
+
),
|
| 609 |
+
)
|
| 610 |
+
|
build/torch210-cxx11-cu126-aarch64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/typing.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
| 3 |
+
#
|
| 4 |
+
# Use of this software is governed by the terms and conditions of the
|
| 5 |
+
# NVIDIA End User License Agreement (EULA), available at:
|
| 6 |
+
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
| 7 |
+
#
|
| 8 |
+
# Any use, reproduction, disclosure, or distribution of this software
|
| 9 |
+
# and related documentation outside the scope permitted by the EULA
|
| 10 |
+
# is strictly prohibited.
|
| 11 |
+
|
| 12 |
+
from abc import ABC, abstractmethod
|
| 13 |
+
from typing import ForwardRef, Tuple, Union, Any, Type, List
|
| 14 |
+
|
| 15 |
+
from cutlass.base_dsl.typing import *
|
| 16 |
+
|
| 17 |
+
from cutlass._mlir import ir
|
| 18 |
+
import cutlass._mlir.extras.types as T
|
| 19 |
+
from cutlass._mlir.dialects.cute import AddressSpace
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
Int = Union[int, Integer]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
ScaledBasis = ForwardRef("ScaledBasis")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
IntTuple = Union[Int, Tuple["IntTuple", ...]]
|
| 29 |
+
Shape = Union[Int, Tuple["Shape", ...]]
|
| 30 |
+
Stride = Union[Int, ScaledBasis, Tuple["Stride", ...]]
|
| 31 |
+
Coord = Union[Int, None, Tuple["Coord", ...]]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Layout(ir.Value):
|
| 35 |
+
def __init__(self, op_result):
|
| 36 |
+
super().__init__(op_result)
|
| 37 |
+
|
| 38 |
+
def __str__(self): ...
|
| 39 |
+
|
| 40 |
+
def get_hier_coord(self, idx) -> Coord:
|
| 41 |
+
"""Return the (hierarchical) ND logical coordinate corresponding to the linear index"""
|
| 42 |
+
...
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def shape(self, *, loc=None, ip=None) -> Shape: ...
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def stride(self, *, loc=None, ip=None) -> Stride: ...
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
Tile = Union[Int, None, Layout, Tuple["Tile", ...]]
|
| 52 |
+
|
| 53 |
+
# XTuple is super set of above types
|
| 54 |
+
XTuple = Union[IntTuple, Shape, Stride, Coord, Tile]
|
| 55 |
+
|
| 56 |
+
Tiler = Union[Shape, Layout, Tile]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class Pointer(ABC):
|
| 60 |
+
"""
|
| 61 |
+
Abstract base class for CuTe jit function and runtime _Pointer
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def value_type(self) -> Type[Numeric]:
|
| 66 |
+
return self.dtype
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def dtype(self) -> Type[Numeric]: ...
|
| 70 |
+
|
| 71 |
+
def align(self, min_align: int) -> "Pointer": ...
|
| 72 |
+
|
| 73 |
+
def __get_mlir_types__(self) -> List[ir.Type]: ...
|
| 74 |
+
|
| 75 |
+
def __extract_mlir_values__(self) -> List[ir.Value]: ...
|
| 76 |
+
|
| 77 |
+
def __new_from_mlir_values__(self, values) -> "Pointer": ...
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Tensor(ABC):
|
| 81 |
+
"""
|
| 82 |
+
Abstract base class for CuTe jit function and runtime _Tensor
|
| 83 |
+
|
| 84 |
+
A CuTe Tensor is iterator with layout
|
| 85 |
+
|
| 86 |
+
:Examples:
|
| 87 |
+
|
| 88 |
+
Create tensor from torch.tensor with Host Runtime:
|
| 89 |
+
|
| 90 |
+
.. code-block:: python
|
| 91 |
+
|
| 92 |
+
>>> import torch
|
| 93 |
+
>>> from cutlass.cute.runtime import from_dlpack
|
| 94 |
+
>>> mA = from_dlpack(torch.tensor([1, 3, 5], dtype=torch.int32))
|
| 95 |
+
>>> mA.shape
|
| 96 |
+
(3,)
|
| 97 |
+
>>> mA.stride
|
| 98 |
+
(1,)
|
| 99 |
+
>>> mA.layout
|
| 100 |
+
(3,):(1,)
|
| 101 |
+
|
| 102 |
+
Define JIT function:
|
| 103 |
+
|
| 104 |
+
.. code-block:: python
|
| 105 |
+
|
| 106 |
+
@cute.jit
|
| 107 |
+
def add(a: Tensor, b: Tensor, res: Tensor): ...
|
| 108 |
+
|
| 109 |
+
Call JIT function from python:
|
| 110 |
+
|
| 111 |
+
.. code-block:: python
|
| 112 |
+
|
| 113 |
+
>>> import torch
|
| 114 |
+
>>> a = torch.tensor([1, 3, 5], dtype=torch.int32)
|
| 115 |
+
>>> b = torch.tensor([2, 4, 6], dtype=torch.int32)
|
| 116 |
+
>>> c = torch.zeros([3], dtype=torch.int32)
|
| 117 |
+
>>> mA = from_dlpack(a)
|
| 118 |
+
>>> mB = from_dlpack(b)
|
| 119 |
+
>>> mC = from_dlpack(c)
|
| 120 |
+
>>> add(mA, mB, mC)
|
| 121 |
+
>>> c
|
| 122 |
+
tensor([3, 7, 11], dtype=torch.int32)
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __str__(self): ...
|
| 126 |
+
|
| 127 |
+
@abstractmethod
|
| 128 |
+
def __getitem__(self, idx) -> Union["Tensor", ir.Value, IntTuple]: ...
|
| 129 |
+
|
| 130 |
+
@abstractmethod
|
| 131 |
+
def __setitem__(self, idx, value): ...
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
@abstractmethod
|
| 135 |
+
def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]: ...
|
| 136 |
+
|
| 137 |
+
@element_type.setter
|
| 138 |
+
def element_type(self, new_type): ...
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
@abstractmethod
|
| 142 |
+
def memspace(self) -> AddressSpace: ...
|
| 143 |
+
|
| 144 |
+
@property
|
| 145 |
+
@abstractmethod
|
| 146 |
+
def iterator(self): ...
|
| 147 |
+
|
| 148 |
+
@property
|
| 149 |
+
def layout(self) -> Union[Layout, "ComposedLayout"]: ...
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def shape(self) -> Shape: ...
|
| 153 |
+
|
| 154 |
+
def load(self, *, loc=None, ip=None) -> "TensorSSA": ...
|
| 155 |
+
|
| 156 |
+
def store(self, data: "TensorSSA", *, loc=None, ip=None): ...
|
| 157 |
+
|
| 158 |
+
def mark_layout_dynamic(self, leading_dim: int | None = None) -> "Tensor": ...
|
| 159 |
+
|
| 160 |
+
def mark_compact_shape_dynamic(
|
| 161 |
+
self,
|
| 162 |
+
mode: int,
|
| 163 |
+
stride_order: tuple[int, ...] | None = None,
|
| 164 |
+
divisibility: int = 1,
|
| 165 |
+
) -> "Tensor": ...
|
| 166 |
+
|
| 167 |
+
@abstractmethod
|
| 168 |
+
def fill(self, value: Numeric) -> None: ...
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
__all__ = [
|
| 172 |
+
"Coord",
|
| 173 |
+
"Numeric",
|
| 174 |
+
"Integer",
|
| 175 |
+
"Boolean",
|
| 176 |
+
"Int8",
|
| 177 |
+
"Int16",
|
| 178 |
+
"Int32",
|
| 179 |
+
"Int64",
|
| 180 |
+
"Uint8",
|
| 181 |
+
"Uint16",
|
| 182 |
+
"Uint32",
|
| 183 |
+
"Uint64",
|
| 184 |
+
"Float",
|
| 185 |
+
"Float16",
|
| 186 |
+
"BFloat16",
|
| 187 |
+
"TFloat32",
|
| 188 |
+
"Float32",
|
| 189 |
+
"Float64",
|
| 190 |
+
"Float8E5M2",
|
| 191 |
+
"Float8E4M3FN",
|
| 192 |
+
"Float8E4M3B11FNUZ",
|
| 193 |
+
"Float8E4M3",
|
| 194 |
+
"Float8E8M0FNU",
|
| 195 |
+
"Float4E2M1FN",
|
| 196 |
+
"Float6E2M3FN",
|
| 197 |
+
"Float6E3M2FN",
|
| 198 |
+
"IntTuple",
|
| 199 |
+
"Layout",
|
| 200 |
+
"Pointer",
|
| 201 |
+
"Shape",
|
| 202 |
+
"Stride",
|
| 203 |
+
"Tensor",
|
| 204 |
+
"Tile",
|
| 205 |
+
"Tiler",
|
| 206 |
+
"XTuple",
|
| 207 |
+
]
|