File size: 46,043 Bytes
fca4fc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides a DSL for Cutlass Dialects. It also includes utils with
regarding to that dialect.
"""
# Local module imports
from typing import Callable, Union, Type, List, Union, Sequence, ForwardRef
from inspect import isclass
import functools
import pkgutil
from dataclasses import is_dataclass
from ..base_dsl import *
from ..base_dsl import compiler
from ..base_dsl.dsl import is_dynamic_expression, extract_mlir_values
from ..base_dsl.typing import *
from ..base_dsl.typing import DynamicExpression, get_mlir_types
from ..base_dsl.runtime.jit_arg_adapters import is_arg_spec_constexpr
from ..base_dsl.ast_helpers import const_expr
# MLIR Imports
from cutlass._mlir import ir, execution_engine, passmanager
from cutlass._mlir.dialects import arith, func, gpu, scf, cute, gpu as cutlass_gpu
from cutlass._mlir.dialects._ods_common import (
get_op_result_or_op_results as _get_op_result_or_op_results,
)
from cutlass._mlir.extras import types as T
# Helpers
from ..base_dsl._mlir_helpers import arith as cutlass_arith
from ..base_dsl._mlir_helpers import lru_cache_ir
from ..base_dsl.ast_helpers import (
loop_selector,
executor,
if_selector,
if_executor,
while_selector,
while_executor,
assert_executor,
bool_cast,
)
from ..base_dsl.runtime.dlpack_runtime import (
get_cute_tensor_c_pointer,
get_tensor_desc_shape_all,
get_tensor_desc_stride_all,
get_tensor_desc_element_type,
get_tensor_desc_is_in_device,
get_tensor_desc_assumed_align,
)
from .cutlass_ast_decorators import (
_loop_execute_range_dynamic,
_if_execute_dynamic,
_while_execute_dynamic,
)
# =============================================================================
# Set the AST decorator
# =============================================================================
# Set the DSL specific functions
executor.set_functions(
is_dynamic_expression,
_loop_execute_range_dynamic,
_if_execute_dynamic,
_while_execute_dynamic,
)
# =============================================================================
# Cutlass DSL Base Abstract Class
# =============================================================================
# Return a ctype class that represents the in-memory layout expected
# for a CuTe hierarchical tuple type.
def get_sparse_tuple_ctype(dyn):
# When there is a single dynamic value, the sparse CuTe
# representation is a single integer.
if isinstance(dyn, int):
return ctypes.c_int32
# For zero or greater than 1 dynamic values, the tuple
# representation will be a struct with a field for each dynamic
# value. The representation is flattened, even for hierarchical CuTe
# profiles (although we are only dealing with depth 1 inputs here).
class TupleDescriptor(ctypes.Structure):
_fields_ = [(f"x{idx}", ctypes.c_int32) for idx in range(len(dyn))]
def __str__(self):
return f"struct<{str(self._fields_)}>"
return TupleDescriptor
def is_cute_algebra_type(arg_spec):
# Walk through the arg_spec to check if it's a cute algebra type
_cute_algebra_type_aliases = (
"Shape",
"Stride",
"Coord",
"Tile",
"IntTuple",
)
origin = get_origin(arg_spec)
if origin is Union:
for sub_ty in get_args(arg_spec):
sub_origin = get_origin(sub_ty)
if sub_origin is Tuple or (
type(sub_origin) is type and issubclass(sub_origin, tuple)
):
tuple_arg0 = get_args(sub_ty)[0]
if isinstance(
tuple_arg0, ForwardRef
) and tuple_arg0.__forward_arg__ in (_cute_algebra_type_aliases):
return True
return False
class CutlassBaseDSL(BaseDSL):
"""This abstract class provides a DSL for Cutlass."""
def __init__(
self,
name: str,
compiler_provider: Any,
pass_sm_arch_name: str,
device_compilation_only: bool = False,
preprocess: bool = False,
):
super().__init__(
name,
compiler_provider,
pass_sm_arch_name,
device_compilation_only,
preprocess,
)
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
return False
def _build_gpu_module(self, attrs):
self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels"))
with ir.InsertionPoint(self.gpu_module.bodyRegion.blocks.append(*[])):
pass
for attr_name in attrs:
self.gpu_module.attributes[attr_name] = ir.Attribute.parse(attrs[attr_name])
def _get_pipeline(self, pipeline):
pipeline = super()._get_pipeline(pipeline)
if pipeline == None:
# cubin format is required to be cubin as we launch cuda module at python level.
return "builtin.module(cute-to-nvvm{cubin-format=bin opt-level=3})"
return pipeline
def preprocess_pipeline(self, pipeline, arch) -> str:
pipeline = super().preprocess_pipeline(pipeline, arch)
pipeline = pipeline.rstrip(")") + ",external-kernel-for-gpu-launch)"
return pipeline
def _enter_gpu_module(self):
return ir.InsertionPoint(self.gpu_module.bodyRegion.blocks[0])
def _generate_kernel_attrs(self, config: BaseDSL.LaunchConfig) -> dict:
assert isinstance(
config, BaseDSL.LaunchConfig
), f"Expect LaunchConfig for @kernel, but got {type(config)}"
ret = {}
# generate launch bound attr from LaunchConfig
max_threads = ", ".join(map(str, config.block))
ret["nvvm.reqntid"] = ir.Attribute.parse(f"array<i32 : {max_threads}>")
# min_blocks_per_mp is optional for kernel
min_blocks = config.min_blocks_per_mp
if min_blocks > 0:
ret["nvvm.minctasm"] = ir.Attribute.parse(f"{min_blocks} : i32")
return ret
@lru_cache(maxsize=1)
def get_version(self):
"""
Get the version of cutlass dsl, used for computing the hash key of the cache.
Including source python files and the shared library.
"""
dsl_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# get the version hash of the cutlass shared library
version_hash = hashlib.sha256()
# update the version hash of the source python files
for lib in pkgutil.walk_packages([dsl_path], prefix="cutlass."):
try:
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
version_hash.update(f.read())
except Exception:
raise DSLRuntimeError(
f"Failed to read module file {lib.name}. The file may not exist or may not be readable."
"Please re-install the package."
)
try:
# update the version hash of the cutlass shared library
with open(
os.path.join(dsl_path, "_mlir/_mlir_libs/libCutlassIRPythonCAPI.so"),
"rb",
) as f:
while True:
chunk = f.read(1024**2)
if not chunk:
break
version_hash.update(chunk)
except Exception:
raise DSLRuntimeError(
f"Failed to read the shared library file libCutlassIRPythonCAPI.so."
"The file may not exist or may not be readable."
"Please re-install the package."
)
return version_hash
def _kernel_helper(self, funcBody, *args, **kwargs):
class _CutlassIrKernelGenHelper(BaseDSL._KernelGenHelper):
def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None):
super().generate_func_op(arg_types, arg_attrs, kernel_name)
self.func_op = func.FuncOp(
kernel_name, ir.FunctionType.get(arg_types, []), loc=loc
)
if arg_attrs is not None:
log().debug(arg_attrs)
self.func_op.arg_attrs = arg_attrs
return self.func_op
def generate_func_ret_op(self):
return func.ReturnOp([])
def get_func_body_start(self):
assert self.func_op is not None, "Invalid func_op is not expected!"
return self.func_op.add_entry_block()
def generate_launch_op(self, *args, **kwargs):
# Extract args and do validation
kernelSym = kwargs.get("kernelSym", None)
kernelOperands = kwargs.get("kernelOperands", None)
requiredArgs = kwargs.get("requiredArgs", None)
assert kernelSym is not None, "kernelSym being None is not expected!"
assert (
requiredArgs is not None
), "requiredArgs being None is not expected!"
assert (
kernelOperands is not None
), "kernelOperands being None is not expected!"
assert isinstance(
requiredArgs.config, BaseDSL.LaunchConfig
), f"Expect LaunchConfig for @kernel, but got {type(requiredArgs.config)}"
cfg = requiredArgs.config
# Apply to grid, block, and cluster if present
cfg.grid = [to_index(size) for size in cfg.grid]
cfg.block = [to_index(size) for size in cfg.block]
if cfg.has_cluster:
cfg.cluster = [to_index(size) for size in cfg.cluster]
cfg.smem = const(cfg.smem)
if not isinstance(cfg.async_deps, (list, tuple)):
cfg.async_deps = [cfg.async_deps]
is_async = len(cfg.async_deps) > 0
token = gpu.launch_func(
gpu.AsyncTokenType.get() if is_async else None,
cfg.async_deps,
kernelSym,
*cfg.grid,
*cfg.block,
kernelOperands,
**dict(
zip(
("cluster_size_x", "cluster_size_y", "cluster_size_z"),
tuple(cfg.cluster),
)
),
dynamic_shared_memory_size=cfg.smem,
)
return token if is_async else None
return KernelLauncher(
self, _CutlassIrKernelGenHelper, funcBody, *args, **kwargs
)
def _get_globals(self):
caller_globals = self.frame.f_globals
caller_locals = self.frame.f_locals
all_globals = globals().copy()
all_globals.update(caller_globals)
all_globals.update(caller_locals)
return all_globals
def _preprocess_launch_config_args(self, args, kwargs):
"""Helper to preprocess args and kwargs for LaunchConfig"""
if "stream" in kwargs:
kwargs["async_deps"] = kwargs.pop("stream")
def mangle_name(self, function_name, args, args_spec: inspect.FullArgSpec):
"""Mangle the name of the function to avoid conflicts with other functions"""
function_name = "cutlass_" + function_name
return super().mangle_name(function_name, args, args_spec)
def _validate_arg(self, arg, arg_index, arg_name, arg_annotation):
"""
Validates if the arg is really of the annotated type.
"""
if is_arg_spec_constexpr(arg_annotation, arg_name, arg_index, None):
pass
else:
origin = get_origin(arg_annotation)
# Handle special case where annotation is Type[X] but arg is an actual type
if origin is type and isinstance(arg, type):
# Get the expected base type from Type[X]
expected_base = get_args(arg_annotation)[0]
if not issubclass(arg, expected_base):
return DSLRuntimeError(
f"expects argument #{arg_index+1} ({arg_name}) to be Type[{expected_base}], but got {arg}"
)
# Handle Union types and generic types
elif origin is Union:
# For Union types, check if arg matches any of the allowed types
allowed_types = get_args(arg_annotation)
if not any(
(isinstance(ty, type) and isinstance(arg, ty))
or (get_origin(ty) is tuple and isinstance(arg, tuple))
for ty in allowed_types
):
return DSLRuntimeError(
f"expects argument #{arg_index+1} ({arg_name}) to be one of {allowed_types}, but got {type(arg)}"
)
elif isinstance(arg_annotation, type):
# Handle simple type annotations
if not isinstance(arg, arg_annotation) and arg is not None:
return DSLRuntimeError(
f"expects argument #{arg_index+1} ({arg_name}) to be {arg_annotation}, but got {type(arg)}"
)
# Everything looks good if we are here
return None
def _generate_jit_func_args_for_known_types(
self,
func,
arg,
arg_name,
arg_spec,
arg_index,
*,
is_host=True,
):
jit_arg_type, jit_arg_attr, jit_exec_arg = [], [], []
default_attr = ir.DictAttr.get({})
(
jit_exec_arg,
jit_arg_type,
jit_arg_attr,
) = super()._generate_jit_func_args_for_known_types(
func, arg, arg_name, arg_spec, arg_index, is_host=is_host
)
if jit_arg_type is not None and len(jit_arg_type) == 0:
# Handle DSL specific types
if is_cute_algebra_type(arg_spec):
dyn_vals = extract_mlir_values(arg)
if dyn_vals:
# Handle dynamic types
jit_arg_type.extend([v.type for v in dyn_vals])
jit_arg_attr.extend([default_attr] * len(dyn_vals))
jit_exec_arg.extend(get_c_pointers(arg) if is_host else dyn_vals)
else:
jit_exec_arg = jit_arg_type = jit_arg_attr = None
return jit_exec_arg, jit_arg_type, jit_arg_attr
def _generate_execution_arguments_for_known_types(
self, arg, arg_spec, arg_name, i, fop_args, iv_block_args
):
ir_arg, iv_block_args = super()._generate_execution_arguments_for_known_types(
arg, arg_spec, arg_name, i, fop_args, iv_block_args
)
if not ir_arg:
# Handling DSL specific types
if is_cute_algebra_type(arg_spec):
n_args = len(get_mlir_types(arg))
blk_args = fop_args[iv_block_args : iv_block_args + n_args]
ir_arg.append(new_from_mlir_values(arg, blk_args))
iv_block_args += n_args
return ir_arg, iv_block_args
# =============================================================================
# Cute DSL Class
# =============================================================================
class CuTeDSL(CutlassBaseDSL):
"""
This is a concrete DSL subclass for the CuTe dialect.
"""
def __init__(self):
name = "CUTE_DSL"
compiler_provider = compiler.Compiler(passmanager, execution_engine)
pass_sm_arch_name = "cubin-chip"
super().__init__(name, compiler_provider, pass_sm_arch_name, preprocess=True)
# =============================================================================
# KernelLauncher
# =============================================================================
class KernelLauncher:
"""
This class is used to launch a kernel function.
Usage:
```python
@cute.kernel
def kernel(arg1, arg2, ...):
...
@cute.jit
def launch_kernel():
kernel(arg1, arg2, ...).launch(grid=[1, 1, 1], block=[1, 1, 1], ...)
# or
kernel(arg1, arg2, ...)(grid=[1, 1, 1], block=[1, 1, 1], ...)
```
"""
def __init__(
self,
dsl: "CutlassBaseDSL",
kernelGenHelper: BaseDSL._KernelGenHelper,
funcBody,
*func_args,
**func_kwargs,
):
self.dsl = dsl
self.kernelGenHelper = kernelGenHelper
self.funcBody = funcBody
self.func_args = func_args
self.func_kwargs = func_kwargs
self._check_func_args(funcBody, *func_args, **func_kwargs)
def _check_func_args(self, funcBody, *func_args, **func_kwargs):
# Get function signature
sig = inspect.signature(funcBody)
# func_args and func_kwargs should match funcBody's signature,
# no extra or missing arguments.
try:
sig.bind(*func_args, **func_kwargs)
except TypeError as e:
raise DSLRuntimeError(
f"Failed to bind arguments to function `{funcBody.__name__}` with signature `{sig}`",
cause=e,
)
def launch(self, *args, **kwargs):
self.dsl.frame = inspect.currentframe().f_back
self.dsl._preprocess_launch_config_args(args, kwargs)
config = self.dsl.LaunchConfig(*args, **kwargs)
kernel_generator = self.dsl.kernel_launcher(
requiredArgs=["config"],
unitAttrNames=["gpu.kernel", "cute.kernel"],
valueAttrDict=self.dsl._generate_kernel_attrs(config),
kernelGenHelper=self.kernelGenHelper,
)(self.funcBody)
ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config)
self.dsl.kernel_symbols.append(name)
return ret.launch_op_ret
def __call__(self, *args, **kwargs):
return self.launch(*args, **kwargs)
# =============================================================================
# Utils
# =============================================================================
def is_frozen_dataclass(obj_or_cls) -> bool:
"""
Return True if obj_or_cls is a dataclass (class or instance) declared with frozen=True,
otherwise False.
"""
if not isinstance(obj_or_cls, type):
# If it's an instance, get its class
obj_or_cls = obj_or_cls.__class__
# Must be a dataclass, and __dataclass_params__.frozen must be True
return (
is_dataclass(obj_or_cls)
and getattr(obj_or_cls, "__dataclass_params__", None) is not None
and obj_or_cls.__dataclass_params__.frozen
)
def pack_from_irvalue(
ir_values: List["ir.Value"],
indices: Dict[int, Tuple[int, int]],
class_types: List[Any],
) -> List[Any]:
"""
Packs MLIR values into a list of mixed values.
"""
log().debug("===--- Values Pack (%d)", len(ir_values))
for idx, packed in enumerate(ir_values):
log().debug("[%d]: will-packed: %s", idx, ir_values)
for idx, unpacked in indices.items():
log().debug("[%d]: indices: %s", idx, unpacked)
for idx, c in enumerate(class_types):
log().debug("[%d]: obj-types: %s", idx, type(c))
mixed_values = [None] * len(indices)
for idx, (start, length) in sorted(indices.items()):
chunk = ir_values[start : start + length]
obj = class_types[idx]
if is_frozen_dataclass(obj):
mixed_values[idx] = obj
elif not isinstance(obj, type) and hasattr(obj, "__new_from_mlir_values__"):
mixed_values[idx] = obj.__new_from_mlir_values__(chunk)
else:
try:
if isinstance(chunk, list) and chunk[0] is None:
mixed_values[idx] = class_types[idx]
else:
mixed_values[idx] = t.as_numeric(chunk[0])
except DSLRuntimeError as e:
mixed_values[idx] = chunk[0]
log().debug("------------------ ")
for idx, packed in enumerate(mixed_values):
log().debug("[%d]: packed: %s", idx, packed)
log().debug("------------------ ")
return mixed_values
def unpack_to_irvalue(
mixed_values: List[Any], body_name: str
) -> Tuple[List[ir.Value], List[Any], Dict[int, Tuple[int, int]], List[Any]]:
"""
Unpacks mixed values into ir.Value values.
"""
unpacked_values = []
ir_values = []
indices = {}
class_types = []
current_offset = 0
log().debug("===--- Values UNPack (%d)", len(mixed_values))
for idx, packed in enumerate(mixed_values):
log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed)
for idx, item in enumerate(mixed_values):
class_types.append(item)
try:
if is_frozen_dataclass(item):
extracted_vals = [None]
else:
extracted_vals = extract_mlir_values(item)
# it's consexpr (python value), so we create mlir value for it
if extracted_vals == []:
if item is None:
extracted_vals = [None]
else:
dyn_expr = t.as_numeric(item)
extracted_vals = extract_mlir_values(dyn_expr)
ir_values.extend(extracted_vals)
else:
ir_values.extend(extracted_vals)
unpacked_values.extend(extracted_vals)
length = len(extracted_vals)
indices[idx] = (current_offset, length)
current_offset += length
except Exception as e:
raise DSLRuntimeError(
f"The '{body_name}' statement encountered a user-defined Python object, which cannot be automatically converted into an dynamic expression (aka MLIR value).",
context={
item: (
f"All expressions within '{body_name}' must be dynamic expressions, "
"mixing Python objects and dynamic expressions (aka MLIR values) is not supported. "
"The DSL failed to convert the Python object into MLIR values."
)
},
suggestion=(
f"Please ensure '{item}' implements the '{DynamicExpression.__name__}', "
f"so it can be treated as a valid dynamic expression or mark '{body_name}' as a constant expression if conditions are Python objects."
),
) from e
log().debug("------------------ ")
for idx, unpacked in enumerate(unpacked_values):
log().debug("[%d]: unpacked values: %s", idx, unpacked)
for idx, unpacked in enumerate(ir_values):
log().debug("[%d]: unpacked ir_values: %s", idx, unpacked)
for idx, unpacked in indices.items():
log().debug("[%d]: indices: %s", idx, unpacked)
for idx, unpacked in enumerate(class_types):
log().debug("[%d]: initial-class-types: %s", idx, unpacked)
log().debug("------------------ ")
return ir_values, unpacked_values, indices, class_types
def to_index(value):
"""Converts a value to an index, either by casting or coercing to int."""
if is_dynamic_expression(value):
if isinstance(value, Numeric):
value = value.ir_value()
assert ir.IntegerType.isinstance(
value.type
), f"expects integer type, but got {value.type}"
res = arith.index_cast(T.index(), value)
else:
res = const(int(value), ty=T.index())
return res
def _validate_iter_args_structure(iter_args, ir_values):
"""
Validates that iter_args structure contains the same number of atomic values
as there are IR values.
Args:
iter_args: Original iteration arguments, possibly nested sequences
ir_values: Flattened MLIR values extracted from iter_args
Returns:
bool: True if the number of atomic values in iter_args matches
the number of values in ir_values
"""
# Handle non-sequence case
if not isinstance(iter_args, (tuple, list, set)):
return not isinstance(ir_values, (tuple, list, set)) or len(ir_values) == 1
# If we have a sequence but ir_values isn't one, there's a mismatch
if not isinstance(ir_values, (tuple, list, set)):
return False
# Count all non-sequence values recursively
def count_values(args):
if not isinstance(args, (tuple, list, set)):
return 1
else:
return sum(count_values(arg) for arg in args)
return count_values(iter_args) == len(ir_values)
# =============================================================================
# DSL implementation of Python Build-in Operators
# =============================================================================
def _minmax(op, *args, loc=None, ip=None):
"""Computes the minimum or maximum value from the provided arguments."""
from ..base_dsl.typing import _binary_op, _binary_op_type_promote
# AST Traversal doesn't support early exit in if executor
x = None
res = None
if len(args) == 1:
# Handle case for min([a, b, c, d, ..])
if hasattr(args[0], "__iter__"):
x = op(*tuple(args[0]))
# Handle case for min(a)
else:
x = args[0]
# Handle case for min(a, b, c, ...) and min([x, y], [b]) and min(a, (x, y, z))
elif len(args) > 1:
res, *xs = tuple(args)
for x in xs:
lhs = as_numeric(op(res, loc=loc, ip=ip))
rhs = as_numeric(op(x, loc=loc, ip=ip))
emitter = getattr(cutlass_arith, f"_{op.__name__}")
lhs, rhs, res_type = _binary_op_type_promote(lhs, rhs, promote_bool=True)
if isinstance(lhs.value, cutlass_arith.ArithValue) and isinstance(
lhs, Integer
):
lhs_val = lhs.value.with_signedness(lhs.signed)
else:
lhs_val = lhs.value
if isinstance(rhs.value, cutlass_arith.ArithValue) and isinstance(
rhs, Integer
):
rhs_val = rhs.value.with_signedness(rhs.signed)
else:
rhs_val = rhs.value
res = res_type(emitter(lhs_val, rhs_val), loc=loc, ip=ip)
x = res
else:
raise DSLNotImplemented(f"{type(args)} is not supported")
return x
def min(*args, loc=None, ip=None):
"""Computes the minimum value from the provided arguments.
This function differs from Python's built-in min() in that the return type
is determined by the static types of the inputs, not their dynamic values.
:param args: One or more values or iterables to find the minimum of
:type args: tuple
:param loc: Source location for MLIR operation tracking
:type loc: object, optional
:param ip: Insertion point for MLIR operation
:type ip: object, optional
:return: The minimum value among all inputs
:rtype: Numeric
:raises DSLNotImplemented: If the input type is not supported
Supports multiple calling patterns:
- min(a): Returns a
- min([a, b, c, ...]): Returns minimum of all elements in the iterable
- min(a, b, c, ...): Returns minimum of all arguments
- min([x, y], [b]): Returns minimum across all elements in all iterables
- min(a, (x, y, z)): Returns minimum across all elements
Examples:
.. code-block:: python
# Find minimum of two values
result = min(x, y)
# Find minimum of multiple values
result = min(a, b, c, d)
# Find minimum of values in a list
values = [a, b, c, d]
result = min(values)
# Find minimum across mixed arguments
result = min(x, [y, z])
Difference from Python's built-in min():
.. code-block:: python
# In Python, the return type depends on the dynamic values:
a = 5
b = 3.14
result = min(a, b) # Returns 3.14 (float)
# In this DSL implementation, the return type is determined statically:
a = Int32(5)
b = Float32(3.14)
result = min(a, b) # Return type is determined by the type of operands, not values
"""
return _minmax(min, *args, loc=loc, ip=ip)
def max(*args, loc=None, ip=None):
"""Computes the maximum value from the provided arguments.
This function differs from Python's built-in max() in that the return type
is determined by the static types of the inputs, not their dynamic values.
:param args: One or more values or iterables to find the maximum of
:type args: tuple
:param loc: Source location for MLIR operation tracking
:type loc: object, optional
:param ip: Insertion point for MLIR operation
:type ip: object, optional
:return: The maximum value among all inputs
:rtype: Numeric
:raises DSLNotImplemented: If the input type is not supported
Supports multiple calling patterns:
- max(a): Returns a
- max([a, b, c, ...]): Returns maximum of all elements in the iterable
- max(a, b, c, ...): Returns maximum of all arguments
- max([x, y], [b]): Returns maximum across all elements in all iterables
- max(a, (x, y, z)): Returns maximum across all elements
Examples:
.. code-block:: python
# Find maximum of two values
result = max(x, y)
# Find maximum of multiple values
result = max(a, b, c, d)
# Find maximum of values in a list
values = [a, b, c, d]
result = max(values)
# Find maximum across mixed arguments
result = max(x, [y, z])
Difference from Python's built-in max():
.. code-block:: python
# In Python, the return type depends on the dynamic values:
a = 5
b = 3.14
result = max(a, b) # Returns 5 (int)
# In this DSL implementation, the return type is determined statically:
a = Int32(5)
b = Float32(3.14)
result = max(a, b) # Return type is determined by the type of operands, not values
"""
return _minmax(max, *args, loc=loc, ip=ip)
def and_(*args, loc=None, ip=None):
"""AND operation for value in DSL numeric types.
:param *args: One or more numeric values to AND together
:type *args: Numeric
:param loc: Source location for MLIR operation tracking
:type loc: object, optional
:param ip: Insertion point for MLIR operation
:type ip: object, optional
:return: The result of the logical AND operation
:rtype: Numeric
:raises ValueError: If no arguments are provided
Supports multiple calling patterns:
- and_(a): Returns a
- and_(a, b, c, ...): if a is truthy, returns and_(b, c, ...), otherwise returns a
All arguments must be of the same type.
Examples:
.. code-block:: python
# In Python, 'and' returns the second operand if the first is truthy,
# otherwise it returns the first operand
a = 5
b = 3
result = a and b # Returns 3
# In this DSL implementation, the behavior is similar but works with DSL types
a = Int32(5)
b = Int32(3)
result = and_(a, b) # Returns b
"""
if len(args) == 0:
raise ValueError("and_() requires at least one argument")
if len(args) == 1:
return args[0]
def and_op(lhs, rhs):
if not isinstance(lhs, (Numeric, cutlass_arith.ArithValue, int, float, bool)):
raise DSLNotImplemented(f"{type(lhs)} is not supported")
elif isinstance(lhs, (int, float, bool)) and isinstance(
rhs, (int, float, bool)
):
return lhs and rhs
else:
return as_numeric(lhs).__dsl_and__(as_numeric(rhs))
return functools.reduce(and_op, args[1:], args[0])
def or_(*args, loc=None, ip=None):
"""Logical OR operation for DSL numeric types.
:param *args: One or more numeric values to OR together
:type *args: Numeric
:param loc: Source location for MLIR operation tracking
:type loc: object, optional
:param ip: Insertion point for MLIR operation
:type ip: object, optional
:return: The result of the logical OR operation
:rtype: Numeric
:raises ValueError: If no arguments are provided
Supports multiple calling patterns:
- or_(a): Returns a
- or_(a, b, c, ...): if a is truthy, returns a, otherwise returns or_(b, c, ...)
Examples:
.. code-block:: python
# In Python, 'or' returns the first operand if it's truthy,
# otherwise it returns the second operand
a = 5
b = 3
result = a or b # Returns 5
# In this DSL implementation, the behavior is similar but works with DSL types
a = Int32(5)
b = Int32(3)
result = or_(a, b) # Returns a
"""
if len(args) == 0:
raise ValueError("or_() requires at least one argument")
if len(args) == 1:
return args[0]
def or_op(lhs, rhs):
if not isinstance(lhs, (Numeric, cutlass_arith.ArithValue, int, float, bool)):
raise DSLNotImplemented(f"{type(lhs)} is not supported")
elif isinstance(lhs, (int, float, bool)) and isinstance(
rhs, (int, float, bool)
):
return lhs or rhs
else:
return as_numeric(lhs).__dsl_or__(as_numeric(rhs))
return functools.reduce(or_op, args[1:], args[0])
def all_(iterable):
"""Logical AND operation for all elements in an iterable.
Returns True if all elements in the iterable are truthy, otherwise False.
This is the DSL equivalent of Python's built-in all() function.
:param iterable: An iterable containing values to check
:type iterable: Iterable
:return: True if all elements are truthy, False otherwise
:rtype: Boolean
Examples:
.. code-block:: python
# Check if all values are non-zero
values = [Int32(1), Int32(2), Int32(3)]
result = all_(values) # Returns True
# Check if all conditions are met
conditions = [a > 0, b < 10, c != 0]
result = all_(conditions) # Returns True if all conditions are met
"""
bool_iterable = [Boolean(i) for i in iterable]
return functools.reduce(
lambda lhs, rhs: lhs.__dsl_and__(rhs) if hasattr(lhs, "__dsl_and__") else lhs,
bool_iterable,
Boolean(True),
)
def any_(iterable):
"""Logical OR operation for any element in an iterable.
Returns True if any element in the iterable is truthy, otherwise False.
This is the DSL equivalent of Python's built-in any() function.
:param iterable: An iterable containing values to check
:type iterable: Iterable
:return: True if any element is truthy, False otherwise
:rtype: Boolean
Examples:
.. code-block:: python
# Check if any value is non-zero
values = [Int32(0), Int32(0), Int32(3)]
result = any_(values) # Returns True
# Check if any condition is met
conditions = [a > 10, b < 0, c != 0]
result = any_(conditions) # Returns True if any condition is met
"""
bool_iterable = [Boolean(i) for i in iterable]
return functools.reduce(
lambda lhs, rhs: lhs.__dsl_or__(rhs) if hasattr(lhs, "__dsl_or__") else lhs,
bool_iterable,
Boolean(False),
)
# =============================================================================
# Conditional Expression
# =============================================================================
def select_(cond, if_value, else_value):
def _as_scalar(value):
if const_expr(isinstance(value, list)):
if const_expr(len(value) == 1):
return value[0]
else:
raise DSLRuntimeError(
"Conditional expression must have exactly one value in all expressions"
)
return value
# Non-DSL dynamic cond should be handled before this.
if const_expr(not is_dynamic_expression(cond)):
raise DSLRuntimeError("Conditional expression must be dynamic")
# Extract MLIR values
cond = extract_mlir_values(cond)
if const_expr(is_dynamic_expression(if_value)):
if_value = extract_mlir_values(if_value)
else:
if_value = const(if_value)
if const_expr(is_dynamic_expression(else_value)):
else_value = extract_mlir_values(else_value)
else:
else_value = const(else_value)
return arith.SelectOp(
_as_scalar(cond), _as_scalar(if_value), _as_scalar(else_value)
).result
# =============================================================================
# Terminator
# =============================================================================
def yield_out(args=[], loc=None, ip=None):
"""
Generate a yield operation. It it used to return values from a loop, if-else, or while region.
"""
scf.yield_(extract_mlir_values(args), loc=loc, ip=ip)
# =============================================================================
# For Loop
# =============================================================================
class LoopUnroll(ir.Attribute):
def __init__(self, **kwargs):
valid_keys = set(["count", "full"])
def to_mlir_attr(val):
if isinstance(val, bool):
return "true" if val else "false"
elif isinstance(val, int):
return f"{val} : i32"
else:
raise DSLNotImplemented(f"{type(val)} is not supported")
cfg = {key: to_mlir_attr(kwargs[key]) for key in valid_keys if key in kwargs}
if kwargs.get("count", None) == 1:
cfg["disable"] = "true"
unroll = "<" + ", ".join(f"{key} = {value}" for key, value in cfg.items()) + ">"
super().__init__(
ir.Attribute.parse(f"#llvm.loop_annotation<unroll = {unroll}>")
)
def for_generate(
start,
stop=None,
step=None,
iter_args: Optional[Sequence[ir.Value]] = None,
*,
unroll: LoopUnroll = None,
loc=None,
ip=None,
):
"""
scf.for with yield support
"""
if step is None:
step = 1
if stop is None:
stop = start
start = 0
start = const(start)
params = [start, stop, step]
for i, p in enumerate(params):
if isinstance(p, int):
p = const(p)
elif isinstance(p, float):
raise DSLRuntimeError(f"{p=} must be int.")
elif isinstance(p, Integer):
p = p.ir_value()
params[i] = p
start, stop, step = params
def _createI32Attr(value):
if not isinstance(value, int):
raise DSLRuntimeError(f"value must be int.")
return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), value)
ir_iter_args = extract_mlir_values(iter_args) if iter_args is not None else None
if not _validate_iter_args_structure(iter_args, ir_iter_args):
raise DSLRuntimeError("iter_args: Elements should be extractable as ir.Value.")
for_op = scf.ForOp(start, stop, step, ir_iter_args, loc=loc, ip=ip)
if unroll is not None:
for_op.attributes["loop_annotation"] = unroll
iv = for_op.induction_variable
new_results = new_from_mlir_values(iter_args, for_op.results)
new_iter_args = new_from_mlir_values(iter_args, for_op.inner_iter_args)
new_iter_args = () if new_iter_args is None else tuple(new_iter_args)
with ir.InsertionPoint(for_op.body):
if len(new_iter_args) > 1:
yield iv, new_iter_args, new_results
elif len(new_iter_args) == 1:
yield iv, new_iter_args[0], new_results[0]
else:
yield iv
# =============================================================================
# Logical Operators
# =============================================================================
def not_(lhs: Union[ir.Value, bool], *, loc=None, ip=None):
"""
Logical Not
"""
res = None
# Handle Python bool first to prevent infinite recursion
if const_expr(type(lhs) == bool):
res = lhs ^ True
elif const_expr(hasattr(lhs, "__dsl_not__")):
res = lhs.__dsl_not__(loc=loc, ip=ip)
elif const_expr(is_dynamic_expression(lhs)):
# If lhs is MLIR value, compute not using xor
res = arith.XOrIOp(lhs, const(1, lhs.type)).result
else:
res = bool(lhs) ^ True
return res
# =============================================================================
# If/Else
# =============================================================================
def if_generate(
cond: Boolean,
then_body: Callable,
else_body: Optional[Callable] = None,
input_args: List[DslType] = None,
return_types: List[DslType] = None,
*,
loc=None,
ip=None,
) -> List:
"""
Generate an IfOp with optional else branch and return values.
Args:
cond: The condition expression
then_body: Function to execute in then branch
else_body: Optional function to execute in else branch
input_args: Arguments to pass to branch bodies
return_types: Expected return types for the operation
loc: Optional location information
ip: Optional insertion point
Returns:
List of DSL typed results
"""
input_args = input_args or []
mlir_return_types = []
# Validate and collect MLIR return types (if provided).
if return_types is not None:
for t in return_types:
if not isinstance(t, DslType):
raise DSLRuntimeError(f"{t=} must be a DslType.")
mlir_return_types.append(t.mlir_type)
# Determine whether there's an else branch.
has_else = else_body is not None
# Create the IfOp.
if_op = scf.IfOp(
Boolean(cond).ir_value(), mlir_return_types, hasElse=has_else, loc=loc, ip=ip
)
def _execute_and_yield_out(body, input_args):
yield_vals = body(*input_args)
if return_types is not None:
if not isinstance(yield_vals, Iterable):
# body only return single element
yield_vals = [yield_vals]
yield_vals = [t(r) for t, r in zip(return_types, yield_vals)]
yield_out(yield_vals)
# Generate the body for 'then'.
with ir.InsertionPoint(if_op.then_block):
_execute_and_yield_out(then_body, input_args)
# Generate the body for 'else' if provided.
if has_else:
with ir.InsertionPoint(if_op.else_block):
_execute_and_yield_out(else_body, input_args)
# Collect MLIR results.
mlir_results = _get_op_result_or_op_results(if_op)
if not isinstance(mlir_results, list):
mlir_results = [mlir_results]
# Wrap the results with their DSL types.
if return_types is None:
return []
vals = [t(r) for t, r in zip(return_types, mlir_results)]
if len(vals) == 1:
return vals[0]
return vals
# =============================================================================
# While Loop
# =============================================================================
class WhileLoopContext:
"""
Context manager for a dynamic while loop.
"""
def __init__(
self,
inputs: Sequence[Union[ir.Value, Numeric]],
condition: Callable[[Sequence[ir.Value]], ir.Value],
*,
loc=None,
ip=None,
):
# Keep original inputs and allow recover original type information
self.inputs = inputs
self.input_ir_values = extract_mlir_values(inputs)
if not _validate_iter_args_structure(inputs, self.input_ir_values):
raise DSLRuntimeError("inputs: Elements should be extractable as ir.Value.")
self.condition = condition
self.input_ir_types = [i.type for i in self.input_ir_values]
self.while_op = scf.WhileOp(
self.input_ir_types, self.input_ir_values, loc=loc, ip=ip
)
self.before_region = self.while_op.before
self.after_region = self.while_op.after
self.before_region.blocks.append(*self.input_ir_types)
self.before_block = self.before_region.blocks[0]
self.after_region.blocks.append(*self.input_ir_types)
self.after_block = self.after_region.blocks[0]
def __enter__(self):
with ir.InsertionPoint(self.before_block):
args = new_from_mlir_values(self.inputs, self.before_block.arguments)
cond = self.condition(*args)
cond_ir_val = extract_mlir_values(cond)
scf.ConditionOp(cond_ir_val[0], [*self.before_block.arguments])
self.ipoint_op = ir.InsertionPoint(self.after_block)
self.ipoint_op.__enter__()
return new_from_mlir_values(self.inputs, self.after_block.arguments)
def __exit__(self, exc_type, exc_value, traceback):
self.ipoint_op.__exit__(exc_type, exc_value, traceback)
@property
def results(self):
return new_from_mlir_values(self.inputs, self.while_op.results_)
def while_generate(
inputs: Sequence[Union[ir.Value, Numeric]],
condition: Callable[[Sequence[Union[ir.Value, Numeric]]], Union[ir.Value, Numeric]],
*,
loc=None,
ip=None,
) -> WhileLoopContext:
"""
Generate a WhileLoopContext for a dynamic loop.
"""
return WhileLoopContext(inputs, condition, loc=loc, ip=ip)
|