Kernels:
Trusted publisher
Uploaded using `kernel-builder` (batch 32/32).
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py +997 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py +725 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py +269 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py +431 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py +184 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py +65 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py +41 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py +262 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py +362 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py +41 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py +196 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py +63 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py +621 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py +482 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py +250 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py +868 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py +1613 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/generator.py +0 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py +415 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py +175 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/library.py +1531 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py +868 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py +438 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py +427 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py +342 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py +661 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py +212 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py +753 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py +440 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py +447 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/docs_src/source/conf.py +132 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/__init__.py +36 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/int_tuple.py +225 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/layout.py +367 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/swizzle.py +129 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/typing.py +42 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_cutlass.py +74 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_library.py +46 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_pycute.py +46 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py +661 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py +146 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py +428 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py +44 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py +309 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py +198 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py +173 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py +142 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py +319 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py +180 -0
- build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py +44 -0
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py
ADDED
|
@@ -0,0 +1,997 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Ease-of-use interface for constructing, compiling, and running CONVs
|
| 35 |
+
|
| 36 |
+
The ``Conv2d`` interface is meant to allow one to easily instantiate, compile, and run
|
| 37 |
+
CONV2D operations in CUTLASS via Python, without specifying many configuration parameters.
|
| 38 |
+
Under the hood, the interface will select sensible default parameters for the many template
|
| 39 |
+
parameters for CUTLASS CONVs.
|
| 40 |
+
|
| 41 |
+
Note: optimal performance is not to be expected from this interface. To achieve optimal
|
| 42 |
+
performance, one should specify and tune each configuration parameter.
|
| 43 |
+
|
| 44 |
+
The simplest example of using this interface is the following:
|
| 45 |
+
|
| 46 |
+
.. highlight:: python
|
| 47 |
+
.. code-block:: python
|
| 48 |
+
|
| 49 |
+
# A, B, C, and D are torch/numpy/cupy tensor objects
|
| 50 |
+
plan = cutlass_cppgen.op.Conv(A, B, C, D)
|
| 51 |
+
plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
| 52 |
+
|
| 53 |
+
One can also use the interface by specifying data types of operands at construction
|
| 54 |
+
and using different tensor objects with these data types at runtime:
|
| 55 |
+
|
| 56 |
+
.. highlight:: python
|
| 57 |
+
.. code-block:: python
|
| 58 |
+
|
| 59 |
+
# The following is shorthand for:
|
| 60 |
+
# cutlass_cppgen.op.Conv2d(kind="fprop",
|
| 61 |
+
# element_A=torch.float32, element_B=torch.float32,
|
| 62 |
+
# element_C=torch.float32, element_D=torch.float32,
|
| 63 |
+
# element_accumulator=torch.float32)
|
| 64 |
+
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=torch.float32)
|
| 65 |
+
|
| 66 |
+
A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda')
|
| 67 |
+
B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda')
|
| 68 |
+
C0 = torch.zeros((128, 64), dtype=torch.float32, device='cuda')
|
| 69 |
+
D0 = torch.zeros((128, 64), dtype=torch.float32, device.'cuda')
|
| 70 |
+
plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
| 71 |
+
|
| 72 |
+
A = torch.rand((32, 128), dtype=torch.float32, device='cuda')
|
| 73 |
+
B = torch.rand((128, 256), dtype=torch.float32, device='cuda')
|
| 74 |
+
C = torch.zeros((32, 256), dtype=torch.float32, device='cuda')
|
| 75 |
+
D = torch.zeros((32, 256), dtype=torch.float32, device.'cuda')
|
| 76 |
+
plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
| 77 |
+
|
| 78 |
+
The interface additionally enables one to decouple the compilation of the underlying CUTLASS
|
| 79 |
+
kernel from its execution:
|
| 80 |
+
|
| 81 |
+
.. highlight:: python
|
| 82 |
+
.. code-block:: python
|
| 83 |
+
|
| 84 |
+
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
|
| 85 |
+
|
| 86 |
+
# Do other work...
|
| 87 |
+
|
| 88 |
+
plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
| 89 |
+
|
| 90 |
+
# Do other work...
|
| 91 |
+
|
| 92 |
+
plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
| 93 |
+
|
| 94 |
+
Elementwise activation functions are easily fused to the GEMM via the interface:
|
| 95 |
+
|
| 96 |
+
.. highlight:: python
|
| 97 |
+
.. code-block:: python
|
| 98 |
+
|
| 99 |
+
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
|
| 100 |
+
plan.activation = cutlass_cppgen.epilogue.relu
|
| 101 |
+
|
| 102 |
+
Operations can also be run asynchronously:
|
| 103 |
+
|
| 104 |
+
.. highlight:: python
|
| 105 |
+
.. code-block:: python
|
| 106 |
+
|
| 107 |
+
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
|
| 108 |
+
args = plan.run()
|
| 109 |
+
|
| 110 |
+
# Do other work...
|
| 111 |
+
|
| 112 |
+
args.sync()
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
from __future__ import annotations
|
| 116 |
+
from typing import Optional
|
| 117 |
+
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 118 |
+
cuda = lazy_import("cuda.cuda")
|
| 119 |
+
cudart = lazy_import("cuda.cudart")
|
| 120 |
+
from cutlass_library import (
|
| 121 |
+
ConvKind,
|
| 122 |
+
ConvMode,
|
| 123 |
+
DataTypeSize,
|
| 124 |
+
IteratorAlgorithm,
|
| 125 |
+
OperationKind,
|
| 126 |
+
SplitKMode,
|
| 127 |
+
StrideSupport,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
import cutlass_cppgen
|
| 131 |
+
from cutlass_cppgen import epilogue
|
| 132 |
+
from cutlass_cppgen.backend import compiler
|
| 133 |
+
from cutlass_cppgen.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
|
| 134 |
+
from cutlass_cppgen.backend.reduction_operation import ReductionOperation, ReductionArguments
|
| 135 |
+
from cutlass_cppgen.backend.library import TensorDescription, TileDescription
|
| 136 |
+
from cutlass_cppgen.op.op import OperationBase
|
| 137 |
+
from cutlass_cppgen.shape import Conv2DProblemSize, MatrixCoord
|
| 138 |
+
from cutlass_cppgen.utils import check, datatypes
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class Conv2d(OperationBase):
|
| 142 |
+
"""
|
| 143 |
+
Constructs a ``Conv2d`` object.
|
| 144 |
+
|
| 145 |
+
The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C,
|
| 146 |
+
along with the data type of output D and that used for accumulation, are bound to the ``Conv``
|
| 147 |
+
object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed.
|
| 148 |
+
|
| 149 |
+
The constructor has optional parameters for flexibly setting these parameters. The following
|
| 150 |
+
constructors are equivalent:
|
| 151 |
+
|
| 152 |
+
.. highlight:: python
|
| 153 |
+
.. code-block:: python
|
| 154 |
+
|
| 155 |
+
# Use F32 for A, B, C, D, and accumulation in fprop
|
| 156 |
+
|
| 157 |
+
# Use the generic ``element`` parameter to concisely set all data types for operands to the same values.
|
| 158 |
+
Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f32)
|
| 159 |
+
|
| 160 |
+
# Explicitly specify the data types to use for A, B, C, and D.
|
| 161 |
+
Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32,
|
| 162 |
+
element_C=cutlass_cppgen.DataType.f32, element_D=cutlass_cppgen.DataType.f32)
|
| 163 |
+
|
| 164 |
+
# Set the data types and elements from existing tensors. Note that one can use different tensors when
|
| 165 |
+
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
|
| 166 |
+
# have the same data type as those passed in here).
|
| 167 |
+
# A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout
|
| 168 |
+
Conv2d(kind="fprop", A=A, B=B, C=C, D=D)
|
| 169 |
+
|
| 170 |
+
# Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit
|
| 171 |
+
# those passed in via the generic ``element``
|
| 172 |
+
Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32,
|
| 173 |
+
element=cutlass_cppgen.DataType.f32)
|
| 174 |
+
|
| 175 |
+
The order of precedence for the setting of the data type for a given operand/output is as follows:
|
| 176 |
+
1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor
|
| 177 |
+
2) Otherwise, if the data type (e.g., ``element_A``) is specified, use those
|
| 178 |
+
3) Otherwise, use the generic values (e.g., ``element``)
|
| 179 |
+
|
| 180 |
+
:param kind: the convolution kind (i.e. fprop, wgrad, and dgrad)
|
| 181 |
+
:type kind: str
|
| 182 |
+
:param A: tensor representing data type of operand A
|
| 183 |
+
:param B: tensor representing data type of operand B
|
| 184 |
+
:param C: tensor representing data type of operand C
|
| 185 |
+
:param D: tensor representing data type of operand D
|
| 186 |
+
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
| 187 |
+
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
| 188 |
+
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
| 189 |
+
:type element: cutlass_cppgen.DataType
|
| 190 |
+
:param element_A: data type to be used for operand A
|
| 191 |
+
:type element_A: cutlass_cppgen.DataType
|
| 192 |
+
:param element_B: data type to be used for operand B
|
| 193 |
+
:type element_B: cutlass_cppgen.DataType
|
| 194 |
+
:param element_C: data type to be used for operand C
|
| 195 |
+
:type element_C: cutlass_cppgen.DataType
|
| 196 |
+
:param element_D: data type to be used for operand D
|
| 197 |
+
:type element_D: cutlass_cppgen.DataType
|
| 198 |
+
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
| 199 |
+
:type element_accumulator: cutlass_cppgen.DataType
|
| 200 |
+
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
| 201 |
+
:type cc: int
|
| 202 |
+
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
| 203 |
+
:type kernel_cc: int
|
| 204 |
+
"""
|
| 205 |
+
def __init__(
|
| 206 |
+
self, kind="fprop",
|
| 207 |
+
A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0,
|
| 208 |
+
element=None,
|
| 209 |
+
element_A=None, element_B=None, element_C=None, element_D=None,
|
| 210 |
+
element_accumulator=None,
|
| 211 |
+
cc: int = None, kernel_cc: int = None
|
| 212 |
+
):
|
| 213 |
+
super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=OperationKind.Conv2d)
|
| 214 |
+
# Verify the kernel cc
|
| 215 |
+
if self.current_cc in [90, 100, 101, 103]:
|
| 216 |
+
# The Conv2d kernel on Hopper (SM90) is currently unsupported
|
| 217 |
+
# Revert to use SM80-tagged kernels
|
| 218 |
+
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
| 219 |
+
self.specified_kernel_cc = 80
|
| 220 |
+
self._reset_options(80)
|
| 221 |
+
|
| 222 |
+
# The arch is used in testing
|
| 223 |
+
self.arch = self.current_cc
|
| 224 |
+
self.name = "conv2d" + kind
|
| 225 |
+
|
| 226 |
+
# The convolution kind. (concept: cutlass_library.library.ConvKind)
|
| 227 |
+
self.conv_kind = datatypes.getattr_enum(ConvKind, kind)
|
| 228 |
+
|
| 229 |
+
# The element types (concept: cutlass library types) of A, B, C, and D
|
| 230 |
+
elements = []
|
| 231 |
+
layouts = []
|
| 232 |
+
|
| 233 |
+
# Complete the data types based on user-provided arguments
|
| 234 |
+
for elt, tens, name in zip([element_A, element_B, element_C, element_D],
|
| 235 |
+
[A, B, C, D],
|
| 236 |
+
["A", "B", "C", "D"]):
|
| 237 |
+
if elt is not None and tens is not None:
|
| 238 |
+
raise Exception(f'Must not specify both element_{name} and tensor {name}')
|
| 239 |
+
if elt is None and tens is None and element is None:
|
| 240 |
+
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
|
| 241 |
+
|
| 242 |
+
elt_to_set = None
|
| 243 |
+
lay_to_set = None
|
| 244 |
+
|
| 245 |
+
if tens is not None:
|
| 246 |
+
elt_to_set, _ = datatypes.get_datatype_and_layout(tens)
|
| 247 |
+
else:
|
| 248 |
+
elt_to_set = elt if elt is not None else element
|
| 249 |
+
|
| 250 |
+
assert elt_to_set is not None
|
| 251 |
+
|
| 252 |
+
# Currently we only support layout TensorNHWC
|
| 253 |
+
lay_to_set = cutlass_cppgen.LayoutType.TensorNHWC
|
| 254 |
+
elements.append(datatypes.library_type(elt_to_set))
|
| 255 |
+
layouts.append(lay_to_set)
|
| 256 |
+
|
| 257 |
+
self._element_a, self._element_b, self._element_c, self._element_d = elements
|
| 258 |
+
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
|
| 259 |
+
|
| 260 |
+
self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta
|
| 261 |
+
|
| 262 |
+
if element_accumulator is None:
|
| 263 |
+
self._element_accumulator = self._element_c
|
| 264 |
+
else:
|
| 265 |
+
self._element_accumulator = datatypes.library_type(element_accumulator)
|
| 266 |
+
|
| 267 |
+
# Default inputs if none is supplied in run()
|
| 268 |
+
self.A = A
|
| 269 |
+
self.B = B
|
| 270 |
+
self.C = C
|
| 271 |
+
self.D = D
|
| 272 |
+
|
| 273 |
+
self.alpha = alpha
|
| 274 |
+
self.beta = beta
|
| 275 |
+
|
| 276 |
+
# We only specify the stride of the swizzling functor here
|
| 277 |
+
# The actual swizzling functor is determined in run based on conv_kind and stride
|
| 278 |
+
self._swizzling_stride = 1
|
| 279 |
+
|
| 280 |
+
# Arguments that will be set to default value in _reset_operations
|
| 281 |
+
# The default tile_description and op_class are fetched from manifest of cutlass library
|
| 282 |
+
self._tile_description = None
|
| 283 |
+
self.op_class = None
|
| 284 |
+
# The default identity epilogue will be created
|
| 285 |
+
self.epilogue_functor = None
|
| 286 |
+
|
| 287 |
+
self._reset_operations()
|
| 288 |
+
|
| 289 |
+
# Arguments that will be determined online based on arguments of "run"
|
| 290 |
+
# based on stride, input/output channels, alignment, and conv_kind
|
| 291 |
+
self._iterator_algorithm = None
|
| 292 |
+
self._stride_support = None
|
| 293 |
+
|
| 294 |
+
def _reset_operations(self, reset_epilogue: bool = True):
|
| 295 |
+
# Set the default op class
|
| 296 |
+
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
|
| 297 |
+
layout_comb = (self._layout_a, self._layout_b)
|
| 298 |
+
|
| 299 |
+
self.possible_op_classes = self.options.supporting_opclasses(
|
| 300 |
+
self._element_a, self._element_b, self._element_accumulator,
|
| 301 |
+
self._layout_a, self._layout_b, self._math_operation
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
|
| 305 |
+
self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
|
| 306 |
+
elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
|
| 307 |
+
self.opclass = cutlass_cppgen.OpcodeClass.Simt
|
| 308 |
+
else:
|
| 309 |
+
if self._math_operation is not None:
|
| 310 |
+
math_op_str = f' and math operation {self._math_operation}'
|
| 311 |
+
else:
|
| 312 |
+
math_op_str = ''
|
| 313 |
+
|
| 314 |
+
raise Exception(f'No kernel configuration found for supported data type and layout '
|
| 315 |
+
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
|
| 316 |
+
|
| 317 |
+
if reset_epilogue:
|
| 318 |
+
self._reset_epilogue_functor_activation(epilogue.identity)
|
| 319 |
+
|
| 320 |
+
self.alignment_pref_A = min(
|
| 321 |
+
128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
|
| 322 |
+
self.alignment_pref_B = min(
|
| 323 |
+
128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
|
| 324 |
+
self.alignment_pref_C = min(
|
| 325 |
+
128 // DataTypeSize[self._element_c], max(self.possible_operations.alignments("C")))
|
| 326 |
+
|
| 327 |
+
#
|
| 328 |
+
# Tile description Related
|
| 329 |
+
#
|
| 330 |
+
|
| 331 |
+
@property
|
| 332 |
+
def tile_description(self) -> TileDescription:
|
| 333 |
+
"""
|
| 334 |
+
Returns the tile description
|
| 335 |
+
"""
|
| 336 |
+
return self._tile_description
|
| 337 |
+
|
| 338 |
+
@tile_description.setter
|
| 339 |
+
def tile_description(
|
| 340 |
+
self, td=None):
|
| 341 |
+
"""
|
| 342 |
+
Set the tile description
|
| 343 |
+
|
| 344 |
+
:param td: tile description
|
| 345 |
+
:type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
|
| 346 |
+
{
|
| 347 |
+
"threadblock_shape": [int, int, int],
|
| 348 |
+
"warp_count": [int, int, int],
|
| 349 |
+
"stages": int,
|
| 350 |
+
"instruction_shape": [int, int, int] (optional),
|
| 351 |
+
"cluster_shape": [int, int, int] (optional)
|
| 352 |
+
}
|
| 353 |
+
"""
|
| 354 |
+
if td is None:
|
| 355 |
+
return
|
| 356 |
+
if isinstance(td, dict):
|
| 357 |
+
if self._tile_description is None:
|
| 358 |
+
op = self.possible_operations.default_operation(self._math_operation)
|
| 359 |
+
self._tile_description = datatypes.td_from_profiler_op(op)
|
| 360 |
+
if "cluster_shape" in td.keys():
|
| 361 |
+
if td["cluster_shape"] != [1, 1, 1]:
|
| 362 |
+
cutlass_cppgen.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.")
|
| 363 |
+
td["cluster_shape"] = [1, 1, 1]
|
| 364 |
+
td = self._tile_description.clone_and_update(td)
|
| 365 |
+
|
| 366 |
+
valid, msg = self._valid_tile_description(td)
|
| 367 |
+
if valid:
|
| 368 |
+
self._tile_description = td
|
| 369 |
+
else:
|
| 370 |
+
raise Exception(msg)
|
| 371 |
+
|
| 372 |
+
def _valid_tile_description(self, td: TileDescription) -> tuple:
|
| 373 |
+
"""
|
| 374 |
+
Checks whether the provided tile description is valid for the given compute capability. At present,
|
| 375 |
+
this checks the following:
|
| 376 |
+
|
| 377 |
+
- Does the tile description use a number of stages supported by the compute capability in question?
|
| 378 |
+
- Does the tile size requested fit within shared memory?
|
| 379 |
+
- Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
|
| 380 |
+
more non-unit cluster dimensions for pre-SM90 architectures)?
|
| 381 |
+
- Is the kernel schedule being used supported on the architecture in question?
|
| 382 |
+
|
| 383 |
+
:param td: tile description to validate
|
| 384 |
+
:type td: cutlass_cppgen.backend.TileDescription
|
| 385 |
+
:return: tuple in which the first element is a bool indicating that the tile description is valid
|
| 386 |
+
and the second element is a string providing an optional error message.
|
| 387 |
+
:rtype: tuple
|
| 388 |
+
"""
|
| 389 |
+
valid, msg = check.valid_stage_count(self.cc, self.current_cc, td)
|
| 390 |
+
if not valid:
|
| 391 |
+
return (valid, msg)
|
| 392 |
+
|
| 393 |
+
valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
|
| 394 |
+
if not valid:
|
| 395 |
+
return (valid, msg)
|
| 396 |
+
|
| 397 |
+
return valid, msg
|
| 398 |
+
|
| 399 |
+
def tile_descriptions(self) -> list:
|
| 400 |
+
"""
|
| 401 |
+
Returns a list of valid tile descriptions for the operations
|
| 402 |
+
|
| 403 |
+
:returns: list of valid tile descriptions for the operations
|
| 404 |
+
:rtype: list
|
| 405 |
+
"""
|
| 406 |
+
descriptions = []
|
| 407 |
+
description_str = []
|
| 408 |
+
for op in self.possible_operations.all_operations:
|
| 409 |
+
td = datatypes.td_from_profiler_op(op)
|
| 410 |
+
|
| 411 |
+
if self._math_operation is not None:
|
| 412 |
+
if td.math_instruction.math_operation != self._math_operation:
|
| 413 |
+
continue
|
| 414 |
+
|
| 415 |
+
if str(td) not in description_str:
|
| 416 |
+
description_str.append(str(td))
|
| 417 |
+
descriptions.append(td)
|
| 418 |
+
return descriptions
|
| 419 |
+
|
| 420 |
+
#
|
| 421 |
+
# Swizzling functor Related
|
| 422 |
+
#
|
| 423 |
+
|
| 424 |
+
@property
|
| 425 |
+
def swizzling_stride(self):
|
| 426 |
+
"""
|
| 427 |
+
Returns the stride of swizzling currently being used by the Conv2d
|
| 428 |
+
|
| 429 |
+
:return: swizzing stride
|
| 430 |
+
"""
|
| 431 |
+
return self._swizzling_stride
|
| 432 |
+
|
| 433 |
+
@swizzling_stride.setter
|
| 434 |
+
def swizzling_stride(self, stride: int):
|
| 435 |
+
"""
|
| 436 |
+
Sets the swizzling functor to the type specified by `swizzling_functor`
|
| 437 |
+
"""
|
| 438 |
+
if not isinstance(stride, int):
|
| 439 |
+
raise Exception(f"Expect integer (1, 2, 4, 8), got {stride}")
|
| 440 |
+
self._swizzling_stride = stride
|
| 441 |
+
|
| 442 |
+
def _propose_swizzling_functor(self, stride):
|
| 443 |
+
"""
|
| 444 |
+
Automatically propose the swizzling functor based on the stride
|
| 445 |
+
"""
|
| 446 |
+
if self.conv_kind == ConvKind.Dgrad:
|
| 447 |
+
if stride[0] != 1 or stride[1] != 1:
|
| 448 |
+
return getattr(cutlass_cppgen.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}")
|
| 449 |
+
|
| 450 |
+
return getattr(cutlass_cppgen.swizzle, f"IdentitySwizzle{self._swizzling_stride}")
|
| 451 |
+
|
| 452 |
+
#
|
| 453 |
+
# Iterator Algorithm Related
|
| 454 |
+
#
|
| 455 |
+
|
| 456 |
+
@property
|
| 457 |
+
def iterator_algorithm(self) -> IteratorAlgorithm:
|
| 458 |
+
"""
|
| 459 |
+
Returns the iterator algorithm
|
| 460 |
+
"""
|
| 461 |
+
return self._iterator_algorithm
|
| 462 |
+
|
| 463 |
+
@iterator_algorithm.setter
|
| 464 |
+
def iterator_algorithm(self, alg: str):
|
| 465 |
+
"""
|
| 466 |
+
Sets the iterator algorithm
|
| 467 |
+
|
| 468 |
+
:param alg: The iterator algorithm
|
| 469 |
+
:type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels"
|
| 470 |
+
"""
|
| 471 |
+
iterator_alg = datatypes.getattr_enum(IteratorAlgorithm, alg)
|
| 472 |
+
|
| 473 |
+
# Check if the iterator algorithm is valid
|
| 474 |
+
if iterator_alg in [IteratorAlgorithm.FewChannels, IteratorAlgorithm.FixedChannels] and self.conv_kind != ConvKind.Fprop:
|
| 475 |
+
raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.")
|
| 476 |
+
|
| 477 |
+
self._iterator_algorithm = iterator_alg
|
| 478 |
+
|
| 479 |
+
def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> IteratorAlgorithm:
|
| 480 |
+
"""
|
| 481 |
+
Propose a valid iterator algorithm based on problem size and alignment
|
| 482 |
+
"""
|
| 483 |
+
if self.conv_kind == ConvKind.Fprop:
|
| 484 |
+
# Check whether the fixed channel is applicable
|
| 485 |
+
if problem_size.C == alignment_a:
|
| 486 |
+
return IteratorAlgorithm.FixedChannels
|
| 487 |
+
elif (problem_size.C % alignment_a == 0 and
|
| 488 |
+
problem_size.R <= 32 and problem_size.S <= 32):
|
| 489 |
+
return IteratorAlgorithm.Optimized
|
| 490 |
+
else:
|
| 491 |
+
return IteratorAlgorithm.Analytic
|
| 492 |
+
elif self.conv_kind == ConvKind.Dgrad:
|
| 493 |
+
if (problem_size.K % alignment_a == 0 and
|
| 494 |
+
problem_size.R <= 32 and problem_size.S <= 32 and
|
| 495 |
+
problem_size.C % alignment_b == 0):
|
| 496 |
+
return IteratorAlgorithm.Optimized
|
| 497 |
+
else:
|
| 498 |
+
return IteratorAlgorithm.Analytic
|
| 499 |
+
elif self.conv_kind == ConvKind.Wgrad:
|
| 500 |
+
if (problem_size.K % alignment_a == 0 and
|
| 501 |
+
problem_size.C % alignment_b == 0):
|
| 502 |
+
return IteratorAlgorithm.Optimized
|
| 503 |
+
else:
|
| 504 |
+
return IteratorAlgorithm.Analytic
|
| 505 |
+
|
| 506 |
+
def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool:
|
| 507 |
+
"""
|
| 508 |
+
Validate whether the user provide iterator algorithm works for the given problem size
|
| 509 |
+
"""
|
| 510 |
+
if self.conv_kind == ConvKind.Fprop:
|
| 511 |
+
if iterator_algorithm == IteratorAlgorithm.FixedChannels:
|
| 512 |
+
return problem_size.C == alignment_a
|
| 513 |
+
elif iterator_algorithm == IteratorAlgorithm.Optimized:
|
| 514 |
+
return (problem_size.C % alignment_a == 0 and
|
| 515 |
+
problem_size.R <= 32 and problem_size.S <= 32)
|
| 516 |
+
elif iterator_algorithm == IteratorAlgorithm.FewChannels:
|
| 517 |
+
return problem_size.C % alignment_a == 0
|
| 518 |
+
elif self.conv_kind == ConvKind.Dgrad:
|
| 519 |
+
if iterator_algorithm == IteratorAlgorithm.Optimized:
|
| 520 |
+
return (problem_size.K % alignment_a == 0 and
|
| 521 |
+
problem_size.R <= 32 and problem_size.S <= 32 and
|
| 522 |
+
problem_size.C % alignment_b == 0)
|
| 523 |
+
elif self.conv_kind == ConvKind.Wgrad:
|
| 524 |
+
if iterator_algorithm == IteratorAlgorithm.Optimized:
|
| 525 |
+
return (problem_size.K % alignment_a == 0 and
|
| 526 |
+
problem_size.C % alignment_b == 0)
|
| 527 |
+
|
| 528 |
+
return True
|
| 529 |
+
|
| 530 |
+
#
|
| 531 |
+
# Stride Support Related
|
| 532 |
+
#
|
| 533 |
+
|
| 534 |
+
def _propose_stride_support(self, stride):
|
| 535 |
+
if self.conv_kind == ConvKind.Dgrad:
|
| 536 |
+
if stride[0] == 1 and stride[1] == 1:
|
| 537 |
+
return StrideSupport.Unity
|
| 538 |
+
|
| 539 |
+
return StrideSupport.Strided
|
| 540 |
+
|
| 541 |
+
#
|
| 542 |
+
# Construct and Compilation
|
| 543 |
+
#
|
| 544 |
+
|
| 545 |
+
def construct(
|
| 546 |
+
self, tile_description: TileDescription = None,
|
| 547 |
+
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
| 548 |
+
iterator_algorithm: IteratorAlgorithm = None,
|
| 549 |
+
stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
|
| 550 |
+
epilogue_functor=None) -> cutlass_cppgen.backend.Conv2dOperation:
|
| 551 |
+
"""
|
| 552 |
+
Constructs a ``cutlass_cppgen.backend.Conv2dOperation`` based on the input parameters and current
|
| 553 |
+
kernel specification of the ``Conv2d`` object.
|
| 554 |
+
|
| 555 |
+
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
| 556 |
+
:type tile_description: cutlass_cppgen.backend.TileDescription
|
| 557 |
+
:param alignment_A: alignment of operand A
|
| 558 |
+
:type alignment_A: int
|
| 559 |
+
:param alignment_B: alignment of operand B
|
| 560 |
+
:type alignment_B: int
|
| 561 |
+
:param alignment_C: alignment of operand C
|
| 562 |
+
:type alignment_C: int
|
| 563 |
+
:param iterator_algorithm: the iterator algorithm used
|
| 564 |
+
:type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
|
| 565 |
+
:param stride_support: the stride support of dgrad
|
| 566 |
+
:type stride_support: cutlass_library.library.StrideSupport
|
| 567 |
+
:param swizzling_functor: the swizzling functor
|
| 568 |
+
:type swizzling_functor: cutlass_cppgen.swizzle
|
| 569 |
+
:param epilogue_functor: the epilogue functor
|
| 570 |
+
|
| 571 |
+
:return: operation that was constructed
|
| 572 |
+
:rtype: cutlass_cppgen.backend.Conv2dOperation
|
| 573 |
+
"""
|
| 574 |
+
# Get alignment
|
| 575 |
+
alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A)
|
| 576 |
+
alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B)
|
| 577 |
+
alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C)
|
| 578 |
+
|
| 579 |
+
tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
|
| 580 |
+
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
| 581 |
+
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
| 582 |
+
|
| 583 |
+
if tile_description is None:
|
| 584 |
+
if self.tile_description is not None:
|
| 585 |
+
tile_description = self.tile_description
|
| 586 |
+
else:
|
| 587 |
+
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
|
| 588 |
+
tile_description = datatypes.td_from_profiler_op(op)
|
| 589 |
+
else:
|
| 590 |
+
valid, err_str = self._valid_tile_description(tile_description)
|
| 591 |
+
if not valid:
|
| 592 |
+
raise Exception(f"Invalid tile description. {err_str}")
|
| 593 |
+
self.tile_description = tile_description
|
| 594 |
+
|
| 595 |
+
if iterator_algorithm is None:
|
| 596 |
+
# If the iterator algorithm is already set
|
| 597 |
+
if self.iterator_algorithm is not None:
|
| 598 |
+
iterator_algorithm = self.iterator_algorithm
|
| 599 |
+
else:
|
| 600 |
+
# Otherwise, we conservatively use the analytic iterator for correctness
|
| 601 |
+
iterator_algorithm = IteratorAlgorithm.Analytic
|
| 602 |
+
|
| 603 |
+
if stride_support is None:
|
| 604 |
+
# If the stride support is already set
|
| 605 |
+
if self._stride_support is not None:
|
| 606 |
+
stride_support = self._stride_support
|
| 607 |
+
else:
|
| 608 |
+
# Otherwise, we assume strided
|
| 609 |
+
stride_support = StrideSupport.Strided
|
| 610 |
+
|
| 611 |
+
if swizzling_functor is None:
|
| 612 |
+
# If the swizzling functor is already set
|
| 613 |
+
swizzling_functor = self._propose_swizzling_functor(stride=(2, 2))
|
| 614 |
+
|
| 615 |
+
if epilogue_functor is None:
|
| 616 |
+
if self.epilogue_functor is not None:
|
| 617 |
+
epilogue_functor = self.epilogue_functor
|
| 618 |
+
else:
|
| 619 |
+
epilogue_functor = self._create_epilogue_functor_activation(self._activation)
|
| 620 |
+
|
| 621 |
+
# Reset the alignment of the epilogue functor
|
| 622 |
+
epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor)
|
| 623 |
+
|
| 624 |
+
operation = Conv2dOperation(
|
| 625 |
+
conv_kind=self.conv_kind,
|
| 626 |
+
iterator_algorithm=iterator_algorithm,
|
| 627 |
+
arch=self.current_cc,
|
| 628 |
+
tile_description=tile_description,
|
| 629 |
+
A=tensor_A, B=tensor_B, C=tensor_C,
|
| 630 |
+
stride_support=stride_support,
|
| 631 |
+
epilogue_functor=epilogue_functor,
|
| 632 |
+
swizzling_functor=swizzling_functor,
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
return operation
|
| 636 |
+
|
| 637 |
+
def compile(self, tile_description: TileDescription = None,
|
| 638 |
+
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
| 639 |
+
iterator_algorithm: IteratorAlgorithm = None,
|
| 640 |
+
stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
|
| 641 |
+
epilogue_functor = None, print_module: bool = False) -> cutlass_cppgen.backend.Conv2dOperation:
|
| 642 |
+
"""
|
| 643 |
+
Emits and compiles the kernel currently specified. If ``tile_description`` and any
|
| 644 |
+
of the ``alignment`` parameters are set, the kernel will be chosen using this
|
| 645 |
+
tile description and alignments. Otherwise, a default tile description and alignment
|
| 646 |
+
will be used.
|
| 647 |
+
|
| 648 |
+
::param tile_description: tile description specifying shapes and operand types to use in the kernel
|
| 649 |
+
:type tile_description: cutlass_cppgen.backend.TileDescription
|
| 650 |
+
:param alignment_A: alignment of operand A
|
| 651 |
+
:type alignment_A: int
|
| 652 |
+
:param alignment_B: alignment of operand B
|
| 653 |
+
:type alignment_B: int
|
| 654 |
+
:param alignment_C: alignment of operand C
|
| 655 |
+
:type alignment_C: int
|
| 656 |
+
:param iterator_algorithm: the iterator algorithm used
|
| 657 |
+
:type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
|
| 658 |
+
:param stride_support: the stride support of dgrad
|
| 659 |
+
:type stride_support: cutlass_library.library.StrideSupport
|
| 660 |
+
:param swizzling_functor: the swizzling functor
|
| 661 |
+
:type swizzling_functor: cutlass_cppgen.swizzle
|
| 662 |
+
:param epilogue_functor: the epilogue functor
|
| 663 |
+
|
| 664 |
+
:return: operation that was compiled
|
| 665 |
+
:rtype: cutlass_cppgen.backend.Conv2dOperation
|
| 666 |
+
"""
|
| 667 |
+
|
| 668 |
+
self.operation = self.construct(
|
| 669 |
+
tile_description, alignment_A, alignment_B, alignment_C,
|
| 670 |
+
iterator_algorithm, stride_support, swizzling_functor, epilogue_functor)
|
| 671 |
+
|
| 672 |
+
if print_module:
|
| 673 |
+
print(self.operation.rt_module.emit())
|
| 674 |
+
|
| 675 |
+
compiler.add_module([self.operation,])
|
| 676 |
+
return self.operation
|
| 677 |
+
|
| 678 |
+
#
|
| 679 |
+
# Run Related
|
| 680 |
+
#
|
| 681 |
+
|
| 682 |
+
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
|
| 683 |
+
"""
|
| 684 |
+
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
|
| 685 |
+
is raised if it does not.
|
| 686 |
+
|
| 687 |
+
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
| 688 |
+
:type tensor: numpy/cupy/torch array/tensor object
|
| 689 |
+
:param ref_dtype: data type for the tensor that this object was initialized to
|
| 690 |
+
:param name: identifier of the tensor to verify. Used in raising exceptions
|
| 691 |
+
:type name: str
|
| 692 |
+
"""
|
| 693 |
+
dtype, _ = datatypes.get_datatype_and_layout(tensor)
|
| 694 |
+
if dtype != ref_type:
|
| 695 |
+
raise Exception(f'Tensor {name} with type and layout {dtype} '
|
| 696 |
+
f'does not match the expected type of {ref_type}.')
|
| 697 |
+
|
| 698 |
+
def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation):
|
| 699 |
+
if self.conv_kind == ConvKind.Fprop:
|
| 700 |
+
input = A
|
| 701 |
+
weight = B
|
| 702 |
+
output = C
|
| 703 |
+
output_tensor = "C"
|
| 704 |
+
elif self.conv_kind == ConvKind.Dgrad:
|
| 705 |
+
output = A
|
| 706 |
+
weight = B
|
| 707 |
+
input = C
|
| 708 |
+
output_tensor = "A"
|
| 709 |
+
elif self.conv_kind == ConvKind.Wgrad:
|
| 710 |
+
output = A
|
| 711 |
+
input = B
|
| 712 |
+
weight = C
|
| 713 |
+
output_tensor = "A"
|
| 714 |
+
else:
|
| 715 |
+
raise Exception(f"Convolution kind {self.conv_kind} is not supported")
|
| 716 |
+
|
| 717 |
+
N_, H_, W_, C_ = datatypes.get_tensor_shape(input, op="CONV")
|
| 718 |
+
K_, R_, S_, _ = datatypes.get_tensor_shape(weight, op="CONV")
|
| 719 |
+
_, P_, Q_, _ = datatypes.get_tensor_shape(output, op="CONV")
|
| 720 |
+
|
| 721 |
+
problem_size = Conv2DProblemSize(
|
| 722 |
+
N_, H_, W_, C_,
|
| 723 |
+
K_, R_, S_, C_,
|
| 724 |
+
padding[0], padding[1],
|
| 725 |
+
stride[0], stride[1],
|
| 726 |
+
dilation[0], dilation[1],
|
| 727 |
+
ConvMode.CrossCorrelation,
|
| 728 |
+
1, 1
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
if P_ != problem_size.P or Q_ != problem_size.Q:
|
| 732 |
+
raise Exception(
|
| 733 |
+
f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})")
|
| 734 |
+
|
| 735 |
+
return problem_size
|
| 736 |
+
|
| 737 |
+
def run(self, A=None, B=None, C=None, D=None,
|
| 738 |
+
stride=(1, 1), padding=(0, 0), dilation=(1, 1),
|
| 739 |
+
alpha=None, beta=None,
|
| 740 |
+
split_k=("serial", 1), sync: bool = True,
|
| 741 |
+
print_module: bool = False,
|
| 742 |
+
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
| 743 |
+
"""
|
| 744 |
+
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
| 745 |
+
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
| 746 |
+
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
|
| 747 |
+
parameters provided in the call, or from those
|
| 748 |
+
passed in on the construction of this object -- one of the two must be specified.
|
| 749 |
+
|
| 750 |
+
By default, this call returns only once the kernel has completed. To launch the kernel
|
| 751 |
+
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
|
| 752 |
+
caller to syncrhonize the results of the kernel before attempting to access outputs
|
| 753 |
+
by calling ``sync()`` on the arguments returned from this call.
|
| 754 |
+
|
| 755 |
+
:param A: tensor representing data type and layout of operand A
|
| 756 |
+
:param B: tensor representing data type and layout of operand B
|
| 757 |
+
:param C: tensor representing data type and layout of operand C
|
| 758 |
+
:param D: tensor representing data type and layout of operand D
|
| 759 |
+
:param stride: (stride_h, stride_w) describing the convolution stride. Default: (1, 1)
|
| 760 |
+
:param padding: (pad_h, pad_w) describing the convolution padding. Default: (0, 0)
|
| 761 |
+
:param dilation: (dilation_h, dilation_w) describing the dilation of convolution. Default: (1, 1)
|
| 762 |
+
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
| 763 |
+
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
| 764 |
+
:param split_k: a tuple (split_k_mode, split_k_slices)
|
| 765 |
+
:param sync: whether the call should wait for the kernel to complete before returning
|
| 766 |
+
:type sync: bool
|
| 767 |
+
:param print_module: whether to print the emitted C++ code
|
| 768 |
+
:type print_module: bool
|
| 769 |
+
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
| 770 |
+
:type stream: :class:`cuda.cuda.CUstream`
|
| 771 |
+
|
| 772 |
+
:return: arguments passed in to the kernel
|
| 773 |
+
:rtype: cutlass_cppgen.backend.Conv2dArguments
|
| 774 |
+
"""
|
| 775 |
+
if not stream:
|
| 776 |
+
stream = cuda.CUstream(0)
|
| 777 |
+
super().run_setup()
|
| 778 |
+
|
| 779 |
+
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
|
| 780 |
+
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
|
| 781 |
+
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
|
| 782 |
+
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
|
| 783 |
+
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
| 784 |
+
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
| 785 |
+
|
| 786 |
+
# handle the case when there is no C
|
| 787 |
+
if C is None:
|
| 788 |
+
if beta != 0:
|
| 789 |
+
raise Exception(f"With beta {beta} != 0, C has to be provided.")
|
| 790 |
+
else:
|
| 791 |
+
C = D
|
| 792 |
+
|
| 793 |
+
# Construct problem size based on input
|
| 794 |
+
# It also verifies whether the A, B, C, D, stride, padding, and dilation are matching
|
| 795 |
+
problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation)
|
| 796 |
+
|
| 797 |
+
# Propose stride support based on input
|
| 798 |
+
stride_support = self._propose_stride_support(stride)
|
| 799 |
+
|
| 800 |
+
# Propose swizzling functor
|
| 801 |
+
swizzling_functor = self._propose_swizzling_functor(stride)
|
| 802 |
+
|
| 803 |
+
shape_a = datatypes.get_tensor_shape(A, op="CONV")
|
| 804 |
+
shape_b = datatypes.get_tensor_shape(B, op="CONV")
|
| 805 |
+
shape_c = datatypes.get_tensor_shape(C, op="CONV")
|
| 806 |
+
|
| 807 |
+
# Get the alignment
|
| 808 |
+
alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a, operand="A")
|
| 809 |
+
alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b, operand="B")
|
| 810 |
+
alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c, operand="C")
|
| 811 |
+
|
| 812 |
+
alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A)
|
| 813 |
+
alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B)
|
| 814 |
+
alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C)
|
| 815 |
+
|
| 816 |
+
# Propose iterator algorithm based on input
|
| 817 |
+
if self._iterator_algorithm is None:
|
| 818 |
+
# Propose a default iterator algorithm based on the problem size
|
| 819 |
+
iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b)
|
| 820 |
+
else:
|
| 821 |
+
if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)):
|
| 822 |
+
iterator_algorithm = self._iterator_algorithm
|
| 823 |
+
else:
|
| 824 |
+
raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.")
|
| 825 |
+
|
| 826 |
+
epilogue_args = [alpha, beta]
|
| 827 |
+
|
| 828 |
+
if hasattr(self, "_activation_args"):
|
| 829 |
+
if isinstance(self._activation_args, list):
|
| 830 |
+
epilogue_args += self._activation_args
|
| 831 |
+
else:
|
| 832 |
+
epilogue_args.append(self._activation_args)
|
| 833 |
+
|
| 834 |
+
if split_k[0] == "parallel" and split_k[1] > 1:
|
| 835 |
+
epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity)
|
| 836 |
+
else:
|
| 837 |
+
epilogue_functor = self.epilogue_functor
|
| 838 |
+
|
| 839 |
+
# The alignment is determined by the iterator function (I believe)
|
| 840 |
+
self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
| 841 |
+
alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support,
|
| 842 |
+
swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module)
|
| 843 |
+
|
| 844 |
+
# Create reduction operation for parallel split-k
|
| 845 |
+
if split_k[0] == "parallel" and split_k[1] > 1:
|
| 846 |
+
epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor)
|
| 847 |
+
self.reduction_operation = ReductionOperation(
|
| 848 |
+
shape=MatrixCoord(4, 32 * alignment_c), C=self.operation.C,
|
| 849 |
+
element_accumulator=self._element_accumulator,
|
| 850 |
+
element_compute=self._element_accumulator,
|
| 851 |
+
epilogue_functor=epilogue_functor_reduction,
|
| 852 |
+
count=alignment_c
|
| 853 |
+
)
|
| 854 |
+
if print_module:
|
| 855 |
+
print(self.reduction_operation.rt_module.emit())
|
| 856 |
+
compiler.add_module([self.reduction_operation,])
|
| 857 |
+
|
| 858 |
+
arguments = Conv2dArguments(
|
| 859 |
+
operation=self.operation, problem_size=problem_size,
|
| 860 |
+
A=A, B=B, C=C, D=D,
|
| 861 |
+
output_op=self.operation.epilogue_type(*epilogue_args),
|
| 862 |
+
split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]),
|
| 863 |
+
split_k_slices=split_k[1],
|
| 864 |
+
stream=stream
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
self.operation.run(arguments)
|
| 868 |
+
|
| 869 |
+
if split_k[0] == "parallel" and split_k[1] > 1:
|
| 870 |
+
implicit_gemm_size = arguments.problem_size.implicit_gemm_size(self.conv_kind)
|
| 871 |
+
reduction_arguments = ReductionArguments(
|
| 872 |
+
self.reduction_operation,
|
| 873 |
+
problem_size=[implicit_gemm_size.m, implicit_gemm_size.n],
|
| 874 |
+
partitions=split_k[1],
|
| 875 |
+
workspace=arguments.ptr_D,
|
| 876 |
+
destination=D,
|
| 877 |
+
source=C,
|
| 878 |
+
output_op=self.reduction_operation.epilogue_type(*epilogue_args),
|
| 879 |
+
stream=stream
|
| 880 |
+
)
|
| 881 |
+
self.reduction_operation.run(reduction_arguments)
|
| 882 |
+
|
| 883 |
+
if sync:
|
| 884 |
+
if split_k[0] == "parallel" and split_k[1] > 1:
|
| 885 |
+
reduction_arguments.sync()
|
| 886 |
+
|
| 887 |
+
# Free memory allocated by args because we are not
|
| 888 |
+
# calling `arguments.sync()` in this case (which will free memory)
|
| 889 |
+
arguments.free()
|
| 890 |
+
else:
|
| 891 |
+
arguments.sync()
|
| 892 |
+
|
| 893 |
+
return arguments
|
| 894 |
+
|
| 895 |
+
#
|
| 896 |
+
# Helper functions
|
| 897 |
+
#
|
| 898 |
+
@staticmethod
|
| 899 |
+
def output_size(input_size, weight_size, padding, stride, dilation):
|
| 900 |
+
problem_size = Conv2DProblemSize(
|
| 901 |
+
*input_size,
|
| 902 |
+
*weight_size,
|
| 903 |
+
padding[0], padding[1],
|
| 904 |
+
stride[0], stride[1],
|
| 905 |
+
dilation[0], dilation[1],
|
| 906 |
+
ConvMode.CrossCorrelation,
|
| 907 |
+
1, 1
|
| 908 |
+
)
|
| 909 |
+
return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K)
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
#
|
| 913 |
+
# Easy to use interfaces for fprop, wgrad, and dgrad
|
| 914 |
+
#
|
| 915 |
+
|
| 916 |
+
class Conv2dFprop(Conv2d):
|
| 917 |
+
def __init__(
|
| 918 |
+
self,
|
| 919 |
+
input=None, weight=None, C=None, output=None, alpha=1, beta=0,
|
| 920 |
+
element=None,
|
| 921 |
+
element_input=None, element_weight=None, element_C=None, element_output=None,
|
| 922 |
+
element_accumulator=None,
|
| 923 |
+
cc: int = None, kernel_cc: int = None):
|
| 924 |
+
A, B, D = input, weight, output
|
| 925 |
+
element_A, element_B, element_D = element_input, element_weight, element_output
|
| 926 |
+
super().__init__(
|
| 927 |
+
"fprop", A, B, C, D, alpha, beta, element,
|
| 928 |
+
element_A, element_B, element_C, element_D,
|
| 929 |
+
element_accumulator, cc, kernel_cc)
|
| 930 |
+
|
| 931 |
+
def run(
|
| 932 |
+
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
|
| 933 |
+
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
| 934 |
+
sync: bool = True, print_module: bool = False,
|
| 935 |
+
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
| 936 |
+
|
| 937 |
+
if not stream:
|
| 938 |
+
stream = cuda.CUstream(0)
|
| 939 |
+
|
| 940 |
+
A, B, D = input, weight, output
|
| 941 |
+
return super().run(
|
| 942 |
+
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
| 943 |
+
|
| 944 |
+
|
| 945 |
+
class Conv2dDgrad(Conv2d):
|
| 946 |
+
def __init__(
|
| 947 |
+
self,
|
| 948 |
+
grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0,
|
| 949 |
+
element=None,
|
| 950 |
+
element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None,
|
| 951 |
+
element_accumulator=None,
|
| 952 |
+
cc: int = None, kernel_cc: int = None):
|
| 953 |
+
A, B, D = grad_output, weight, grad_input
|
| 954 |
+
element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input
|
| 955 |
+
super().__init__(
|
| 956 |
+
"dgrad", A, B, C, D, alpha, beta, element,
|
| 957 |
+
element_A, element_B, element_C, element_D,
|
| 958 |
+
element_accumulator, cc, kernel_cc)
|
| 959 |
+
|
| 960 |
+
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
|
| 961 |
+
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
| 962 |
+
sync: bool = True, print_module: bool = False,
|
| 963 |
+
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
| 964 |
+
#
|
| 965 |
+
if not stream:
|
| 966 |
+
stream = cuda.CUstream(0)
|
| 967 |
+
|
| 968 |
+
A, B, D = grad_output, weight, grad_input
|
| 969 |
+
return super().run(
|
| 970 |
+
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
| 971 |
+
|
| 972 |
+
|
| 973 |
+
class Conv2dWgrad(Conv2d):
|
| 974 |
+
def __init__(
|
| 975 |
+
self,
|
| 976 |
+
grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0,
|
| 977 |
+
element=None,
|
| 978 |
+
element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None,
|
| 979 |
+
element_accumulator=None,
|
| 980 |
+
cc: int = None, kernel_cc: int = None):
|
| 981 |
+
A, B, D = grad_output, input, grad_weight
|
| 982 |
+
element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight
|
| 983 |
+
super().__init__(
|
| 984 |
+
"wgrad", A, B, C, D, alpha, beta, element,
|
| 985 |
+
element_A, element_B, element_C, element_D,
|
| 986 |
+
element_accumulator, cc, kernel_cc)
|
| 987 |
+
|
| 988 |
+
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
|
| 989 |
+
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
| 990 |
+
sync: bool = True, print_module: bool = False,
|
| 991 |
+
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
| 992 |
+
if not stream:
|
| 993 |
+
stream = cuda.CUstream(0)
|
| 994 |
+
|
| 995 |
+
A, B, D = grad_output, input, grad_weight
|
| 996 |
+
return super().run(
|
| 997 |
+
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py
ADDED
|
@@ -0,0 +1,725 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Ease-of-use interface for constructing, compiling, and running GEMMs.
|
| 35 |
+
|
| 36 |
+
The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run
|
| 37 |
+
GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
|
| 38 |
+
Under the hood, the interface will select sensible default parameters for the many template
|
| 39 |
+
parameters for CUTLASS GEMMs.
|
| 40 |
+
|
| 41 |
+
Note: optimal performance is not to be expected from this interface. To achieve optimal
|
| 42 |
+
performance, one should specify and tune each configuration parameter.
|
| 43 |
+
|
| 44 |
+
The simplest example of using this interface is the following:
|
| 45 |
+
|
| 46 |
+
.. highlight:: python
|
| 47 |
+
.. code-block:: python
|
| 48 |
+
|
| 49 |
+
# A, B, C, and D are torch/numpy/cupy tensor objects
|
| 50 |
+
plan = cutlass_cppgen.op.Gemm(A, B, C, D)
|
| 51 |
+
plan.run()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
One can also use the interface by specifying data types of operands at construction
|
| 55 |
+
and using different tensor objects with these data types at runtime:
|
| 56 |
+
|
| 57 |
+
.. highlight:: python
|
| 58 |
+
.. code-block:: python
|
| 59 |
+
|
| 60 |
+
# The following is shorthand for:
|
| 61 |
+
# cutlass_cppgen.op.Gemm(element_A=torch.float32, element_B=torch.float32,
|
| 62 |
+
# element_C=torch.float32, element_D=torch.float32,
|
| 63 |
+
# element_accumulator=torch.float32,
|
| 64 |
+
# layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 65 |
+
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 66 |
+
|
| 67 |
+
A0 = torch.rand((128, 256), device='cuda')
|
| 68 |
+
B0 = torch.rand((256, 64), device='cuda')
|
| 69 |
+
C0 = torch.zeros((128, 64), device='cuda')
|
| 70 |
+
D0 = torch.zeros((128, 64), device.'cuda')
|
| 71 |
+
plan.run(A0, B0, C0, D0)
|
| 72 |
+
|
| 73 |
+
A = torch.rand((32, 128), device='cuda')
|
| 74 |
+
B = torch.rand((128, 256), device='cuda')
|
| 75 |
+
C = torch.zeros((32, 256), device='cuda')
|
| 76 |
+
D = torch.zeros((32, 256), device.'cuda')
|
| 77 |
+
plan.run(A1, B1, C1, D1)
|
| 78 |
+
|
| 79 |
+
The interface additionally enables one to decouple the compilation of the underlying CUTLASS
|
| 80 |
+
kernel from its execution:
|
| 81 |
+
|
| 82 |
+
.. highlight:: python
|
| 83 |
+
.. code-block:: python
|
| 84 |
+
|
| 85 |
+
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 86 |
+
plan.compile()
|
| 87 |
+
|
| 88 |
+
# Do other work...
|
| 89 |
+
|
| 90 |
+
plan.run(A0, B0, C0, D0)
|
| 91 |
+
|
| 92 |
+
# Do other work...
|
| 93 |
+
|
| 94 |
+
plan.run(A1, B1, C1, D1)
|
| 95 |
+
|
| 96 |
+
Elementwise activation functions are easily fused to the GEMM via the interface:
|
| 97 |
+
|
| 98 |
+
.. highlight:: python
|
| 99 |
+
.. code-block:: python
|
| 100 |
+
|
| 101 |
+
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 102 |
+
plan.activation = cutlass_cppgen.epilogue.relu
|
| 103 |
+
|
| 104 |
+
Operations can also be run asynchronously:
|
| 105 |
+
|
| 106 |
+
.. highlight:: python
|
| 107 |
+
.. code-block:: python
|
| 108 |
+
|
| 109 |
+
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 110 |
+
args = plan.run()
|
| 111 |
+
|
| 112 |
+
# Do other work...
|
| 113 |
+
|
| 114 |
+
args.sync()
|
| 115 |
+
"""
|
| 116 |
+
from __future__ import annotations
|
| 117 |
+
from typing import Optional
|
| 118 |
+
from math import prod
|
| 119 |
+
|
| 120 |
+
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 121 |
+
cuda = lazy_import("cuda.cuda")
|
| 122 |
+
from cutlass_library import (
|
| 123 |
+
DataType,
|
| 124 |
+
DataTypeSize,
|
| 125 |
+
GemmUniversalMode,
|
| 126 |
+
KernelScheduleSuffixes,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
import cutlass_cppgen
|
| 130 |
+
from cutlass_cppgen import epilogue, swizzle
|
| 131 |
+
from cutlass_cppgen.backend import compiler
|
| 132 |
+
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
|
| 133 |
+
from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal
|
| 134 |
+
from cutlass_cppgen.backend.library import TensorDescription, TileDescription
|
| 135 |
+
from cutlass_cppgen.op.op import OperationBase
|
| 136 |
+
from cutlass_cppgen.shape import GemmCoord
|
| 137 |
+
from cutlass_cppgen.utils import check, datatypes
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class Gemm(OperationBase):
|
| 141 |
+
"""
|
| 142 |
+
Constructs a ``Gemm`` object.
|
| 143 |
+
|
| 144 |
+
The data types and layouts of operands A, B, and C, along with the data type of output D
|
| 145 |
+
and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime --
|
| 146 |
+
these are not to be changed after a ``Gemm`` has been constructed.
|
| 147 |
+
|
| 148 |
+
The constructor has optional parameters for flexibly setting these parameters. The following
|
| 149 |
+
constructors are equivalent:
|
| 150 |
+
|
| 151 |
+
.. highlight:: python
|
| 152 |
+
.. code-block:: python
|
| 153 |
+
|
| 154 |
+
# Use F32 for A, B, C, D, and accumulation. All operands are row major.
|
| 155 |
+
|
| 156 |
+
# Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts
|
| 157 |
+
# for operands to the same values.
|
| 158 |
+
Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 159 |
+
|
| 160 |
+
# Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``.
|
| 161 |
+
Gemm(element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, element_C=cutlass_cppgen.DataType.f32,
|
| 162 |
+
element_D=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 163 |
+
|
| 164 |
+
# Set the data types and elements from existing tensors. Note that one can use different tensors when
|
| 165 |
+
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
|
| 166 |
+
# have the same data type and layout as those passed in here).
|
| 167 |
+
# A, B, C, and D are row-major torch.Tensor objects of type torch.float32
|
| 168 |
+
Gemm(A=A, B=B, C=C, D=D)
|
| 169 |
+
|
| 170 |
+
# Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is
|
| 171 |
+
# the same as that for D, at present)
|
| 172 |
+
Gemm(element=cutlass_cppgen.DataType.f32, layout_A=cutlass_cppgen.LayoutType.RowMajor,
|
| 173 |
+
layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor)
|
| 174 |
+
|
| 175 |
+
# Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types
|
| 176 |
+
# and layouts will inherit those passed in via the generic ``element`` and ``layout``
|
| 177 |
+
Gemm(element_A=cutlass_cppgen.DataType.f32, layout_B=cutlass_cppgen.LayoutType.RowMajor,
|
| 178 |
+
element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 179 |
+
|
| 180 |
+
The order of precedence for the setting of the data type and layout for a given operand/output is as follows:
|
| 181 |
+
1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor
|
| 182 |
+
2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those
|
| 183 |
+
3) Otherwise, use the generic values (e.g., ``element``, ``layout``)
|
| 184 |
+
|
| 185 |
+
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
| 186 |
+
:type cc: int
|
| 187 |
+
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
| 188 |
+
:type kernel_cc: int
|
| 189 |
+
:param A: tensor representing data type and layout of operand A
|
| 190 |
+
:param B: tensor representing data type and layout of operand B
|
| 191 |
+
:param C: tensor representing data type and layout of operand C
|
| 192 |
+
:param D: tensor representing data type and layout of operand D
|
| 193 |
+
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
| 194 |
+
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
| 195 |
+
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
| 196 |
+
:type element_accumulator: cutlass_cppgen.DataType
|
| 197 |
+
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
| 198 |
+
:type element: cutlass_cppgen.DataType
|
| 199 |
+
:param layout: generic layout type to be used for operands A, B, C, and D
|
| 200 |
+
:type layout: cutlass_cppgen.LayoutType
|
| 201 |
+
:param element_A: data type to be used for operand A
|
| 202 |
+
:type element_A: cutlass_cppgen.DataType
|
| 203 |
+
:param element_B: data type to be used for operand B
|
| 204 |
+
:type element_B: cutlass_cppgen.DataType
|
| 205 |
+
:param element_C: data type to be used for operand C
|
| 206 |
+
:type element_C: cutlass_cppgen.DataType
|
| 207 |
+
:param element_D: data type to be used for operand D
|
| 208 |
+
:type element_D: cutlass_cppgen.DataType
|
| 209 |
+
:param layout_A: layout of operand A
|
| 210 |
+
:type layout_A: cutlass_cppgen.LayoutType
|
| 211 |
+
:param layout_B: layout of operand B
|
| 212 |
+
:type layout_B: cutlass_cppgen.LayoutType
|
| 213 |
+
:param layout_C: layout of operand C
|
| 214 |
+
:type layout_C: cutlass_cppgen.LayoutType
|
| 215 |
+
:param layout_D: layout of operand D
|
| 216 |
+
:type layout_D: cutlass_cppgen.LayoutType
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
def __init__(
|
| 220 |
+
self, A=None, B=None, C=None, D=None,
|
| 221 |
+
alpha=1.0, beta=0.0, element_accumulator=None,
|
| 222 |
+
element=None, layout=None,
|
| 223 |
+
element_A=None, element_B=None, element_C=None, element_D=None,
|
| 224 |
+
layout_A=None, layout_B=None, layout_C=None,
|
| 225 |
+
cc: int = None, kernel_cc: int = None
|
| 226 |
+
):
|
| 227 |
+
super().__init__(cc=cc, kernel_cc=kernel_cc)
|
| 228 |
+
self.name = "gemm"
|
| 229 |
+
self.compiled = False
|
| 230 |
+
|
| 231 |
+
elements = []
|
| 232 |
+
layouts = []
|
| 233 |
+
|
| 234 |
+
# Check that at least one of the following is set for each tensor (illustrated assuming tensor A):
|
| 235 |
+
# ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout``
|
| 236 |
+
for elt, lay, tens, name in zip([element_A, element_B, element_C, element_D],
|
| 237 |
+
[layout_A, layout_B, layout_C, layout_C],
|
| 238 |
+
[A, B, C, D],
|
| 239 |
+
["A", "B", "C", "D"]):
|
| 240 |
+
if elt is not None and tens is not None:
|
| 241 |
+
raise Exception(f'Must not specify both element_{name} and tensor {name}')
|
| 242 |
+
if lay is not None and tens is not None:
|
| 243 |
+
raise Exception(f'Must not specify both layout_{name} and tensor {name}')
|
| 244 |
+
if elt is None and tens is None and element is None:
|
| 245 |
+
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
|
| 246 |
+
if lay is None and tens is None and layout is None:
|
| 247 |
+
raise Exception(f'Must specify one of layout_{name}, tensor {name}, or generic layout.')
|
| 248 |
+
|
| 249 |
+
elt_to_set = None
|
| 250 |
+
lay_to_set = None
|
| 251 |
+
if tens is not None:
|
| 252 |
+
elt_to_set, lay_to_set = datatypes.get_datatype_and_layout(tens)
|
| 253 |
+
else:
|
| 254 |
+
elt_to_set = elt if elt is not None else element
|
| 255 |
+
lay_to_set = lay if lay is not None else layout
|
| 256 |
+
|
| 257 |
+
elements.append(datatypes.library_type(elt_to_set))
|
| 258 |
+
layouts.append(lay_to_set)
|
| 259 |
+
|
| 260 |
+
self._element_a, self._element_b, self._element_c, self._element_d = elements
|
| 261 |
+
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
|
| 262 |
+
|
| 263 |
+
if element_accumulator is None:
|
| 264 |
+
self._element_accumulator = self._element_c
|
| 265 |
+
else:
|
| 266 |
+
self._element_accumulator = datatypes.library_type(element_accumulator)
|
| 267 |
+
|
| 268 |
+
self.A = A
|
| 269 |
+
self.B = B
|
| 270 |
+
self.C = C
|
| 271 |
+
self.D = D
|
| 272 |
+
|
| 273 |
+
self.alpha = alpha
|
| 274 |
+
self.beta = beta
|
| 275 |
+
|
| 276 |
+
self.epilogue_functor = None
|
| 277 |
+
self.op_class = None
|
| 278 |
+
self._tile_description = None
|
| 279 |
+
|
| 280 |
+
self._reset_operations()
|
| 281 |
+
|
| 282 |
+
self._swizzling_functor = cutlass_cppgen.swizzle.IdentitySwizzle1
|
| 283 |
+
|
| 284 |
+
def _reset_operations(self, reset_epilogue: bool = True):
|
| 285 |
+
# Set the default op class
|
| 286 |
+
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
|
| 287 |
+
layout_comb = (self._layout_a, self._layout_b)
|
| 288 |
+
|
| 289 |
+
self.possible_op_classes = self.options.supporting_opclasses(
|
| 290 |
+
self._element_a, self._element_b, self._element_accumulator,
|
| 291 |
+
self._layout_a, self._layout_b, self._math_operation)
|
| 292 |
+
|
| 293 |
+
if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
|
| 294 |
+
self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
|
| 295 |
+
elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
|
| 296 |
+
self.opclass = cutlass_cppgen.OpcodeClass.Simt
|
| 297 |
+
else:
|
| 298 |
+
if self._math_operation is not None:
|
| 299 |
+
math_op_str = f' and math operation {self._math_operation}'
|
| 300 |
+
else:
|
| 301 |
+
math_op_str = ''
|
| 302 |
+
|
| 303 |
+
raise Exception(f'No kernel configuration found for supported data type and layout '
|
| 304 |
+
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
|
| 305 |
+
|
| 306 |
+
if reset_epilogue:
|
| 307 |
+
self._reset_epilogue_functor_activation(cutlass_cppgen.epilogue.identity)
|
| 308 |
+
|
| 309 |
+
@property
|
| 310 |
+
def swizzling_functor(self):
|
| 311 |
+
"""
|
| 312 |
+
Returns the type of the swizzling functor currently being used by the GEMM
|
| 313 |
+
|
| 314 |
+
:return: swizzing functor type
|
| 315 |
+
"""
|
| 316 |
+
return self._swizzling_functor
|
| 317 |
+
|
| 318 |
+
@swizzling_functor.setter
|
| 319 |
+
def swizzling_functor(self, swizzling_functor):
|
| 320 |
+
"""
|
| 321 |
+
Sets the swizzling functor to the type specified by `swizzling_functor`
|
| 322 |
+
"""
|
| 323 |
+
if swizzling_functor == cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK:
|
| 324 |
+
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
|
| 325 |
+
raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp')
|
| 326 |
+
|
| 327 |
+
if self.current_cc in [90, 100, 101, 103]:
|
| 328 |
+
raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90+')
|
| 329 |
+
self._swizzling_functor = swizzling_functor
|
| 330 |
+
|
| 331 |
+
#
|
| 332 |
+
# Tile description Related
|
| 333 |
+
#
|
| 334 |
+
|
| 335 |
+
@property
|
| 336 |
+
def tile_description(self) -> TileDescription:
|
| 337 |
+
"""
|
| 338 |
+
Returns the tile description
|
| 339 |
+
"""
|
| 340 |
+
return self._tile_description
|
| 341 |
+
|
| 342 |
+
@tile_description.setter
|
| 343 |
+
def tile_description(
|
| 344 |
+
self, td=None):
|
| 345 |
+
"""
|
| 346 |
+
Set the tile description
|
| 347 |
+
|
| 348 |
+
:param td: tile description
|
| 349 |
+
:type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
|
| 350 |
+
{
|
| 351 |
+
"threadblock_shape": [int, int, int],
|
| 352 |
+
"warp_count": [int, int, int],
|
| 353 |
+
"stages": int,
|
| 354 |
+
"instruction_shape": [int, int, int] (optional),
|
| 355 |
+
"cluster_shape": [int, int, int] (optional)
|
| 356 |
+
}
|
| 357 |
+
"""
|
| 358 |
+
if td is None:
|
| 359 |
+
return
|
| 360 |
+
if isinstance(td, dict):
|
| 361 |
+
if self._tile_description is None:
|
| 362 |
+
op = self.possible_operations.default_operation(self._math_operation)
|
| 363 |
+
self._tile_description = datatypes.td_from_profiler_op(op)
|
| 364 |
+
td = self._tile_description.clone_and_update(td)
|
| 365 |
+
|
| 366 |
+
valid, msg = self._valid_tile_description(td)
|
| 367 |
+
if valid:
|
| 368 |
+
self._tile_description = td
|
| 369 |
+
else:
|
| 370 |
+
raise Exception(msg)
|
| 371 |
+
|
| 372 |
+
def _valid_tile_description(self, td: TileDescription) -> tuple:
|
| 373 |
+
"""
|
| 374 |
+
Checks whether the provided tile description is valid for the given compute capability. At present,
|
| 375 |
+
this checks the following:
|
| 376 |
+
|
| 377 |
+
- Does the tile description use a number of stages supported by the compute capability in question?
|
| 378 |
+
- Does the tile size requested fit within shared memory?
|
| 379 |
+
- Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
|
| 380 |
+
more non-unit cluster dimensions for pre-SM90 architectures)?
|
| 381 |
+
- Is the kernel schedule being used supported on the architecture in question?
|
| 382 |
+
|
| 383 |
+
:param td: tile description to validate
|
| 384 |
+
:type td: cutlass_cppgen.backend.TileDescription
|
| 385 |
+
:return: tuple in which the first element is a bool indicating that the tile description is valid
|
| 386 |
+
and the second element is a string providing an optional error message.
|
| 387 |
+
:rtype: tuple
|
| 388 |
+
"""
|
| 389 |
+
valid, msg = check.valid_stage_count(self.cc, self.current_cc, td, self._element_c, self._element_d)
|
| 390 |
+
if not valid:
|
| 391 |
+
return (valid, msg)
|
| 392 |
+
|
| 393 |
+
valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
|
| 394 |
+
if not valid:
|
| 395 |
+
return (valid, msg)
|
| 396 |
+
|
| 397 |
+
valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler)
|
| 398 |
+
|
| 399 |
+
if self.cc in [100, 101, 103] and td.kernel_schedule is not None and td.is_2sm and td.cluster_shape[0] % 2 != 0:
|
| 400 |
+
valid = False
|
| 401 |
+
msg = "Cluster shape must be divisible by 2 for 2SM kernels on SM100, SM101, and SM103"
|
| 402 |
+
|
| 403 |
+
return valid, msg
|
| 404 |
+
|
| 405 |
+
def tile_descriptions(self) -> list:
|
| 406 |
+
"""
|
| 407 |
+
Returns a list of valid tile descriptions for the operations
|
| 408 |
+
|
| 409 |
+
:returns: list of valid tile descriptions for the operations
|
| 410 |
+
:rtype: list
|
| 411 |
+
"""
|
| 412 |
+
tds = [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations]
|
| 413 |
+
if self._math_operation is not None:
|
| 414 |
+
tds = [td for td in tds if td.math_instruction.math_operation == self._math_operation]
|
| 415 |
+
return tds
|
| 416 |
+
|
| 417 |
+
def construct(
|
| 418 |
+
self, tile_description: TileDescription = None,
|
| 419 |
+
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal:
|
| 420 |
+
"""
|
| 421 |
+
Constructs a ``cutlass_cppgen.backend.GemmUniversalOperation`` based on the input parameters and current
|
| 422 |
+
kernel specification of the ``Gemm`` object.
|
| 423 |
+
|
| 424 |
+
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
| 425 |
+
:type tile_description: cutlass_cppgen.backend.TileDescription
|
| 426 |
+
:param alignment_A: alignment of operand A
|
| 427 |
+
:type alignment_A: int
|
| 428 |
+
:param alignment_B: alignment of operand B
|
| 429 |
+
:type alignment_B: int
|
| 430 |
+
:param alignment_C: alignment of operand C
|
| 431 |
+
:type alignment_C: int
|
| 432 |
+
|
| 433 |
+
:return: operation that was constructed
|
| 434 |
+
:rtype: cutlass_cppgen.backend.GemmOperationUniversal
|
| 435 |
+
"""
|
| 436 |
+
alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
|
| 437 |
+
alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
|
| 438 |
+
alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A)
|
| 439 |
+
alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B)
|
| 440 |
+
|
| 441 |
+
tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A)
|
| 442 |
+
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
| 443 |
+
|
| 444 |
+
if alignment_C is None:
|
| 445 |
+
alignment_C = max(self.possible_operations.alignments("C"))
|
| 446 |
+
if self._element_c != DataType.void:
|
| 447 |
+
alignment_C = min(128 // DataTypeSize[self._element_c], alignment_C)
|
| 448 |
+
|
| 449 |
+
if tile_description is None:
|
| 450 |
+
if self._tile_description is None:
|
| 451 |
+
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
|
| 452 |
+
tile_description = datatypes.td_from_profiler_op(op)
|
| 453 |
+
|
| 454 |
+
# The selected op may have lower alignment than that determined above, so we must
|
| 455 |
+
# reset alignment here.
|
| 456 |
+
alignment_C = op.C.alignment
|
| 457 |
+
else:
|
| 458 |
+
tile_description = self._tile_description
|
| 459 |
+
else:
|
| 460 |
+
valid, err_str = self._valid_tile_description(tile_description)
|
| 461 |
+
if not valid:
|
| 462 |
+
raise Exception(f"Invalid tile description. {err_str}")
|
| 463 |
+
self._tile_description = tile_description
|
| 464 |
+
|
| 465 |
+
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
| 466 |
+
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
|
| 467 |
+
|
| 468 |
+
operation = GemmOperationUniversal(
|
| 469 |
+
arch=self.current_cc,
|
| 470 |
+
tile_description=tile_description,
|
| 471 |
+
A=tensor_A, B=tensor_B, C=tensor_C,
|
| 472 |
+
epilogue_functor=self.epilogue_functor,
|
| 473 |
+
swizzling_functor=self._swizzling_functor,
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
return operation
|
| 477 |
+
|
| 478 |
+
def compile(self, tile_description: TileDescription = None,
|
| 479 |
+
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
| 480 |
+
print_module: bool = False) -> cutlass_cppgen.backend.GemmOperationUniversal:
|
| 481 |
+
"""
|
| 482 |
+
Emits and compiles the kernel currently specified. If ``tile_description`` and any
|
| 483 |
+
of the ``alignment`` parameters are set, the kernel will be chosen using this
|
| 484 |
+
tile description and alignments. Otherwise, a default tile description and alignment
|
| 485 |
+
will be used.
|
| 486 |
+
|
| 487 |
+
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
| 488 |
+
:type tile_description: cutlass_cppgen.backend.TileDescription
|
| 489 |
+
:param alignment_A: alignment of operand A
|
| 490 |
+
:type alignment_A: int
|
| 491 |
+
:param alignment_B: alignment of operand B
|
| 492 |
+
:type alignment_B: int
|
| 493 |
+
:param alignment_C: alignment of operand C
|
| 494 |
+
:type alignment_C: int
|
| 495 |
+
:param print_module: whether to print the emitted C++ code
|
| 496 |
+
:type print_module: bool
|
| 497 |
+
|
| 498 |
+
:return: operation that was compiled
|
| 499 |
+
:rtype: cutlass_cppgen.backend.GemmOperationUniversal
|
| 500 |
+
"""
|
| 501 |
+
self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C)
|
| 502 |
+
|
| 503 |
+
if print_module:
|
| 504 |
+
print(self.operation.rt_module.emit())
|
| 505 |
+
|
| 506 |
+
compiler.add_module([self.operation,])
|
| 507 |
+
return self.operation
|
| 508 |
+
|
| 509 |
+
def _verify_rank(self, tensor):
|
| 510 |
+
"""
|
| 511 |
+
Verifies that ``tensor`` has rank greater than 1
|
| 512 |
+
|
| 513 |
+
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
| 514 |
+
:type tensor: numpy/cupy/torch array/tensor object
|
| 515 |
+
"""
|
| 516 |
+
if len(tensor.shape) < 2:
|
| 517 |
+
raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}")
|
| 518 |
+
|
| 519 |
+
def _get_batch_count(self, A, B, C, D) -> int:
|
| 520 |
+
"""
|
| 521 |
+
Returns the batch count specified by the tensors A, B, C, and D and verifies that these
|
| 522 |
+
tensors match in batch size. Presence of a batch dimension is detected by one of the
|
| 523 |
+
tensors being rank 3. If a batch dimension is present, it must be present in one of
|
| 524 |
+
operands A, B, or C (but need not be in all), and must be present in D.
|
| 525 |
+
|
| 526 |
+
:param A: tensor A
|
| 527 |
+
:type A: numpy/cupy/torch array/tensor object
|
| 528 |
+
:param B: tensor B
|
| 529 |
+
:type B: numpy/cupy/torch array/tensor object
|
| 530 |
+
:param C: tensor C
|
| 531 |
+
:type C: numpy/cupy/torch array/tensor object
|
| 532 |
+
:param D: tensor D
|
| 533 |
+
:type D: numpy/cupy/torch array/tensor object
|
| 534 |
+
|
| 535 |
+
:return: tuple of batch count dimensions
|
| 536 |
+
:rtype: tuple
|
| 537 |
+
"""
|
| 538 |
+
A_batch = prod(A.shape[:-2]) if len(A.shape) > 2 else 1
|
| 539 |
+
B_batch = prod(B.shape[:-2]) if len(B.shape) > 2 else 1
|
| 540 |
+
|
| 541 |
+
if 1 not in [A_batch, B_batch]:
|
| 542 |
+
if A_batch != B_batch:
|
| 543 |
+
raise Exception(f"Get invalid batch counts: A={A_batch}, B={B_batch}")
|
| 544 |
+
return max(A_batch, B_batch)
|
| 545 |
+
|
| 546 |
+
def _get_batch_stride(self, tensor) -> int:
|
| 547 |
+
"""
|
| 548 |
+
Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0.
|
| 549 |
+
|
| 550 |
+
:param tensor: tensor object to process
|
| 551 |
+
:type tensor: numpy/cupy/torch array/tensor object
|
| 552 |
+
|
| 553 |
+
:return: stride between each matrix in the batch
|
| 554 |
+
:rtype: int
|
| 555 |
+
"""
|
| 556 |
+
if tensor is not None and len(tensor.shape) > 2:
|
| 557 |
+
return tensor.shape[-2] * tensor.shape[-1]
|
| 558 |
+
else:
|
| 559 |
+
return 0
|
| 560 |
+
|
| 561 |
+
def _get_problem_args(self, A, B, C, D) -> tuple:
|
| 562 |
+
"""
|
| 563 |
+
Returns the problem size and GEMM universal mode to use for the
|
| 564 |
+
given operands.
|
| 565 |
+
|
| 566 |
+
:param A: tensor A
|
| 567 |
+
:type A: numpy/cupy/torch array/tensor object
|
| 568 |
+
:param B: tensor B
|
| 569 |
+
:type B: numpy/cupy/torch array/tensor object
|
| 570 |
+
:param C: tensor C
|
| 571 |
+
:type C: numpy/cupy/torch array/tensor object
|
| 572 |
+
:param D: tensor D
|
| 573 |
+
:type D: numpy/cupy/torch array/tensor object
|
| 574 |
+
|
| 575 |
+
:return: tuple containing the problem size (cutlass_cppgen.shape.GemmCoord), the GEMM mode (cutlass_cppgen.GemmUniversalMode), and the batch count (int)
|
| 576 |
+
:rtype: tuple
|
| 577 |
+
"""
|
| 578 |
+
M, K = A.shape[-2:]
|
| 579 |
+
N = B.shape[-1]
|
| 580 |
+
mode = GemmUniversalMode.Gemm
|
| 581 |
+
|
| 582 |
+
batch_count = self._get_batch_count(A, B, C, D)
|
| 583 |
+
returned_batch_count = batch_count
|
| 584 |
+
|
| 585 |
+
# If we are running a batched GEMM in which there is a nonzero batch stride
|
| 586 |
+
# only for A, then we can fold the batched dimension of A into the M dimension
|
| 587 |
+
# (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A
|
| 588 |
+
# and C are row major. A similar operation can be performed if only B has a nonzero
|
| 589 |
+
# batch dimension
|
| 590 |
+
if batch_count > 1:
|
| 591 |
+
A_row = self._layout_a == cutlass_cppgen.LayoutType.RowMajor
|
| 592 |
+
B_row = self._layout_b == cutlass_cppgen.LayoutType.RowMajor
|
| 593 |
+
C_row = self._layout_c == cutlass_cppgen.LayoutType.RowMajor
|
| 594 |
+
|
| 595 |
+
# Consider a Tensor to be batched if its rank is > 2 and
|
| 596 |
+
# the product of the modes beyond rank 2 equals our pre-determined batch size.
|
| 597 |
+
batched = lambda x : x is None or (len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count)
|
| 598 |
+
|
| 599 |
+
if batched(A) and not batched(B) and (C is None or batched(C)) and A_row and C_row:
|
| 600 |
+
M *= batch_count
|
| 601 |
+
returned_batch_count = 1
|
| 602 |
+
elif not batched(A) and batched(B) and (C is None or batched(C)) and not B_row and not C_row:
|
| 603 |
+
N *= batch_count
|
| 604 |
+
returned_batch_count = 1
|
| 605 |
+
else:
|
| 606 |
+
mode = GemmUniversalMode.Batched
|
| 607 |
+
|
| 608 |
+
return GemmCoord(M, N, K), mode, returned_batch_count
|
| 609 |
+
|
| 610 |
+
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
|
| 611 |
+
"""
|
| 612 |
+
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
|
| 613 |
+
is raised if it does not.
|
| 614 |
+
|
| 615 |
+
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
| 616 |
+
:type tensor: numpy/cupy/torch array/tensor object
|
| 617 |
+
:param ref_dtype: data type for the tensor that this object was initialized to
|
| 618 |
+
:param ref_layout: layout for the tensor that this object was initialized to
|
| 619 |
+
:param name: identifier of the tensor to verify. Used in raising exceptions
|
| 620 |
+
:type name: str
|
| 621 |
+
"""
|
| 622 |
+
dtype, layout = datatypes.get_datatype_and_layout(tensor)
|
| 623 |
+
if dtype != ref_type or layout != ref_layout:
|
| 624 |
+
try:
|
| 625 |
+
# Attempt to transpose the tensor to fit the desired layout
|
| 626 |
+
tensor = tensor.transpose(-1, -2)
|
| 627 |
+
except:
|
| 628 |
+
raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) '
|
| 629 |
+
f'does not match the expected type and '
|
| 630 |
+
f'layout of ({ref_type}, {ref_layout}) and transpose failed.')
|
| 631 |
+
|
| 632 |
+
def run(self, A=None, B=None, C=None, D=None,
|
| 633 |
+
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None,
|
| 634 |
+
stream: Optional[cuda.CUstream] = None) -> GemmArguments:
|
| 635 |
+
"""
|
| 636 |
+
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
| 637 |
+
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
| 638 |
+
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
|
| 639 |
+
parameters provided in this call, or from those
|
| 640 |
+
passed in on the construction of this object -- one of the two must be specified.
|
| 641 |
+
|
| 642 |
+
By default, this call returns only once the kernel has completed. To launch the kernel
|
| 643 |
+
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
|
| 644 |
+
caller to syncrhonize the results of the kernel before attempting to access outputs
|
| 645 |
+
by calling ``sync()`` on the arguments returned from this call.
|
| 646 |
+
|
| 647 |
+
:param A: tensor representing data type and layout of operand A
|
| 648 |
+
:param B: tensor representing data type and layout of operand B
|
| 649 |
+
:param C: tensor representing data type and layout of operand C
|
| 650 |
+
:param D: tensor representing data type and layout of operand D
|
| 651 |
+
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
| 652 |
+
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
| 653 |
+
:param sync: whether the call should wait for the kernel to complete before returning
|
| 654 |
+
:type sync: bool
|
| 655 |
+
:param print_module: whether to print the emitted C++ code
|
| 656 |
+
:type print_module: bool
|
| 657 |
+
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
| 658 |
+
:type stream: :class:`cuda.cuda.CUstream`
|
| 659 |
+
|
| 660 |
+
:return: arguments passed in to the kernel
|
| 661 |
+
:rtype: cutlass_cppgen.backend.GemmArguments
|
| 662 |
+
"""
|
| 663 |
+
if not stream:
|
| 664 |
+
stream = cuda.CUstream(0)
|
| 665 |
+
super().run_setup()
|
| 666 |
+
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
|
| 667 |
+
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
|
| 668 |
+
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
|
| 669 |
+
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
|
| 670 |
+
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
| 671 |
+
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
| 672 |
+
|
| 673 |
+
is_void_c = self._element_c == DataType.void
|
| 674 |
+
|
| 675 |
+
self._verify_rank(A)
|
| 676 |
+
self._verify_rank(B)
|
| 677 |
+
if not is_void_c:
|
| 678 |
+
self._verify_rank(C)
|
| 679 |
+
self._verify_rank(D)
|
| 680 |
+
|
| 681 |
+
alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A")
|
| 682 |
+
alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B")
|
| 683 |
+
|
| 684 |
+
# Set C alignment based on D.shape so as to correctly get an alignment with void-C
|
| 685 |
+
# kernels, for which `C` is None.
|
| 686 |
+
alignment_c = self.possible_operations.find_alignment(D.shape, self._layout_c, operand="C")
|
| 687 |
+
self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
| 688 |
+
alignment_C=alignment_c, print_module=print_module)
|
| 689 |
+
|
| 690 |
+
problem_size, mode, batch_count = self._get_problem_args(A, B, C, D)
|
| 691 |
+
|
| 692 |
+
if mode == GemmUniversalMode.Gemm or batch_count == 1:
|
| 693 |
+
kwargs = {'split_k_slices': 1}
|
| 694 |
+
else:
|
| 695 |
+
kwargs = {
|
| 696 |
+
'batch': batch_count,
|
| 697 |
+
'batch_strides': {
|
| 698 |
+
'A': self._get_batch_stride(A),
|
| 699 |
+
'B': self._get_batch_stride(B),
|
| 700 |
+
'C': self._get_batch_stride(C),
|
| 701 |
+
'D': self._get_batch_stride(D)
|
| 702 |
+
}
|
| 703 |
+
}
|
| 704 |
+
|
| 705 |
+
kwargs['stream'] = stream
|
| 706 |
+
|
| 707 |
+
if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
|
| 708 |
+
output_op = self.operation.epilogue_type(visitor_args)
|
| 709 |
+
else:
|
| 710 |
+
output_op = self.operation.epilogue_type(alpha, beta)
|
| 711 |
+
|
| 712 |
+
arguments = GemmArguments(
|
| 713 |
+
operation=self.operation, problem_size=problem_size,
|
| 714 |
+
A=A, B=B, C=C, D=D,
|
| 715 |
+
output_op=output_op,
|
| 716 |
+
gemm_mode=mode,
|
| 717 |
+
**kwargs
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
self.operation.run(arguments)
|
| 721 |
+
|
| 722 |
+
if sync:
|
| 723 |
+
arguments.sync()
|
| 724 |
+
|
| 725 |
+
return arguments
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Ease-of-use interface for constructing, compiling, and running GEMMs.
|
| 35 |
+
|
| 36 |
+
The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run
|
| 37 |
+
grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
|
| 38 |
+
Under the hood, the interface will select sensible default parameters for the many template
|
| 39 |
+
parameters for CUTLASS grouped GEMMs.
|
| 40 |
+
|
| 41 |
+
Note: optimal performance is not to be expected from this interface. To achieve optimal
|
| 42 |
+
performance, one should specify and tune each configuration parameter.
|
| 43 |
+
|
| 44 |
+
The simplest example of using this interface is the following:
|
| 45 |
+
|
| 46 |
+
.. highlight:: python
|
| 47 |
+
.. code-block:: python
|
| 48 |
+
|
| 49 |
+
# As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects
|
| 50 |
+
plan = cutlass_cppgen.op.GroupedGemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 51 |
+
plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])
|
| 52 |
+
"""
|
| 53 |
+
from __future__ import annotations
|
| 54 |
+
from typing import Optional
|
| 55 |
+
from cutlass_library import DataTypeSize
|
| 56 |
+
|
| 57 |
+
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 58 |
+
cuda = lazy_import("cuda.cuda")
|
| 59 |
+
from cutlass_cppgen.backend.gemm_operation import (
|
| 60 |
+
GemmGroupedArguments,
|
| 61 |
+
GemmOperationGrouped,
|
| 62 |
+
)
|
| 63 |
+
from cutlass_cppgen.backend.library import (
|
| 64 |
+
SchedulerMode,
|
| 65 |
+
TensorDescription,
|
| 66 |
+
TileDescription,
|
| 67 |
+
)
|
| 68 |
+
from cutlass_cppgen.op.gemm import Gemm
|
| 69 |
+
from cutlass_cppgen.shape import GemmCoord
|
| 70 |
+
from cutlass_cppgen.utils import check, datatypes
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class GroupedGemm(Gemm):
|
| 74 |
+
"""
|
| 75 |
+
Constructs a ``GroupedGemm`` object.
|
| 76 |
+
|
| 77 |
+
The data types and layouts of operands A, B, and C, along with the data type of output D
|
| 78 |
+
and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime --
|
| 79 |
+
these are not to be changed after a ``GroupedGemm`` has been constructed.
|
| 80 |
+
|
| 81 |
+
The constructor has optional parameters for flexibly setting these parameters. Please see the constructor
|
| 82 |
+
for ``Gemm`` for examples of these.
|
| 83 |
+
|
| 84 |
+
:param cc: compute capability of device to generate kernels for
|
| 85 |
+
:type cc: int
|
| 86 |
+
:param A: tensor representing data type and layout of operands A
|
| 87 |
+
:param B: tensor representing data type and layout of operands B
|
| 88 |
+
:param C: tensor representing data type and layout of operands C
|
| 89 |
+
:param D: tensor representing data type and layout of operands D
|
| 90 |
+
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
| 91 |
+
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
| 92 |
+
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
| 93 |
+
:type element_accumulator: cutlass_cppgen.DataType
|
| 94 |
+
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
| 95 |
+
:type element: cutlass_cppgen.DataType
|
| 96 |
+
:param layout: generic layout type to be used for operands A, B, C, and D
|
| 97 |
+
:type layout: cutlass_cppgen.LayoutType
|
| 98 |
+
:param element_A: data type to be used for operand A
|
| 99 |
+
:type element_A: cutlass_cppgen.DataType
|
| 100 |
+
:param element_B: data type to be used for operand B
|
| 101 |
+
:type element_B: cutlass_cppgen.DataType
|
| 102 |
+
:param element_C: data type to be used for operand C
|
| 103 |
+
:type element_C: cutlass_cppgen.DataType
|
| 104 |
+
:param element_D: data type to be used for operand D
|
| 105 |
+
:type element_D: cutlass_cppgen.DataType
|
| 106 |
+
:type layout_A: layout of operand A
|
| 107 |
+
:param layout_A: cutlass_cppgen.LayoutType
|
| 108 |
+
:type layout_B: layout of operand B
|
| 109 |
+
:param layout_B: cutlass_cppgen.LayoutType
|
| 110 |
+
:type layout_C: layout of operand C
|
| 111 |
+
:param layout_C: cutlass_cppgen.LayoutType
|
| 112 |
+
:type layout_D: layout of operand D
|
| 113 |
+
:param layout_D: cutlass_cppgen.LayoutType
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self, A=None, B=None, C=None, D=None,
|
| 118 |
+
alpha=1.0, beta=0.0, element_accumulator=None,
|
| 119 |
+
element=None, layout=None,
|
| 120 |
+
element_A=None, element_B=None, element_C=None, element_D=None,
|
| 121 |
+
layout_A=None, layout_B=None, layout_C=None,
|
| 122 |
+
cc: int = None,
|
| 123 |
+
):
|
| 124 |
+
super().__init__(
|
| 125 |
+
A=A, B=B, C=C, D=D,
|
| 126 |
+
alpha=alpha, beta=beta,
|
| 127 |
+
element_accumulator=element_accumulator,
|
| 128 |
+
element=element, layout=layout,
|
| 129 |
+
element_A=element_A, element_B=element_B,
|
| 130 |
+
element_C=element_C, element_D=element_D,
|
| 131 |
+
layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
|
| 132 |
+
cc=cc
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80
|
| 136 |
+
if self.current_cc in [90, 100, 101, 103]:
|
| 137 |
+
self._reset_options(80)
|
| 138 |
+
self._reset_operations(reset_epilogue=False)
|
| 139 |
+
|
| 140 |
+
self.name = "grouped_gemm"
|
| 141 |
+
|
| 142 |
+
@Gemm.swizzling_functor.setter
|
| 143 |
+
def swizzling_functor(self, swizzling_functor):
|
| 144 |
+
"""
|
| 145 |
+
Sets the swizzling functor to the type specified by `swizzling_functor`
|
| 146 |
+
"""
|
| 147 |
+
raise Exception('Grouped GEMM does not currently support different swizzling functors')
|
| 148 |
+
|
| 149 |
+
def construct(self, tile_description: TileDescription = None,
|
| 150 |
+
alignment_A: int = None,
|
| 151 |
+
alignment_B: int = None,
|
| 152 |
+
alignment_C: int = None) -> GemmOperationGrouped:
|
| 153 |
+
"""
|
| 154 |
+
Constructs a ``cutlass_cppgen.backend.GemmOperationGrouped`` based on the input parameters and current
|
| 155 |
+
kernel specification of the ``Gemm`` object.
|
| 156 |
+
|
| 157 |
+
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
| 158 |
+
:type tile_description: cutlass_cppgen.backend.TileDescription
|
| 159 |
+
:param alignment_A: alignment of operand A
|
| 160 |
+
:type alignment_A: int
|
| 161 |
+
:param alignment_B: alignment of operand B
|
| 162 |
+
:type alignment_B: int
|
| 163 |
+
:param alignment_C: alignment of operand C
|
| 164 |
+
:type alignment_C: int
|
| 165 |
+
|
| 166 |
+
:return: operation that was constructed
|
| 167 |
+
:rtype: cutlass_cppgen.backend.GemmOperationGrouped
|
| 168 |
+
"""
|
| 169 |
+
alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A")))
|
| 170 |
+
alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B")))
|
| 171 |
+
alignment_C = check.alignment_or_default(alignment_C, max(self.possible_operations.alignments("C")))
|
| 172 |
+
|
| 173 |
+
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
|
| 174 |
+
|
| 175 |
+
tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
|
| 176 |
+
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
| 177 |
+
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
| 178 |
+
|
| 179 |
+
if tile_description is None:
|
| 180 |
+
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
|
| 181 |
+
tile_description = datatypes.td_from_profiler_op(op)
|
| 182 |
+
else:
|
| 183 |
+
valid, err_str = self._valid_tile_description(tile_description)
|
| 184 |
+
if not valid:
|
| 185 |
+
raise Exception(f"Invalid tile description. {err_str}")
|
| 186 |
+
self.tile_description = tile_description
|
| 187 |
+
|
| 188 |
+
operation = GemmOperationGrouped(
|
| 189 |
+
arch=self.current_cc,
|
| 190 |
+
tile_description=tile_description,
|
| 191 |
+
A=tensor_A, B=tensor_B, C=tensor_C,
|
| 192 |
+
epilogue_functor=self.epilogue_functor,
|
| 193 |
+
swizzling_functor=self._swizzling_functor,
|
| 194 |
+
precompute_mode=SchedulerMode.Device)
|
| 195 |
+
|
| 196 |
+
return operation
|
| 197 |
+
|
| 198 |
+
def run(self, A, B, C, D,
|
| 199 |
+
alpha=None, beta=None, sync: bool = True,
|
| 200 |
+
print_module: bool = False,
|
| 201 |
+
stream: Optional[cuda.CUstream] = None) -> GemmGroupedArguments:
|
| 202 |
+
"""
|
| 203 |
+
Runs the kernel currently specified.
|
| 204 |
+
|
| 205 |
+
By default, this call returns only once the kernel has completed. To launch the kernel
|
| 206 |
+
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
|
| 207 |
+
caller to syncrhonize the results of the kernel before attempting to access outputs
|
| 208 |
+
by calling ``sync()`` on the arguments returned from this call.
|
| 209 |
+
|
| 210 |
+
:param A: list of tensors representing data type and layout of operand A
|
| 211 |
+
:type A: list
|
| 212 |
+
:param B: list of tensors representing data type and layout of operand B
|
| 213 |
+
:type B: list
|
| 214 |
+
:param C: list of tensors representing data type and layout of operand C
|
| 215 |
+
:type C: list
|
| 216 |
+
:param D: list of tensors representing data type and layout of operand D
|
| 217 |
+
:type D: list
|
| 218 |
+
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
| 219 |
+
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
| 220 |
+
:param sync: whether the call should wait for the kernel to complete before returning
|
| 221 |
+
:type sync: bool
|
| 222 |
+
:param print_module: whether to print the emitted C++ code
|
| 223 |
+
:type print_module: bool
|
| 224 |
+
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
| 225 |
+
:type stream: :class:`cuda.cuda.CUstream`
|
| 226 |
+
|
| 227 |
+
:return: arguments passed in to the kernel
|
| 228 |
+
:rtype: cutlass_cppgen.backend.GemmGroupedArguments
|
| 229 |
+
"""
|
| 230 |
+
if not stream:
|
| 231 |
+
stream = cuda.CUstream(0)
|
| 232 |
+
|
| 233 |
+
super().run_setup()
|
| 234 |
+
|
| 235 |
+
if len(A) != len(B) or len(A) != len(C) or len(A) != len(D):
|
| 236 |
+
raise Exception("Lengths of A, B, C, and D lists must be equal")
|
| 237 |
+
|
| 238 |
+
problem_sizes = []
|
| 239 |
+
As, Bs, Cs, Ds = ([None] * len(A) for _ in range(4))
|
| 240 |
+
for i in range(len(A)):
|
| 241 |
+
As[i] = self._verify_tensor(A[i], self.A, self._element_a, self._layout_a, "A")
|
| 242 |
+
Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B")
|
| 243 |
+
Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C")
|
| 244 |
+
Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D")
|
| 245 |
+
problem_sizes.append(GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1]))
|
| 246 |
+
|
| 247 |
+
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
| 248 |
+
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
| 249 |
+
|
| 250 |
+
alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") for A in As))
|
| 251 |
+
alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") for B in Bs))
|
| 252 |
+
alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c, operand="C") for C in Cs))
|
| 253 |
+
self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
| 254 |
+
alignment_C=alignment_c, print_module=print_module)
|
| 255 |
+
|
| 256 |
+
arguments = GemmGroupedArguments(
|
| 257 |
+
operation=self.operation,
|
| 258 |
+
problem_sizes=problem_sizes,
|
| 259 |
+
A=As, B=Bs, C=Cs, D=Ds,
|
| 260 |
+
output_op=self.operation.epilogue_type(alpha, beta),
|
| 261 |
+
stream=stream
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
self.operation.run(arguments)
|
| 265 |
+
|
| 266 |
+
if sync:
|
| 267 |
+
arguments.sync()
|
| 268 |
+
|
| 269 |
+
return arguments
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from bisect import bisect_left
|
| 38 |
+
|
| 39 |
+
from cutlass_library import (
|
| 40 |
+
DataType,
|
| 41 |
+
DataTypeSize,
|
| 42 |
+
MathOperation,
|
| 43 |
+
OperationKind,
|
| 44 |
+
SharedMemPerCC
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
import cutlass_cppgen
|
| 48 |
+
from cutlass_cppgen import get_option_registry
|
| 49 |
+
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
|
| 50 |
+
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
| 51 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 52 |
+
from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity
|
| 53 |
+
from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs
|
| 54 |
+
from cutlass_cppgen.swizzle import get_swizzling_functors
|
| 55 |
+
from cutlass_cppgen.utils import datatypes, check
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class OperationBase:
|
| 59 |
+
"""
|
| 60 |
+
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = OperationKind.Gemm):
|
| 64 |
+
"""
|
| 65 |
+
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
| 66 |
+
:type cc: int
|
| 67 |
+
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
| 68 |
+
:type kernel_cc: int
|
| 69 |
+
:param operation_kind: class of operation that will be performed (e.g., GEMM, Conv)
|
| 70 |
+
:type operation_kind: cutlass_library.OperationKind
|
| 71 |
+
"""
|
| 72 |
+
self.operation_kind = operation_kind
|
| 73 |
+
self.cc = cc if cc is not None else device_cc()
|
| 74 |
+
self.specified_kernel_cc = kernel_cc is not None
|
| 75 |
+
self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc)
|
| 76 |
+
self.tile_description = None
|
| 77 |
+
self._math_operation = None
|
| 78 |
+
|
| 79 |
+
self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind)
|
| 80 |
+
|
| 81 |
+
if self.options is None:
|
| 82 |
+
raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}")
|
| 83 |
+
|
| 84 |
+
# Default activation function: identity
|
| 85 |
+
self._activation = identity
|
| 86 |
+
|
| 87 |
+
def _find_closest_cc(self, cc: int) -> int:
|
| 88 |
+
"""
|
| 89 |
+
Returns the closest CC in _generator_ccs less than or equal to `cc`
|
| 90 |
+
|
| 91 |
+
:param cc: compute capability to query
|
| 92 |
+
:type cc: int
|
| 93 |
+
|
| 94 |
+
:returns: closest CC in _generator_ccs less than or equal to `cc`
|
| 95 |
+
:rtype: int
|
| 96 |
+
"""
|
| 97 |
+
if cc in _generator_ccs:
|
| 98 |
+
return cc
|
| 99 |
+
|
| 100 |
+
# Find closest CC lower than this CC
|
| 101 |
+
idx = bisect_left(_generator_ccs, cc)
|
| 102 |
+
if idx == 0:
|
| 103 |
+
raise Exception(f'No valid CC to fall back to for {cc}')
|
| 104 |
+
return _generator_ccs[idx-1]
|
| 105 |
+
|
| 106 |
+
def activations(self) -> list:
|
| 107 |
+
"""
|
| 108 |
+
Returns possible activation functions that can be used
|
| 109 |
+
|
| 110 |
+
:return: list of activation functions that can be used
|
| 111 |
+
:rtype: list
|
| 112 |
+
"""
|
| 113 |
+
return get_activations()
|
| 114 |
+
|
| 115 |
+
def swizzling_functors(self) -> list:
|
| 116 |
+
"""
|
| 117 |
+
Returns possible swizzling functions that can be used
|
| 118 |
+
|
| 119 |
+
:return: list of swizzling functions that can be used
|
| 120 |
+
:rtype: list
|
| 121 |
+
"""
|
| 122 |
+
return get_swizzling_functors()
|
| 123 |
+
|
| 124 |
+
def _reset_options(self, cc: int):
|
| 125 |
+
"""
|
| 126 |
+
Resets the kernel options based on cc
|
| 127 |
+
|
| 128 |
+
:param cc: compute capability to reset to
|
| 129 |
+
:type cc: int
|
| 130 |
+
"""
|
| 131 |
+
if cc != self.current_cc:
|
| 132 |
+
if cc not in _generator_ccs:
|
| 133 |
+
raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.')
|
| 134 |
+
self.current_cc = cc
|
| 135 |
+
self.options = get_option_registry().options_for_cc(self.current_cc, self.operation_kind)
|
| 136 |
+
|
| 137 |
+
def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
|
| 138 |
+
"""
|
| 139 |
+
Verifies the following properties:
|
| 140 |
+
1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``)
|
| 141 |
+
2) If ``scalar`` is not ``None``, its datatype must match matches the current version
|
| 142 |
+
set by the plan (i.e., those in ``ref_dtype``)
|
| 143 |
+
|
| 144 |
+
If either of these properties does not hold, an exception is raised. If these properties hold and
|
| 145 |
+
``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned.
|
| 146 |
+
|
| 147 |
+
:param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
| 148 |
+
:type scalar: numpy/cupy/torch scalar
|
| 149 |
+
:param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
|
| 150 |
+
:type ref_scalar: numpy/cupy/torch scalar
|
| 151 |
+
:param ref_dtype: data type for the scalar that this object was initialized to
|
| 152 |
+
:param name: identifier of the scalar to verify. Used in raising exceptions
|
| 153 |
+
:type name: str
|
| 154 |
+
|
| 155 |
+
:return: valid scalar to use
|
| 156 |
+
:rtype: numpy/cupy/torch scalar
|
| 157 |
+
"""
|
| 158 |
+
if scalar is None:
|
| 159 |
+
if ref_scalar is None:
|
| 160 |
+
raise Exception(f"Scalar {name} must be set.")
|
| 161 |
+
return ref_scalar
|
| 162 |
+
if hasattr(scalar, "dtype"):
|
| 163 |
+
dtype = datatypes.library_type(scalar.dtype)
|
| 164 |
+
if dtype != ref_dtype:
|
| 165 |
+
raise Exception(
|
| 166 |
+
f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
|
| 167 |
+
)
|
| 168 |
+
return scalar
|
| 169 |
+
|
| 170 |
+
def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
|
| 171 |
+
"""
|
| 172 |
+
Verifies the following properties:
|
| 173 |
+
If ref_dtype is not void:
|
| 174 |
+
1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``)
|
| 175 |
+
2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions
|
| 176 |
+
set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``)
|
| 177 |
+
If ref_dtype is void:
|
| 178 |
+
Neither ``tensor`` nor ``ref_tensor`` are set
|
| 179 |
+
|
| 180 |
+
If either of these properties does not hold, an exception is raised. If these properties hold and
|
| 181 |
+
``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned.
|
| 182 |
+
|
| 183 |
+
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
| 184 |
+
:type tensor: numpy/cupy/torch array/tensor object
|
| 185 |
+
:param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
|
| 186 |
+
:type ref_tensor: numpy/cupy/torch array/tensor object
|
| 187 |
+
:param ref_dtype: data type for the tensor that this object was initialized to
|
| 188 |
+
:param ref_layout: layout for the tensor that this object was initialized to
|
| 189 |
+
:param name: identifier of the tensor to verify. Used in raising exceptions
|
| 190 |
+
:type name: str
|
| 191 |
+
|
| 192 |
+
:return: valid tensor object to use
|
| 193 |
+
:rtype: numpy/cupy/torch array/tensor object
|
| 194 |
+
"""
|
| 195 |
+
if ref_dtype == DataType.void:
|
| 196 |
+
if tensor is not None or ref_tensor is not None:
|
| 197 |
+
raise Exception("Operands with element DataType.void must not be provided a tensor")
|
| 198 |
+
return None
|
| 199 |
+
|
| 200 |
+
if tensor is None:
|
| 201 |
+
if ref_tensor is None:
|
| 202 |
+
raise Exception(f"Tensor {name} must be set.")
|
| 203 |
+
return ref_tensor
|
| 204 |
+
|
| 205 |
+
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
|
| 206 |
+
return tensor
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def opclass(self) -> cutlass_cppgen.OpcodeClass:
|
| 210 |
+
"""
|
| 211 |
+
Returns the opcode class currently in use
|
| 212 |
+
|
| 213 |
+
:return: opcode class currently in use
|
| 214 |
+
:rtype: cutlass_cppgen.OpcodeClass
|
| 215 |
+
"""
|
| 216 |
+
return self.op_class
|
| 217 |
+
|
| 218 |
+
@opclass.setter
|
| 219 |
+
def opclass(self, oc: cutlass_cppgen.OpcodeClass):
|
| 220 |
+
if isinstance(oc, str):
|
| 221 |
+
oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc)
|
| 222 |
+
if oc in self.possible_op_classes:
|
| 223 |
+
self.op_class = oc
|
| 224 |
+
else:
|
| 225 |
+
raise Exception(
|
| 226 |
+
f'Unsupported operation class {oc} for CC {self.cc} and data type combination '
|
| 227 |
+
f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and '
|
| 228 |
+
f'layout combination ({self._layout_a}, {self._layout_b}).')
|
| 229 |
+
|
| 230 |
+
# Changing the op class also changes the possible operations available. Reset these.
|
| 231 |
+
self.possible_operations = self.options.operations(
|
| 232 |
+
self.op_class, self._element_a, self._element_b,
|
| 233 |
+
self._element_accumulator, self._layout_a, self._layout_b, self._math_operation)
|
| 234 |
+
|
| 235 |
+
# Changing the op class changes the elements per access in the epilogue. Reset this.
|
| 236 |
+
if self.epilogue_functor is not None:
|
| 237 |
+
self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor)
|
| 238 |
+
|
| 239 |
+
@property
|
| 240 |
+
def math_operation(self) -> cutlass_cppgen.MathOperation:
|
| 241 |
+
"""
|
| 242 |
+
Returns the math operation currently in use
|
| 243 |
+
|
| 244 |
+
:return: math operation currently in use
|
| 245 |
+
:rtype: cutlass_cppgen.MathOperation
|
| 246 |
+
"""
|
| 247 |
+
return self._math_operation
|
| 248 |
+
|
| 249 |
+
@math_operation.setter
|
| 250 |
+
def math_operation(self, mo: cutlass_cppgen.MathOperation):
|
| 251 |
+
if isinstance(mo, str):
|
| 252 |
+
mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo)
|
| 253 |
+
|
| 254 |
+
if not self.specified_kernel_cc:
|
| 255 |
+
if self.current_cc in [90, 100, 101, 103]:
|
| 256 |
+
# CUTLASS 3.0 kernels do not use different math operations. If one is specified, we
|
| 257 |
+
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
|
| 258 |
+
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
| 259 |
+
self._reset_options(80)
|
| 260 |
+
self._reset_operations(reset_epilogue=False)
|
| 261 |
+
elif self.current_cc in [90, 100, 101, 103]:
|
| 262 |
+
raise Exception("CUTLASS 3.0 kernels do not use different math operations. "
|
| 263 |
+
"To use 2.x kernels with a specific math operation, do not set the `kernel_cc`"
|
| 264 |
+
"parameter when constructing the plan.")
|
| 265 |
+
|
| 266 |
+
self._math_operation = mo
|
| 267 |
+
self._reset_operations()
|
| 268 |
+
|
| 269 |
+
def _elements_per_access(self):
|
| 270 |
+
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
|
| 271 |
+
return 1
|
| 272 |
+
elif self._element_c != DataType.void:
|
| 273 |
+
return 128 // DataTypeSize[self._element_c]
|
| 274 |
+
else:
|
| 275 |
+
return 128 // max(self.possible_operations.alignments("C"))
|
| 276 |
+
|
| 277 |
+
def _create_epilogue_functor_activation(self, activation):
|
| 278 |
+
"""
|
| 279 |
+
Returns the epilogue functor with given activation function
|
| 280 |
+
"""
|
| 281 |
+
if self.epilogue_functor is None:
|
| 282 |
+
elements_per_access = self._elements_per_access()
|
| 283 |
+
else:
|
| 284 |
+
elements_per_access = self.epilogue_functor.epilogue_vector_length
|
| 285 |
+
|
| 286 |
+
if not self.specified_kernel_cc:
|
| 287 |
+
if self.current_cc in [90, 100, 101, 103] and activation != identity:
|
| 288 |
+
# CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation,
|
| 289 |
+
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
|
| 290 |
+
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
| 291 |
+
if self._element_c != self._element_d:
|
| 292 |
+
raise Exception("CUTLASS 2.x kernels require element C to be the same as element D")
|
| 293 |
+
self._reset_options(80)
|
| 294 |
+
self._reset_operations(reset_epilogue=False)
|
| 295 |
+
elif (self.cc in [90, 100, 101, 103] and self.current_cc not in [90, 100, 101, 103] and activation == identity and self._math_operation is None):
|
| 296 |
+
# SM80 fallback kernels are currently used. Since an identity activation is requested,
|
| 297 |
+
# we can switch back to using SM90 kernels.
|
| 298 |
+
self._reset_options(self.cc)
|
| 299 |
+
self._reset_operations(reset_epilogue=False)
|
| 300 |
+
else:
|
| 301 |
+
if self.current_cc in [90, 100, 101, 103] and activation != identity:
|
| 302 |
+
raise Exception("Epilogues with elementwise fusion are not currently supported "
|
| 303 |
+
"in the Python interface for 3.x kernels. To use 2.x kernels "
|
| 304 |
+
"with fused elementwise epilogues, do not set the `kernel_cc` "
|
| 305 |
+
"parameter when constructing the plan.")
|
| 306 |
+
|
| 307 |
+
return get_activation_epilogue(
|
| 308 |
+
activation,
|
| 309 |
+
self._element_d,
|
| 310 |
+
elements_per_access,
|
| 311 |
+
self._element_accumulator,
|
| 312 |
+
self._element_accumulator,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
def _reset_epilogue_functor_activation(self, activation):
|
| 316 |
+
"""
|
| 317 |
+
Set the epilogue functor based on the provided activation function
|
| 318 |
+
"""
|
| 319 |
+
self.epilogue_functor = self._create_epilogue_functor_activation(activation)
|
| 320 |
+
|
| 321 |
+
def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor):
|
| 322 |
+
"""
|
| 323 |
+
Reset the alignment of the current epilogue functor based on alignment C
|
| 324 |
+
"""
|
| 325 |
+
if isinstance(epilogue_functor, EpilogueFunctorVisitor):
|
| 326 |
+
return epilogue_functor
|
| 327 |
+
|
| 328 |
+
if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'):
|
| 329 |
+
# Identity epilogue does not have 'activation_functor'
|
| 330 |
+
activation = identity
|
| 331 |
+
else:
|
| 332 |
+
activation = epilogue_functor.activation_functor
|
| 333 |
+
|
| 334 |
+
epilogue_functor = get_activation_epilogue(
|
| 335 |
+
activation,
|
| 336 |
+
self._element_d,
|
| 337 |
+
alignment,
|
| 338 |
+
self._element_accumulator,
|
| 339 |
+
self._element_accumulator,
|
| 340 |
+
)
|
| 341 |
+
return epilogue_functor
|
| 342 |
+
|
| 343 |
+
@property
|
| 344 |
+
def activation(self):
|
| 345 |
+
"""
|
| 346 |
+
Returns the type of the current activation function used
|
| 347 |
+
"""
|
| 348 |
+
if hasattr(self.epilogue_functor, "activation_functor"):
|
| 349 |
+
return self.epilogue_functor.activation_functor
|
| 350 |
+
else:
|
| 351 |
+
return identity
|
| 352 |
+
|
| 353 |
+
@activation.setter
|
| 354 |
+
def activation(self, act):
|
| 355 |
+
"""
|
| 356 |
+
Sets the type of the activation function to use
|
| 357 |
+
Activation can come with a set of arguments
|
| 358 |
+
|
| 359 |
+
:param act: type of activation function to use
|
| 360 |
+
:type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01)
|
| 361 |
+
|
| 362 |
+
"""
|
| 363 |
+
if isinstance(act, tuple):
|
| 364 |
+
if isinstance(act[0], str):
|
| 365 |
+
act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0])
|
| 366 |
+
else:
|
| 367 |
+
act_fn = act[0]
|
| 368 |
+
self._reset_epilogue_functor_activation(act_fn)
|
| 369 |
+
self._activation_args = act[1]
|
| 370 |
+
self._activation = act[0]
|
| 371 |
+
else:
|
| 372 |
+
if isinstance(act, str):
|
| 373 |
+
act = getattr(cutlass_cppgen.backend.epilogue, act)
|
| 374 |
+
self._reset_epilogue_functor_activation(act)
|
| 375 |
+
self._activation = act
|
| 376 |
+
|
| 377 |
+
@property
|
| 378 |
+
def epilogue_visitor(self):
|
| 379 |
+
"""
|
| 380 |
+
Return the epilogue functor
|
| 381 |
+
"""
|
| 382 |
+
return self.epilogue_functor
|
| 383 |
+
|
| 384 |
+
@epilogue_visitor.setter
|
| 385 |
+
def epilogue_visitor(self, visitor):
|
| 386 |
+
"""
|
| 387 |
+
Create the epilogue visitor
|
| 388 |
+
"""
|
| 389 |
+
self.epilogue_functor = EpilogueFunctorVisitor(cc_map[self.cc], visitor)
|
| 390 |
+
|
| 391 |
+
# The epilogue_functor may consume too much shared memory
|
| 392 |
+
# Reset the possible operations
|
| 393 |
+
if self.cc not in [90, 100, 101, 103]:
|
| 394 |
+
# The shared memory is only a concern for sm90+ epilogue
|
| 395 |
+
# In sm80, the epilogue and mainloop share the shared memory
|
| 396 |
+
return
|
| 397 |
+
|
| 398 |
+
datatype_comb = self.possible_operations.datatype_comb
|
| 399 |
+
layout_comb = self.possible_operations.layout_comb
|
| 400 |
+
new_possible_operations = KernelsForDataType(datatype_comb, layout_comb)
|
| 401 |
+
for operation in self.possible_operations.all_operations:
|
| 402 |
+
td = datatypes.td_from_profiler_op(operation)
|
| 403 |
+
# Filter invalid epilogue schedules
|
| 404 |
+
if cc_map[self.cc] == 90 and td.epilogue_schedule not in [
|
| 405 |
+
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized,
|
| 406 |
+
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]:
|
| 407 |
+
continue
|
| 408 |
+
epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td)
|
| 409 |
+
|
| 410 |
+
# Verify the maximum number of mainloop stages
|
| 411 |
+
mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, OperationKind.Gemm)
|
| 412 |
+
smem_capacity_bytes = SharedMemPerCC[self.cc] << 10
|
| 413 |
+
mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage
|
| 414 |
+
if mainloop_stages < 2:
|
| 415 |
+
# Mainloop stages must >= 2
|
| 416 |
+
continue
|
| 417 |
+
|
| 418 |
+
new_possible_operations.add(operation)
|
| 419 |
+
if len(new_possible_operations.all_operations) == 0:
|
| 420 |
+
raise RuntimeError(
|
| 421 |
+
"The epilogue consumes too much shared memory. "
|
| 422 |
+
"No valid tile description is found in the generator.")
|
| 423 |
+
self.possible_operations = new_possible_operations
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def run_setup(self):
|
| 427 |
+
"""
|
| 428 |
+
Steps that must be taken before caling `plan.run()`
|
| 429 |
+
"""
|
| 430 |
+
# Initialize the memory pool if, if not already done
|
| 431 |
+
cutlass_cppgen.get_memory_pool()
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for expressing shapes
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from cutlass_library import (
|
| 38 |
+
ConvMode,
|
| 39 |
+
ConvKind,
|
| 40 |
+
LayoutType
|
| 41 |
+
)
|
| 42 |
+
from cutlass_cppgen.backend.c_types import (
|
| 43 |
+
Conv2DProblemSize_,
|
| 44 |
+
GemmCoord_,
|
| 45 |
+
GemmCoordBatched_
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class MatrixCoord:
|
| 50 |
+
def __init__(self, row, col):
|
| 51 |
+
self._row = row
|
| 52 |
+
self._col = col
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def row(self):
|
| 56 |
+
return self._row
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def column(self):
|
| 60 |
+
return self._col
|
| 61 |
+
|
| 62 |
+
def leading_dimension(self, layout: LayoutType) -> int:
|
| 63 |
+
"""
|
| 64 |
+
Returns the leading dimension for a matrix with layout ``layout`` and shape provided by the MatrixCoord.
|
| 65 |
+
|
| 66 |
+
:param layout: layout of matrix
|
| 67 |
+
:type layout: cutlass_library.LayoutType
|
| 68 |
+
|
| 69 |
+
:returns: leading dimension
|
| 70 |
+
:rtype: int
|
| 71 |
+
"""
|
| 72 |
+
if layout == LayoutType.RowMajor:
|
| 73 |
+
return self._col
|
| 74 |
+
elif layout == LayoutType.ColumnMajor:
|
| 75 |
+
return self._row
|
| 76 |
+
else:
|
| 77 |
+
raise Exception(f'Unsupported layout for leading dimension calculation: {layout}')
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class GemmCoord:
|
| 81 |
+
def __init__(self, m: int, n: int, k: int):
|
| 82 |
+
self._m = m
|
| 83 |
+
self._n = n
|
| 84 |
+
self._k = k
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def m(self) -> int:
|
| 88 |
+
return self._m
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def n(self) -> int:
|
| 92 |
+
return self._n
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def k(self) -> int:
|
| 96 |
+
return self._k
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def mk(self) -> MatrixCoord:
|
| 100 |
+
return MatrixCoord(self._m, self._k)
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def mn(self) -> MatrixCoord:
|
| 104 |
+
return MatrixCoord(self._m, self._n)
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def kn(self) -> MatrixCoord:
|
| 108 |
+
return MatrixCoord(self._k, self._n)
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def ctype(self) -> GemmCoord_:
|
| 112 |
+
return GemmCoord_(self._m, self._n, self._k)
|
| 113 |
+
|
| 114 |
+
def batched_ctype(self, batch_count: int) -> GemmCoordBatched_:
|
| 115 |
+
return GemmCoordBatched_(self._m, self._n, self._k, batch_count)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class Conv2DProblemSize:
|
| 119 |
+
def __init__(
|
| 120 |
+
self, n: int, h: int, w: int, c: int,
|
| 121 |
+
k: int, r: int, s: int, c_: int,
|
| 122 |
+
pad_h: int, pad_w: int, stride_h: int, stride_w: int,
|
| 123 |
+
dilation_h: int, dilation_w: int, mode: ConvMode=ConvMode.CrossCorrelation,
|
| 124 |
+
split_k_slices: int=1, groups: int=1):
|
| 125 |
+
|
| 126 |
+
self.N = n
|
| 127 |
+
self.H = h
|
| 128 |
+
self.W = w
|
| 129 |
+
self.C = c
|
| 130 |
+
self.K = k
|
| 131 |
+
self.R = r
|
| 132 |
+
self.S = s
|
| 133 |
+
self.pad_h = pad_h
|
| 134 |
+
self.pad_w = pad_w
|
| 135 |
+
self.stride_h = stride_h
|
| 136 |
+
self.stride_w = stride_w
|
| 137 |
+
self.dilation_h = dilation_h
|
| 138 |
+
self.dilation_w = dilation_w
|
| 139 |
+
self.mode = int(mode)
|
| 140 |
+
self.split_k_slices = split_k_slices
|
| 141 |
+
self.groups = groups
|
| 142 |
+
self.P = ((h + pad_h * 2 - r * dilation_h) // stride_h) + 1
|
| 143 |
+
self.Q = ((w + pad_w * 2 - s * dilation_w) // stride_w) + 1
|
| 144 |
+
|
| 145 |
+
@property
|
| 146 |
+
def ctype(self) -> Conv2DProblemSize_:
|
| 147 |
+
return Conv2DProblemSize_(self)
|
| 148 |
+
|
| 149 |
+
def implicit_gemm_size(self, kind: ConvKind):
|
| 150 |
+
if kind == ConvKind.Fprop:
|
| 151 |
+
return GemmCoord(
|
| 152 |
+
self.N * self.P * self.Q,
|
| 153 |
+
self.K,
|
| 154 |
+
self.R * self.S * self.C // self.groups
|
| 155 |
+
)
|
| 156 |
+
elif kind == ConvKind.Dgrad:
|
| 157 |
+
return GemmCoord(
|
| 158 |
+
self.N * self.H * self.W,
|
| 159 |
+
self.C,
|
| 160 |
+
self.R * self.S * self.K
|
| 161 |
+
)
|
| 162 |
+
elif kind == ConvKind.Wgrad:
|
| 163 |
+
return GemmCoord(
|
| 164 |
+
self.K,
|
| 165 |
+
self.R * self.S * self.C,
|
| 166 |
+
self.N * self.P * self.Q
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
@staticmethod
|
| 170 |
+
def from_sizes(input_size, weight_size):
|
| 171 |
+
K, R, S, _ = weight_size
|
| 172 |
+
pad_h = R // 2
|
| 173 |
+
pad_w = S // 2
|
| 174 |
+
stride_h = 1
|
| 175 |
+
stride_w = 1
|
| 176 |
+
dilation_h = 1
|
| 177 |
+
dilation_w = 1
|
| 178 |
+
return Conv2DProblemSize(
|
| 179 |
+
*input_size,
|
| 180 |
+
*weight_size,
|
| 181 |
+
pad_h, pad_w,
|
| 182 |
+
stride_h, stride_w,
|
| 183 |
+
dilation_h, dilation_w
|
| 184 |
+
)
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Registry of swizzling functions
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from cutlass_library import SwizzlingFunctor
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
IdentitySwizzle1 = SwizzlingFunctor.Identity1
|
| 41 |
+
IdentitySwizzle2 = SwizzlingFunctor.Identity2
|
| 42 |
+
IdentitySwizzle4 = SwizzlingFunctor.Identity4
|
| 43 |
+
IdentitySwizzle8 = SwizzlingFunctor.Identity8
|
| 44 |
+
HorizontalSwizzle = SwizzlingFunctor.Horizontal
|
| 45 |
+
ThreadblockSwizzleStreamK = SwizzlingFunctor.StreamK
|
| 46 |
+
StridedDgradIdentitySwizzle1 = SwizzlingFunctor.StridedDgradIdentity1
|
| 47 |
+
StridedDgradIdentitySwizzle4 = SwizzlingFunctor.StridedDgradIdentity4
|
| 48 |
+
StridedDgradHorizontalSwizzle = SwizzlingFunctor.StridedDgradHorizontal
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
_swizzling_functors = [
|
| 52 |
+
IdentitySwizzle1,
|
| 53 |
+
IdentitySwizzle2,
|
| 54 |
+
IdentitySwizzle4,
|
| 55 |
+
IdentitySwizzle8,
|
| 56 |
+
HorizontalSwizzle,
|
| 57 |
+
ThreadblockSwizzleStreamK,
|
| 58 |
+
StridedDgradIdentitySwizzle1,
|
| 59 |
+
StridedDgradIdentitySwizzle4,
|
| 60 |
+
StridedDgradHorizontalSwizzle,
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_swizzling_functors():
|
| 65 |
+
return _swizzling_functors
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from cutlass_cppgen.utils.check import (
|
| 34 |
+
alignment_or_default,
|
| 35 |
+
calculate_smem_usage,
|
| 36 |
+
calculate_smem_usage_per_stage,
|
| 37 |
+
valid_cluster_shape,
|
| 38 |
+
valid_schedule,
|
| 39 |
+
valid_stage_count,
|
| 40 |
+
update_alignment,
|
| 41 |
+
)
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utility functions for checking constraints on kernels and calculating kernel attributes
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import ctypes
|
| 38 |
+
|
| 39 |
+
from cutlass_library import DataTypeSize, KernelScheduleSuffixes, OperationKind, SharedMemPerCC
|
| 40 |
+
|
| 41 |
+
import cutlass_cppgen
|
| 42 |
+
from cutlass_cppgen.backend.library import TileDescription
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int:
|
| 46 |
+
"""
|
| 47 |
+
Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
|
| 48 |
+
|
| 49 |
+
:param td: tile description to compute shared memory of
|
| 50 |
+
:type td: TileDescription
|
| 51 |
+
:param operation_kind: identifier for the type of operation being performed
|
| 52 |
+
:type operation_kind: cutlass_library.OperationKind
|
| 53 |
+
|
| 54 |
+
:return: number of bytes of shared memory consumed by a single stage
|
| 55 |
+
:rtype: int
|
| 56 |
+
"""
|
| 57 |
+
m, n, k = td.blackwell_threadblock_shape
|
| 58 |
+
if td.is_2sm:
|
| 59 |
+
m //= 2
|
| 60 |
+
|
| 61 |
+
if operation_kind == OperationKind.Gemm:
|
| 62 |
+
stage_barrier_bytes = 32
|
| 63 |
+
return (
|
| 64 |
+
(DataTypeSize[td.math_instruction.element_a] * m * k // 8)
|
| 65 |
+
+ (DataTypeSize[td.math_instruction.element_b] * k * n // 8)
|
| 66 |
+
+ stage_barrier_bytes
|
| 67 |
+
)
|
| 68 |
+
else:
|
| 69 |
+
raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def calculate_smem_usage(operation) -> int:
|
| 73 |
+
"""
|
| 74 |
+
Returns the amount of shared memory in bytes consumed by a kernel.
|
| 75 |
+
|
| 76 |
+
:return: number of bytes of shared memory consumed by the operation
|
| 77 |
+
:return: int
|
| 78 |
+
"""
|
| 79 |
+
_per_stage = calculate_smem_usage_per_stage(operation.tile_description, operation.operation_kind)
|
| 80 |
+
return _per_stage * operation.tile_description.stages
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def valid_stage_count(
|
| 84 |
+
cc: int,
|
| 85 |
+
kernel_cc: int,
|
| 86 |
+
td: TileDescription,
|
| 87 |
+
element_C: cutlass_cppgen.DataType = None,
|
| 88 |
+
element_D: cutlass_cppgen.DataType = None,
|
| 89 |
+
verbose: bool = True) -> tuple:
|
| 90 |
+
"""
|
| 91 |
+
Checks whether a device with `cc` supports the number of stages within `tile_description`, both
|
| 92 |
+
based on raw limits on the number of stages and based on shared memory capacity
|
| 93 |
+
|
| 94 |
+
:param cc: compute capability of device in question
|
| 95 |
+
:type cc: int
|
| 96 |
+
:param kernel_cc: compute capability that the kernel targets (corresponding to the arch::SMxy tag in CUTLASS)
|
| 97 |
+
:type kernel_cc: int
|
| 98 |
+
:param td: tile description to check
|
| 99 |
+
:type td: TileDescription
|
| 100 |
+
:param element_C: data type of operand C
|
| 101 |
+
:type element_C: cutlass_cppgen.DataType
|
| 102 |
+
:param element_D: data type of operand D
|
| 103 |
+
:type element_D: cutlass_cppgen.DataType
|
| 104 |
+
:param verbose: whether to log warnings
|
| 105 |
+
:type verbose: bool
|
| 106 |
+
|
| 107 |
+
:return: tuple with the first element indicating whether the provided tile description is
|
| 108 |
+
valid for the provided device and the second element being an error message
|
| 109 |
+
:rtype: tuple
|
| 110 |
+
"""
|
| 111 |
+
if kernel_cc in [90, 100, 101, 103]:
|
| 112 |
+
if (td.stages is None or td.stages == 0):
|
| 113 |
+
# Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically
|
| 114 |
+
# determines the stage count to use. Thus, all settings are valid in these scenarios.
|
| 115 |
+
return (True, "")
|
| 116 |
+
elif verbose:
|
| 117 |
+
cutlass_cppgen.logger.warning(
|
| 118 |
+
"Setting an explicit stage count for SM90 kernels currently may "
|
| 119 |
+
"result in compilation errors if the combination of tile shape, "
|
| 120 |
+
"stage count, and shared memory requirement of the epilogue exceeds "
|
| 121 |
+
"the available shared memory per SM.")
|
| 122 |
+
|
| 123 |
+
if td.stages <= 0:
|
| 124 |
+
return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.")
|
| 125 |
+
|
| 126 |
+
if cc < 80 and td.stages != 2:
|
| 127 |
+
return (False, f"Tile description has stage count of {td.stages}, "
|
| 128 |
+
f"but only 2 stages are supported on SM{cc}.")
|
| 129 |
+
|
| 130 |
+
# The calculation below does not consider shared memory used by the epilogue and, thus,
|
| 131 |
+
# only catches cases in which the mainloop exceeds the device's shared memory capacity.
|
| 132 |
+
# This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the
|
| 133 |
+
# mainloop and epilogue is shared.
|
| 134 |
+
smem_per_stage = calculate_smem_usage_per_stage(td, OperationKind.Gemm)
|
| 135 |
+
smem_usage_mainloop = (smem_per_stage * td.stages)
|
| 136 |
+
smem_arch = SharedMemPerCC[cc] << 10
|
| 137 |
+
if smem_usage_mainloop > smem_arch:
|
| 138 |
+
return ( False,
|
| 139 |
+
"Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n"
|
| 140 |
+
f"Details:\n"
|
| 141 |
+
f"Mainloop uses {smem_per_stage} bytes of shared memory per stage, and "
|
| 142 |
+
f"{td.stages} stages for a total of {smem_usage_mainloop} bytes.\n"
|
| 143 |
+
f"The maxmium amount of shared memory that can be used per block on CC {cc} is {smem_arch}.")
|
| 144 |
+
|
| 145 |
+
return (True, "")
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple:
|
| 149 |
+
"""
|
| 150 |
+
Checks whether a device with `cc` supports a thread block cluster of shape `cluster_shape`.
|
| 151 |
+
|
| 152 |
+
:param cc: compute capability of device in question
|
| 153 |
+
:type cc: int
|
| 154 |
+
:param cluster_shape: dimensions of thread block cluster shape to check
|
| 155 |
+
:type cluster_shape: list
|
| 156 |
+
|
| 157 |
+
:return: tuple with the first element indicating whether the provided cluster shape is
|
| 158 |
+
valid for the provided device and the second element being an error message
|
| 159 |
+
:rtype: tuple
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
if cc < 90 or cc in [120, 121]:
|
| 163 |
+
if cluster_shape != [1, 1, 1]:
|
| 164 |
+
return (False,
|
| 165 |
+
f"Cluster shape for pre-SM90 architectures and SM 120 and 121 must be [1, 1, 1]. Received cluster shape of "
|
| 166 |
+
f"{cluster_shape} for SM{cc}.")
|
| 167 |
+
else:
|
| 168 |
+
return (True, "")
|
| 169 |
+
|
| 170 |
+
if len(cluster_shape) != 3:
|
| 171 |
+
return (False,
|
| 172 |
+
f"Cluster shapes must be rank-3. Received {cluster_shape} (rank {len(cluster_shape)}")
|
| 173 |
+
|
| 174 |
+
if cluster_shape[2] != 1:
|
| 175 |
+
return (False,
|
| 176 |
+
"CUTLASS kernels currently require the third dimension of cluster shape to be 1. "
|
| 177 |
+
f"Received cluster shape of {cluster_shape}.")
|
| 178 |
+
|
| 179 |
+
return (True, "")
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def valid_schedule(
|
| 183 |
+
cc: int,
|
| 184 |
+
kernel_schedule: cutlass_cppgen.KernelScheduleType,
|
| 185 |
+
epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
|
| 186 |
+
tile_scheduler: cutlass_cppgen.TileSchedulerType) -> tuple:
|
| 187 |
+
"""
|
| 188 |
+
Checks that the kernel and epilogue schedules passed in are a valid combination for
|
| 189 |
+
a device of compute capability ``cc``.
|
| 190 |
+
|
| 191 |
+
:param cc: compute capability of device in question
|
| 192 |
+
:type cc: int
|
| 193 |
+
:param kernel_schedule: kernel schedule type
|
| 194 |
+
:type kernel_schedule: cutlass_cppgen.KernelScheduleType
|
| 195 |
+
:param epilogue_schedule: epilogue schedule type
|
| 196 |
+
:type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType
|
| 197 |
+
:param tile_scheduler: tile scheduler type
|
| 198 |
+
:type tile_scheduler: cutlass_cppgen.TileSchedulerType
|
| 199 |
+
|
| 200 |
+
:return: tuple with the first element indicating whether the provided schedules are
|
| 201 |
+
valid for the provided device and the second element being an error message
|
| 202 |
+
:rtype: tuple
|
| 203 |
+
"""
|
| 204 |
+
kernel_auto = (kernel_schedule == cutlass_cppgen.KernelScheduleType.ScheduleAuto)
|
| 205 |
+
epilogue_auto = (epilogue_schedule == cutlass_cppgen.EpilogueScheduleType.ScheduleAuto)
|
| 206 |
+
tile_scheduler_default = (tile_scheduler == cutlass_cppgen.TileSchedulerType.Default)
|
| 207 |
+
if (cc < 90 or cc in [120, 121]) and not (kernel_auto and epilogue_auto and tile_scheduler_default):
|
| 208 |
+
return (False, "Non-default schedules are only supported on SM90 and beyond (excluding SM120 and SM121)")
|
| 209 |
+
|
| 210 |
+
if cc == 90 and ((kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto)):
|
| 211 |
+
return (False, "Kernel and epilogue schedules must either both be auto or neither be auto")
|
| 212 |
+
|
| 213 |
+
if not tile_scheduler_default:
|
| 214 |
+
cooperative_kernels = [cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative,
|
| 215 |
+
cutlass_cppgen.KernelScheduleType.CpAsyncWarpSpecializedCooperative]
|
| 216 |
+
if cc == 90 and (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels):
|
| 217 |
+
return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule")
|
| 218 |
+
return (True, "")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def alignment_or_default(alignment_provided: int, default_alignment: int) -> int:
|
| 222 |
+
"""
|
| 223 |
+
Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
|
| 224 |
+
that `alignment_provided` does not exceed `default_alignment`.
|
| 225 |
+
|
| 226 |
+
:param alignment_provided: alignment preference specified. Can be None.
|
| 227 |
+
:type alignment_provided: int
|
| 228 |
+
:param default_alignment: alignment to use if `alignment_provided` is None
|
| 229 |
+
:type default_alignment: int
|
| 230 |
+
|
| 231 |
+
:return: alignment to use
|
| 232 |
+
:rtype: int
|
| 233 |
+
"""
|
| 234 |
+
if alignment_provided is not None:
|
| 235 |
+
if alignment_provided > default_alignment:
|
| 236 |
+
raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
|
| 237 |
+
return alignment_provided
|
| 238 |
+
|
| 239 |
+
return default_alignment
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def update_alignment(alignment_provided:int, default_alignment: int) -> int:
|
| 243 |
+
"""
|
| 244 |
+
Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
|
| 245 |
+
that `alignment_provided` does not exceed `default_alignment`.
|
| 246 |
+
|
| 247 |
+
:param alignment_provided: alignment preference specified. Can be None.
|
| 248 |
+
:type alignment_provided: int
|
| 249 |
+
:param default_alignment: alignment to use if `alignment_provided` is None
|
| 250 |
+
:type default_alignment: int
|
| 251 |
+
|
| 252 |
+
:return: alignment to use
|
| 253 |
+
:rtype: int
|
| 254 |
+
"""
|
| 255 |
+
if alignment_provided is not None:
|
| 256 |
+
if alignment_provided > default_alignment:
|
| 257 |
+
if alignment_provided % default_alignment == 0:
|
| 258 |
+
return default_alignment
|
| 259 |
+
raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
|
| 260 |
+
return alignment_provided
|
| 261 |
+
|
| 262 |
+
return default_alignment
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utility functions for converting between frontend datatypes and CUTLASS datatypes
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import cutlass_cppgen
|
| 38 |
+
from cutlass_library import (
|
| 39 |
+
DataTypeSize,
|
| 40 |
+
MathOperation,
|
| 41 |
+
MathInstruction
|
| 42 |
+
)
|
| 43 |
+
from cutlass_cppgen.backend.library import (
|
| 44 |
+
TileDescription,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
bfloat16_available = None
|
| 48 |
+
cupy_available = None
|
| 49 |
+
numpy_available = None
|
| 50 |
+
torch_available = None
|
| 51 |
+
_library_to_cupy_dict = None
|
| 52 |
+
_library_to_numpy_dict = None
|
| 53 |
+
_library_to_torch_dict = None
|
| 54 |
+
_torch_to_library_dict = None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def is_numpy_available():
|
| 58 |
+
global numpy_available, _library_to_numpy_dict
|
| 59 |
+
if numpy_available is None:
|
| 60 |
+
try:
|
| 61 |
+
import numpy as np
|
| 62 |
+
|
| 63 |
+
numpy_available = True
|
| 64 |
+
_library_to_numpy_dict = {
|
| 65 |
+
cutlass_cppgen.DataType.f16: np.float16,
|
| 66 |
+
cutlass_cppgen.DataType.f32: np.float32,
|
| 67 |
+
cutlass_cppgen.DataType.f64: np.float64,
|
| 68 |
+
cutlass_cppgen.DataType.s8: np.int8,
|
| 69 |
+
cutlass_cppgen.DataType.s32: np.int32,
|
| 70 |
+
}
|
| 71 |
+
except ImportError:
|
| 72 |
+
numpy_available = False
|
| 73 |
+
_library_to_numpy_dict = {}
|
| 74 |
+
return numpy_available
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def is_numpy_tensor(inp) -> bool:
|
| 78 |
+
if is_numpy_available():
|
| 79 |
+
import numpy as np
|
| 80 |
+
return isinstance(inp, np.ndarray)
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def numpy_library_type(inp) -> cutlass_cppgen.DataType:
|
| 85 |
+
if is_numpy_available():
|
| 86 |
+
import numpy as np
|
| 87 |
+
if inp == np.float16:
|
| 88 |
+
return cutlass_cppgen.DataType.f16
|
| 89 |
+
elif inp == np.float32:
|
| 90 |
+
return cutlass_cppgen.DataType.f32
|
| 91 |
+
elif inp == np.float64:
|
| 92 |
+
return cutlass_cppgen.DataType.f64
|
| 93 |
+
elif inp == np.int8:
|
| 94 |
+
return cutlass_cppgen.DataType.s8
|
| 95 |
+
elif inp == np.int32:
|
| 96 |
+
return cutlass_cppgen.DataType.s32
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def numpy_type(inp):
|
| 101 |
+
return _library_to_numpy_dict.get(inp, None)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def is_cupy_available():
|
| 105 |
+
global cupy_available
|
| 106 |
+
if cupy_available is None:
|
| 107 |
+
try:
|
| 108 |
+
import cupy as cp
|
| 109 |
+
|
| 110 |
+
cupy_available = True
|
| 111 |
+
_library_to_cupy_dict = {
|
| 112 |
+
cutlass_cppgen.DataType.f16: cp.float16,
|
| 113 |
+
cutlass_cppgen.DataType.f32: cp.float32,
|
| 114 |
+
cutlass_cppgen.DataType.f64: cp.float64,
|
| 115 |
+
cutlass_cppgen.DataType.s8: cp.int8,
|
| 116 |
+
cutlass_cppgen.DataType.s32: cp.int32,
|
| 117 |
+
}
|
| 118 |
+
except ImportError:
|
| 119 |
+
cupy_available = False
|
| 120 |
+
_library_to_cupy_dict = {}
|
| 121 |
+
return cupy_available
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def is_cupy_tensor(inp) -> bool:
|
| 125 |
+
if is_cupy_available():
|
| 126 |
+
import cupy as cp
|
| 127 |
+
return isinstance(inp, cp.ndarray)
|
| 128 |
+
return False
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def cupy_library_type(inp) -> cutlass_cppgen.DataType:
|
| 132 |
+
if is_cupy_available():
|
| 133 |
+
import cupy as cp
|
| 134 |
+
if inp == cp.float16:
|
| 135 |
+
return cutlass_cppgen.DataType.f16
|
| 136 |
+
elif inp == cp.float32:
|
| 137 |
+
return cutlass_cppgen.DataType.f32
|
| 138 |
+
elif inp == cp.float64:
|
| 139 |
+
return cutlass_cppgen.DataType.f64
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def cupy_type(inp):
|
| 144 |
+
return _library_to_cupy_dict.get(inp, None)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def is_torch_available():
|
| 148 |
+
global torch_available, _library_to_torch_dict, _torch_to_library_dict
|
| 149 |
+
if torch_available is None:
|
| 150 |
+
try:
|
| 151 |
+
import torch
|
| 152 |
+
|
| 153 |
+
torch_available = True
|
| 154 |
+
_torch_to_library_dict = {
|
| 155 |
+
torch.half: cutlass_cppgen.DataType.f16,
|
| 156 |
+
torch.float16: cutlass_cppgen.DataType.f16,
|
| 157 |
+
torch.bfloat16: cutlass_cppgen.DataType.bf16,
|
| 158 |
+
torch.float: cutlass_cppgen.DataType.f32,
|
| 159 |
+
torch.float32: cutlass_cppgen.DataType.f32,
|
| 160 |
+
torch.double: cutlass_cppgen.DataType.f64,
|
| 161 |
+
torch.float64: cutlass_cppgen.DataType.f64,
|
| 162 |
+
torch.int8: cutlass_cppgen.DataType.s8,
|
| 163 |
+
torch.int32: cutlass_cppgen.DataType.s32,
|
| 164 |
+
torch.uint8: cutlass_cppgen.DataType.u8,
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
_library_to_torch_dict = {
|
| 168 |
+
cutlass_cppgen.DataType.f16: torch.half,
|
| 169 |
+
cutlass_cppgen.DataType.f16: torch.float16,
|
| 170 |
+
cutlass_cppgen.DataType.bf16: torch.bfloat16,
|
| 171 |
+
cutlass_cppgen.DataType.f32: torch.float,
|
| 172 |
+
cutlass_cppgen.DataType.f32: torch.float32,
|
| 173 |
+
cutlass_cppgen.DataType.f64: torch.double,
|
| 174 |
+
cutlass_cppgen.DataType.f64: torch.float64,
|
| 175 |
+
cutlass_cppgen.DataType.s8: torch.int8,
|
| 176 |
+
cutlass_cppgen.DataType.s32: torch.int32,
|
| 177 |
+
cutlass_cppgen.DataType.u8: torch.uint8,
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
def possibly_add_type(torch_type_name, cutlass_type):
|
| 181 |
+
# Only try adding the type if the version of torch being used supports it
|
| 182 |
+
if hasattr(torch, torch_type_name):
|
| 183 |
+
torch_type = getattr(torch, torch_type_name)
|
| 184 |
+
_torch_to_library_dict[torch_type] = cutlass_type
|
| 185 |
+
_library_to_torch_dict[cutlass_type] = torch_type
|
| 186 |
+
|
| 187 |
+
possibly_add_type("float8_e4m3fn", cutlass_cppgen.DataType.e4m3)
|
| 188 |
+
possibly_add_type("float8_e5m2", cutlass_cppgen.DataType.e5m2)
|
| 189 |
+
|
| 190 |
+
except ImportError:
|
| 191 |
+
torch_available = False
|
| 192 |
+
_torch_to_library_dict = {}
|
| 193 |
+
_library_to_torch_dict = {}
|
| 194 |
+
return torch_available
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def is_torch_tensor(inp) -> bool:
|
| 198 |
+
if is_torch_available():
|
| 199 |
+
import torch
|
| 200 |
+
return isinstance(inp, torch.Tensor)
|
| 201 |
+
return False
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def torch_library_type(inp) -> cutlass_cppgen.DataType:
|
| 205 |
+
return _torch_to_library_dict.get(inp, None)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def torch_type(inp):
|
| 209 |
+
return _library_to_torch_dict.get(inp, None)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def is_bfloat16_available():
|
| 213 |
+
global bfloat16_available
|
| 214 |
+
|
| 215 |
+
if bfloat16_available is None:
|
| 216 |
+
try:
|
| 217 |
+
import bfloat16
|
| 218 |
+
|
| 219 |
+
bfloat16_available = True
|
| 220 |
+
except ImportError:
|
| 221 |
+
bfloat16_available = False
|
| 222 |
+
return bfloat16_available
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def bfloat16_library_type(inp) -> cutlass_cppgen.DataType:
|
| 226 |
+
if is_bfloat16_available():
|
| 227 |
+
import bfloat16
|
| 228 |
+
if inp == bfloat16.bfloat16:
|
| 229 |
+
return cutlass_cppgen.DataType.bf16
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def bfloat16_type(inp):
|
| 233 |
+
if is_bfloat16_available():
|
| 234 |
+
import bfloat16
|
| 235 |
+
if inp == cutlass_cppgen.DataType.bf16:
|
| 236 |
+
return bfloat16.bfloat16
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def library_type(inp):
|
| 240 |
+
if inp in DataTypeSize:
|
| 241 |
+
return inp
|
| 242 |
+
|
| 243 |
+
for cvt_fn in [
|
| 244 |
+
bfloat16_library_type,
|
| 245 |
+
cupy_library_type,
|
| 246 |
+
numpy_library_type,
|
| 247 |
+
torch_library_type,
|
| 248 |
+
]:
|
| 249 |
+
out = cvt_fn(inp)
|
| 250 |
+
if out is not None:
|
| 251 |
+
return out
|
| 252 |
+
|
| 253 |
+
raise Exception(f"No available conversion from type {inp} to a library type.")
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _tensor_from_numpy(np_tensor):
|
| 257 |
+
dtype = library_type(np_tensor.dtype)
|
| 258 |
+
if np_tensor.flags.c_contiguous:
|
| 259 |
+
layout = cutlass_cppgen.LayoutType.RowMajor
|
| 260 |
+
elif np_tensor.flags.f_contiguous:
|
| 261 |
+
layout = cutlass_cppgen.LayoutType.ColumnMajor
|
| 262 |
+
return (dtype, layout)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _tensor_from_torch(pt_tensor):
|
| 266 |
+
dtype = library_type(pt_tensor.dtype)
|
| 267 |
+
return (dtype, cutlass_cppgen.LayoutType.RowMajor)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def get_datatype_and_layout(tensor):
|
| 271 |
+
if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)):
|
| 272 |
+
return _tensor_from_numpy(tensor)
|
| 273 |
+
elif is_torch_tensor(tensor):
|
| 274 |
+
return _tensor_from_torch(tensor)
|
| 275 |
+
elif isinstance(tensor, float) or isinstance(tensor, int):
|
| 276 |
+
return (cutlass_cppgen.DataType.f32, cutlass_cppgen.LayoutType.RowMajor)
|
| 277 |
+
else:
|
| 278 |
+
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def get_tensor_shape(tensor, op="GEMM"):
|
| 282 |
+
if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)):
|
| 283 |
+
return tensor.shape
|
| 284 |
+
elif is_torch_tensor(tensor):
|
| 285 |
+
size = tensor.size()
|
| 286 |
+
if op == "CONV":
|
| 287 |
+
# PyTorch Tensors have shape NCHW
|
| 288 |
+
return (size[0], size[2], size[3], size[1])
|
| 289 |
+
else:
|
| 290 |
+
return tuple(tensor.size())
|
| 291 |
+
elif isinstance(tensor, float) or isinstance(tensor, int):
|
| 292 |
+
return (1,)
|
| 293 |
+
else:
|
| 294 |
+
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
_math_operation_value_map = {x.value: x for x in MathOperation}
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def backend_math_operation(math_op: MathOperation):
|
| 301 |
+
if math_op.value not in _math_operation_value_map.keys():
|
| 302 |
+
raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.")
|
| 303 |
+
return _math_operation_value_map[math_op.value]
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def construct_backend_td(td: cutlass_cppgen.TileDescription,
|
| 307 |
+
kernel_schedule: cutlass_cppgen.KernelScheduleType,
|
| 308 |
+
epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
|
| 309 |
+
tile_scheduler: cutlass_cppgen.TileSchedulerType) -> TileDescription:
|
| 310 |
+
mi = td.math_instruction
|
| 311 |
+
backend_mi = MathInstruction(
|
| 312 |
+
mi.instruction_shape,
|
| 313 |
+
mi.element_a,
|
| 314 |
+
mi.element_b,
|
| 315 |
+
mi.element_accumulator,
|
| 316 |
+
mi.opcode_class,
|
| 317 |
+
backend_math_operation(mi.math_operation)
|
| 318 |
+
)
|
| 319 |
+
cluster_shape = td.cluster_shape if hasattr(td, "cluster_shape") else [1, 1, 1]
|
| 320 |
+
return TileDescription(td.threadblock_shape, td.stages, td.warp_count,
|
| 321 |
+
backend_mi, cluster_shape, kernel_schedule, epilogue_schedule, tile_scheduler)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def td_from_profiler_op(op) -> TileDescription:
|
| 325 |
+
"""
|
| 326 |
+
Converts the profiler's TileDescription in ``op`` into the backend TileDescription
|
| 327 |
+
|
| 328 |
+
:param op: profiler Operation
|
| 329 |
+
|
| 330 |
+
:returns: backend TileDescription
|
| 331 |
+
:rtype: cutlass_cppgen.backend.TileDescription
|
| 332 |
+
"""
|
| 333 |
+
kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None
|
| 334 |
+
eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None
|
| 335 |
+
tschedule = op.tile_scheduler if hasattr(op, 'tile_scheduler') else None
|
| 336 |
+
return construct_backend_td(op.tile_description, kschedule, eschedule, tschedule)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def td_from_profiler_td(td: TileDescription) -> TileDescription:
|
| 340 |
+
"""
|
| 341 |
+
Converts the profiler's TileDescription into the backend TileDescription
|
| 342 |
+
|
| 343 |
+
:param td: profiler TileDescription
|
| 344 |
+
:type td: cutlass_cppgen.TileDescription
|
| 345 |
+
|
| 346 |
+
:returns: backend TileDescription
|
| 347 |
+
:rtype: cutlass_cppgen.backend.TileDescription
|
| 348 |
+
"""
|
| 349 |
+
return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def to_camel_case(snake_str):
|
| 353 |
+
return "".join(x.capitalize() for x in snake_str.lower().split("_"))
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def getattr_enum(obj, attr_name):
|
| 357 |
+
# The attr_name is under the snake_case
|
| 358 |
+
camel_attr = to_camel_case(attr_name)
|
| 359 |
+
if hasattr(obj, camel_attr):
|
| 360 |
+
return getattr(obj, camel_attr)
|
| 361 |
+
else:
|
| 362 |
+
raise Exception(f"Invalid option: {attr_name}")
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
import importlib
|
| 33 |
+
from typing import Any
|
| 34 |
+
|
| 35 |
+
def lazy_import(mod_name: str) -> Any:
|
| 36 |
+
class Lazy:
|
| 37 |
+
def __getattr__(self, name:str) -> Any:
|
| 38 |
+
module = importlib.import_module(mod_name)
|
| 39 |
+
return getattr(module, name)
|
| 40 |
+
|
| 41 |
+
return Lazy()
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Profiler based on the cuda events
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import re
|
| 38 |
+
import subprocess
|
| 39 |
+
|
| 40 |
+
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 41 |
+
cuda = lazy_import("cuda.cuda")
|
| 42 |
+
cudart = lazy_import("cuda.cudart")
|
| 43 |
+
import numpy as np
|
| 44 |
+
|
| 45 |
+
from cutlass_cppgen import CUTLASS_PATH
|
| 46 |
+
from cutlass_cppgen.backend.library import DataTypeSize
|
| 47 |
+
from cutlass_cppgen.op.op import OperationBase
|
| 48 |
+
from cutlass_cppgen.shape import GemmCoord
|
| 49 |
+
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class GpuTimer:
|
| 53 |
+
def __init__(self) -> None:
|
| 54 |
+
self.events = [
|
| 55 |
+
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
|
| 56 |
+
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
def start(self, stream=None):
|
| 60 |
+
if not stream:
|
| 61 |
+
stream = cuda.CUstream(0)
|
| 62 |
+
|
| 63 |
+
(err,) = cuda.cuEventRecord(self.events[0], stream)
|
| 64 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 65 |
+
raise RuntimeError(f"CUDA Error {str(err)}")
|
| 66 |
+
|
| 67 |
+
def stop(self, stream=None):
|
| 68 |
+
if not stream:
|
| 69 |
+
stream = cuda.CUstream(0)
|
| 70 |
+
|
| 71 |
+
(err,) = cuda.cuEventRecord(self.events[1], stream)
|
| 72 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 73 |
+
raise RuntimeError(f"CUDA Error {str(err)}")
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
def stop_and_wait(self, stream=None):
|
| 77 |
+
if not stream:
|
| 78 |
+
stream = cuda.CUstream(0)
|
| 79 |
+
|
| 80 |
+
self.stop(stream)
|
| 81 |
+
if stream:
|
| 82 |
+
(err,) = cuda.cuStreamSynchronize(stream)
|
| 83 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 84 |
+
raise RuntimeError(f"CUDA Error {str(err)}")
|
| 85 |
+
else:
|
| 86 |
+
(err,) = cudart.cudaDeviceSynchronize()
|
| 87 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 88 |
+
raise RuntimeError(f"CUDA Error {str(err)}")
|
| 89 |
+
|
| 90 |
+
def duration(self, iterations=1):
|
| 91 |
+
err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1])
|
| 92 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 93 |
+
raise RuntimeError(f"CUDA Error {str(err)}")
|
| 94 |
+
return duration / float(iterations)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class CUDAEventProfiler:
|
| 98 |
+
def __init__(self, op: OperationBase, warmup_iterations: int=500, iterations: int=500, *args, **kwargs) -> None:
|
| 99 |
+
self.arguments = op.run(*args, **kwargs)
|
| 100 |
+
self.operation = op.operation
|
| 101 |
+
self.warmup_iterations = warmup_iterations
|
| 102 |
+
self.iterations = iterations
|
| 103 |
+
self.timer = GpuTimer()
|
| 104 |
+
|
| 105 |
+
#
|
| 106 |
+
# Cutlass Python Interface Profiler
|
| 107 |
+
#
|
| 108 |
+
|
| 109 |
+
def __call__(self):
|
| 110 |
+
for _ in range(self.warmup_iterations):
|
| 111 |
+
self.operation.run(self.arguments)
|
| 112 |
+
|
| 113 |
+
self.timer.start()
|
| 114 |
+
for _ in range(self.iterations):
|
| 115 |
+
self.operation.run(self.arguments)
|
| 116 |
+
|
| 117 |
+
self.timer.stop_and_wait()
|
| 118 |
+
runtime = self.timer.duration(self.iterations)
|
| 119 |
+
return runtime
|
| 120 |
+
|
| 121 |
+
#
|
| 122 |
+
# CUTLASS Profiler
|
| 123 |
+
#
|
| 124 |
+
|
| 125 |
+
def run_cutlass_profiler(self):
|
| 126 |
+
alpha = 1.0
|
| 127 |
+
beta = 1.0
|
| 128 |
+
|
| 129 |
+
profiler_path = CUTLASS_PATH + "/build/tools/profiler/cutlass_profiler"
|
| 130 |
+
kernel_name = self.operation.procedural_name()
|
| 131 |
+
verification_providers = "device"
|
| 132 |
+
provider = "cutlass"
|
| 133 |
+
problem_size = self.arguments.problem_size
|
| 134 |
+
|
| 135 |
+
if "cutlass3x" in kernel_name:
|
| 136 |
+
# cutlass3x generator only have column-major output
|
| 137 |
+
layout_name = self.operation.layout_name_3x()
|
| 138 |
+
if layout_name[-1] == "t":
|
| 139 |
+
new_layout_name = "".join(["n" for l in layout_name if l == "t" or "t"])
|
| 140 |
+
problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k)
|
| 141 |
+
kernel_name = kernel_name.replace(layout_name, new_layout_name)
|
| 142 |
+
|
| 143 |
+
batch_count = self.arguments.batch_count
|
| 144 |
+
|
| 145 |
+
cmd = f"{profiler_path} --kernels={kernel_name} --verification-providers={verification_providers} " \
|
| 146 |
+
f"--providers={provider} --m={problem_size.m()} --n={problem_size.n()} --k={problem_size.k()} " \
|
| 147 |
+
f"--batch_count={batch_count} --alpha={alpha} --beta={beta} "\
|
| 148 |
+
f"--warmup-iterations={self.warmup_iterations} --profiling-iterations={self.iterations}"
|
| 149 |
+
|
| 150 |
+
result = subprocess.getoutput(cmd)
|
| 151 |
+
|
| 152 |
+
m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
|
| 153 |
+
runtime = float(m.group("runtime"))
|
| 154 |
+
|
| 155 |
+
m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
|
| 156 |
+
bytes = int(m.group("bytes"))
|
| 157 |
+
|
| 158 |
+
m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
|
| 159 |
+
flops = int(m.group("flops"))
|
| 160 |
+
|
| 161 |
+
# check if the problem size matches
|
| 162 |
+
assert bytes == self.bytes(problem_size, batch_count, beta)
|
| 163 |
+
assert flops == self.flops(problem_size, batch_count, beta)
|
| 164 |
+
|
| 165 |
+
return runtime
|
| 166 |
+
|
| 167 |
+
def bytes(self, problem_size, batch_count=1, beta=0.0):
|
| 168 |
+
m = problem_size.m()
|
| 169 |
+
n = problem_size.n()
|
| 170 |
+
k = problem_size.k()
|
| 171 |
+
|
| 172 |
+
bytes = (
|
| 173 |
+
(DataTypeSize[self.operation.A.element] * m // 8) * k
|
| 174 |
+
+ (DataTypeSize[self.operation.B.element] * n // 8) * k
|
| 175 |
+
+ (DataTypeSize[self.operation.C.element] * m // 8) * n
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if beta != 0:
|
| 179 |
+
bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n
|
| 180 |
+
|
| 181 |
+
bytes *= batch_count
|
| 182 |
+
|
| 183 |
+
return bytes
|
| 184 |
+
|
| 185 |
+
def flops(self, problem_size, batch_count=1, beta=0.0):
|
| 186 |
+
m = problem_size.m()
|
| 187 |
+
n = problem_size.n()
|
| 188 |
+
k = problem_size.k()
|
| 189 |
+
|
| 190 |
+
flops_ = (m * n * k) * 2 * batch_count
|
| 191 |
+
|
| 192 |
+
if beta != 0:
|
| 193 |
+
flops_ += m * n * batch_count * 2
|
| 194 |
+
|
| 195 |
+
return flops_
|
| 196 |
+
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
import os
|
| 34 |
+
import sys
|
| 35 |
+
|
| 36 |
+
from . import conv2d_operation
|
| 37 |
+
from . import conv3d_operation
|
| 38 |
+
from . import emit_kernel_listing
|
| 39 |
+
from . import gemm_operation
|
| 40 |
+
|
| 41 |
+
if '-m' not in sys.argv:
|
| 42 |
+
# Do not import generator when running python -m cutlass_library.generator to
|
| 43 |
+
# avoid double-import warnings
|
| 44 |
+
from . import generator
|
| 45 |
+
|
| 46 |
+
from . import library
|
| 47 |
+
from . import manifest
|
| 48 |
+
from . import rank_2k_operation
|
| 49 |
+
from . import rank_k_operation
|
| 50 |
+
from . import symm_operation
|
| 51 |
+
from . import trmm_operation
|
| 52 |
+
# Make enum types from library.py accessible via cutlass_library.*
|
| 53 |
+
from .library import *
|
| 54 |
+
|
| 55 |
+
# Set up `source` to point to the path containing the CUTLASS source.
|
| 56 |
+
# Check first if the path contains a `source` subdirectory -- this will
|
| 57 |
+
# be the case when the package has been installed via pip. Otherwise,
|
| 58 |
+
# default to the root of CUTLASS.
|
| 59 |
+
install_source_path = os.path.join(__path__[0], 'source')
|
| 60 |
+
if os.path.isdir(install_source_path):
|
| 61 |
+
source_path = install_source_path
|
| 62 |
+
else:
|
| 63 |
+
source_path = os.path.join(__path__[0], '../..')
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for emitting Conv2d kernels
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import enum
|
| 38 |
+
import logging
|
| 39 |
+
import os.path
|
| 40 |
+
import shutil
|
| 41 |
+
from string import Template
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
import builtins
|
| 45 |
+
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 46 |
+
raise ImportError("Disabling attempt to import cutlass_library")
|
| 47 |
+
from cutlass_library.library import *
|
| 48 |
+
from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
|
| 49 |
+
except ImportError:
|
| 50 |
+
from library import *
|
| 51 |
+
from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
|
| 52 |
+
|
| 53 |
+
_LOGGER = logging.getLogger(__name__)
|
| 54 |
+
|
| 55 |
+
###################################################################################################
|
| 56 |
+
|
| 57 |
+
#
|
| 58 |
+
class Conv2dOperation:
|
| 59 |
+
#
|
| 60 |
+
def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \
|
| 61 |
+
stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1, \
|
| 62 |
+
group_mode = GroupMode.NoneGroup):
|
| 63 |
+
|
| 64 |
+
self.operation_kind = OperationKind.Conv2d
|
| 65 |
+
self.arch = arch
|
| 66 |
+
self.tile_description = tile_description
|
| 67 |
+
self.conv_kind = conv_kind
|
| 68 |
+
self.A = A
|
| 69 |
+
self.B = B
|
| 70 |
+
self.C = C
|
| 71 |
+
self.element_epilogue = element_epilogue
|
| 72 |
+
self.epilogue_functor = epilogue_functor
|
| 73 |
+
self.iterator_algorithm = iterator_algorithm
|
| 74 |
+
self.stride_support = stride_support
|
| 75 |
+
self.swizzling_functor = swizzling_functor
|
| 76 |
+
self.group_mode = group_mode
|
| 77 |
+
#
|
| 78 |
+
def is_complex(self):
|
| 79 |
+
complex_operators = [
|
| 80 |
+
MathOperation.multiply_add_complex,
|
| 81 |
+
MathOperation.multiply_add_complex_gaussian
|
| 82 |
+
]
|
| 83 |
+
return self.tile_description.math_instruction.math_operation in complex_operators
|
| 84 |
+
|
| 85 |
+
#
|
| 86 |
+
def is_mixed_input(self):
|
| 87 |
+
return self.A.element != self.B.element
|
| 88 |
+
|
| 89 |
+
#
|
| 90 |
+
def accumulator_type(self):
|
| 91 |
+
accum = self.tile_description.math_instruction.element_accumulator
|
| 92 |
+
|
| 93 |
+
if self.is_complex():
|
| 94 |
+
return get_complex_from_real(accum)
|
| 95 |
+
|
| 96 |
+
return accum
|
| 97 |
+
|
| 98 |
+
#
|
| 99 |
+
def core_name(self):
|
| 100 |
+
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
| 101 |
+
|
| 102 |
+
intermediate_type = ''
|
| 103 |
+
|
| 104 |
+
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
|
| 105 |
+
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
| 106 |
+
if self.tile_description.math_instruction.element_a != self.A.element and \
|
| 107 |
+
self.tile_description.math_instruction.element_a != self.accumulator_type():
|
| 108 |
+
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
| 109 |
+
else:
|
| 110 |
+
inst_shape = ''
|
| 111 |
+
|
| 112 |
+
return "%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \
|
| 113 |
+
inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm])
|
| 114 |
+
|
| 115 |
+
#
|
| 116 |
+
def extended_name(self):
|
| 117 |
+
''' Append data types if they differ from compute type. '''
|
| 118 |
+
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
| 119 |
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 120 |
+
extended_name = "${element_c}_${core_name}_${element_a}"
|
| 121 |
+
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
| 122 |
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 123 |
+
extended_name = "${core_name}_${element_a}"
|
| 124 |
+
else:
|
| 125 |
+
extended_name = "${core_name}"
|
| 126 |
+
|
| 127 |
+
extended_name = SubstituteTemplate(extended_name, {
|
| 128 |
+
'element_a': DataTypeNames[self.A.element],
|
| 129 |
+
'element_c': DataTypeNames[self.C.element],
|
| 130 |
+
'core_name': self.core_name()
|
| 131 |
+
})
|
| 132 |
+
|
| 133 |
+
return extended_name
|
| 134 |
+
|
| 135 |
+
#
|
| 136 |
+
def layout_name(self):
|
| 137 |
+
return "%s" % (ShortLayoutTypeNames[self.A.layout])
|
| 138 |
+
|
| 139 |
+
#
|
| 140 |
+
def configuration_name(self):
|
| 141 |
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 142 |
+
|
| 143 |
+
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
| 144 |
+
|
| 145 |
+
threadblock = self.tile_description.procedural_name()
|
| 146 |
+
|
| 147 |
+
# grouped conv
|
| 148 |
+
if self.group_mode != GroupMode.NoneGroup:
|
| 149 |
+
group_conv_name = f"{GroupModeNames[self.group_mode]}_"
|
| 150 |
+
else:
|
| 151 |
+
group_conv_name = ""
|
| 152 |
+
|
| 153 |
+
if self.stride_support == StrideSupport.Unity and self.conv_kind == ConvKind.Dgrad:
|
| 154 |
+
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}"
|
| 155 |
+
else:
|
| 156 |
+
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}"
|
| 157 |
+
|
| 158 |
+
return SubstituteTemplate(
|
| 159 |
+
configuration_name,
|
| 160 |
+
{
|
| 161 |
+
'opcode_class': opcode_class_name,
|
| 162 |
+
'extended_name': self.extended_name(),
|
| 163 |
+
'threadblock': threadblock,
|
| 164 |
+
'layout': self.layout_name(),
|
| 165 |
+
'alignment': "%d" % self.A.alignment,
|
| 166 |
+
'group_conv_name': group_conv_name
|
| 167 |
+
}
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
#
|
| 171 |
+
def procedural_name(self):
|
| 172 |
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 173 |
+
return self.configuration_name()
|
| 174 |
+
|
| 175 |
+
###################################################################################################
|
| 176 |
+
#
|
| 177 |
+
# Emits single instances of a CUTLASS device-wide operator
|
| 178 |
+
#
|
| 179 |
+
###################################################################################################
|
| 180 |
+
|
| 181 |
+
class EmitConv2dInstance:
|
| 182 |
+
def __init__(self):
|
| 183 |
+
# Emitter for CUTLASS 3 convolution operations
|
| 184 |
+
self.conv3x_emitter = EmitConv3xInstance()
|
| 185 |
+
self.template = """
|
| 186 |
+
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
| 187 |
+
using ${operation_name}_base =
|
| 188 |
+
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
|
| 189 |
+
${element_a},
|
| 190 |
+
${layout_a},
|
| 191 |
+
${element_b},
|
| 192 |
+
${layout_b},
|
| 193 |
+
${element_c},
|
| 194 |
+
${layout_c},
|
| 195 |
+
${element_accumulator},
|
| 196 |
+
${opcode_class},
|
| 197 |
+
${arch},
|
| 198 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 199 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
| 200 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 201 |
+
${epilogue_functor}<
|
| 202 |
+
${element_c},
|
| 203 |
+
${epilogue_vector_length},
|
| 204 |
+
${element_accumulator},
|
| 205 |
+
${element_epilogue}
|
| 206 |
+
>,
|
| 207 |
+
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
| 208 |
+
${stages},
|
| 209 |
+
${math_operator},
|
| 210 |
+
${iterator_algorithm},
|
| 211 |
+
${stride_support},
|
| 212 |
+
${align_a},
|
| 213 |
+
${align_b}
|
| 214 |
+
>::Kernel;
|
| 215 |
+
"""
|
| 216 |
+
self.template_group_conv = """
|
| 217 |
+
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
| 218 |
+
using ${operation_name}_base =
|
| 219 |
+
typename cutlass::conv::kernel::DefaultConv2dGroup${conv_kind_name}<
|
| 220 |
+
${element_a},
|
| 221 |
+
${layout_a},
|
| 222 |
+
${element_b},
|
| 223 |
+
${layout_b},
|
| 224 |
+
${element_c},
|
| 225 |
+
${layout_c},
|
| 226 |
+
${element_accumulator},
|
| 227 |
+
${opcode_class},
|
| 228 |
+
${arch},
|
| 229 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 230 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
| 231 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 232 |
+
${epilogue_functor}<
|
| 233 |
+
${element_c},
|
| 234 |
+
${epilogue_vector_length},
|
| 235 |
+
${element_accumulator},
|
| 236 |
+
${element_epilogue}
|
| 237 |
+
>,
|
| 238 |
+
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
| 239 |
+
${stages},
|
| 240 |
+
${math_operator},
|
| 241 |
+
${group_mode},
|
| 242 |
+
${iterator_algorithm},
|
| 243 |
+
${stride_support},
|
| 244 |
+
${align_a},
|
| 245 |
+
${align_b}
|
| 246 |
+
>::Kernel;
|
| 247 |
+
"""
|
| 248 |
+
self.template_depthwise_direct_conv = """
|
| 249 |
+
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
| 250 |
+
using ${operation_name}_base =
|
| 251 |
+
typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConv${conv_kind_name}<
|
| 252 |
+
${element_a},
|
| 253 |
+
${layout_a},
|
| 254 |
+
${element_b},
|
| 255 |
+
${layout_b},
|
| 256 |
+
${element_c},
|
| 257 |
+
${layout_c},
|
| 258 |
+
${element_accumulator},
|
| 259 |
+
${opcode_class},
|
| 260 |
+
${arch},
|
| 261 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 262 |
+
cutlass::conv::TensorNHWCShape<${threadblock_output_shape_n}, ${threadblock_output_shape_p}, ${threadblock_output_shape_q}, ${groups_per_cta}>,
|
| 263 |
+
cutlass::MatrixShape<${filter_shape_r}, ${filter_shape_s}>,
|
| 264 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 265 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 266 |
+
${epilogue_functor}<
|
| 267 |
+
${element_c},
|
| 268 |
+
${epilogue_vector_length},
|
| 269 |
+
${element_accumulator},
|
| 270 |
+
${element_epilogue},
|
| 271 |
+
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
| 272 |
+
>,
|
| 273 |
+
|
| 274 |
+
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
|
| 275 |
+
1,
|
| 276 |
+
${threadblock_output_shape_n},
|
| 277 |
+
${threadblock_output_shape_p},
|
| 278 |
+
${threadblock_output_shape_q}>,
|
| 279 |
+
${stages},
|
| 280 |
+
${math_operator},
|
| 281 |
+
${iterator_algorithm},
|
| 282 |
+
${stride_support},
|
| 283 |
+
cutlass::MatrixShape<${stride_r}, ${stride_s}>,
|
| 284 |
+
cutlass::MatrixShape<${dilation_r}, ${dilation_s}>
|
| 285 |
+
>::Kernel;
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
def arch_number_to_type(self, arch: int):
|
| 289 |
+
return f"cutlass::arch::Sm{arch}"
|
| 290 |
+
|
| 291 |
+
def emit(self, operation):
|
| 292 |
+
_LOGGER.debug("*** EmitConv2dInstance::emit")
|
| 293 |
+
_LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name())
|
| 294 |
+
|
| 295 |
+
if hasattr(operation, 'is_3x') and operation.is_3x:
|
| 296 |
+
_LOGGER.debug("*** CUTLASS 3 operation")
|
| 297 |
+
return self.conv3x_emitter.emit(operation)
|
| 298 |
+
|
| 299 |
+
_LOGGER.debug("*** CUTLASS 2 operation")
|
| 300 |
+
|
| 301 |
+
warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
|
| 302 |
+
|
| 303 |
+
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
| 304 |
+
|
| 305 |
+
values = {
|
| 306 |
+
'operation_name': operation.procedural_name(),
|
| 307 |
+
'conv_kind': ConvKindTag[operation.conv_kind],
|
| 308 |
+
'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
|
| 309 |
+
'element_a': DataTypeTag[operation.A.element],
|
| 310 |
+
'layout_a': LayoutTag[operation.A.layout],
|
| 311 |
+
'element_b': DataTypeTag[operation.B.element],
|
| 312 |
+
'layout_b': LayoutTag[operation.B.layout],
|
| 313 |
+
'element_c': DataTypeTag[operation.C.element],
|
| 314 |
+
'layout_c': LayoutTag[operation.C.layout],
|
| 315 |
+
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
| 316 |
+
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 317 |
+
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
| 318 |
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
| 319 |
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
| 320 |
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
| 321 |
+
'warp_shape_m': str(warp_shape[0]),
|
| 322 |
+
'warp_shape_n': str(warp_shape[1]),
|
| 323 |
+
'warp_shape_k': str(warp_shape[2]),
|
| 324 |
+
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 325 |
+
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 326 |
+
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 327 |
+
'epilogue_vector_length': str(epilogue_vector_length),
|
| 328 |
+
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
| 329 |
+
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 330 |
+
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
| 331 |
+
'stages': str(operation.tile_description.stages),
|
| 332 |
+
'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
|
| 333 |
+
'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
|
| 334 |
+
'stride_support': StrideSupportTag[operation.stride_support],
|
| 335 |
+
'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else \
|
| 336 |
+
MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
| 337 |
+
'align_a': str(operation.A.alignment),
|
| 338 |
+
'align_b': str(operation.B.alignment),
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
if operation.group_mode == GroupMode.NoneGroup:
|
| 342 |
+
_LOGGER.debug("*** group_mode=NoneGroup")
|
| 343 |
+
return SubstituteTemplate(self.template, values)
|
| 344 |
+
|
| 345 |
+
elif operation.group_mode == GroupMode.Depthwise:
|
| 346 |
+
_LOGGER.debug("*** group_mode=Depthwise")
|
| 347 |
+
values['group_mode'] = GroupModeTag[operation.group_mode]
|
| 348 |
+
# Setup other template params
|
| 349 |
+
values['threadblock_output_shape_n'] = str(operation.tile_description.threadblock_output_shape[0])
|
| 350 |
+
values['threadblock_output_shape_p'] = str(operation.tile_description.threadblock_output_shape[1])
|
| 351 |
+
values['threadblock_output_shape_q'] = str(operation.tile_description.threadblock_output_shape[2])
|
| 352 |
+
|
| 353 |
+
values['groups_per_cta'] = str(operation.tile_description.threadblock_output_shape[3])
|
| 354 |
+
|
| 355 |
+
values['filter_shape_r'] = str(operation.tile_description.filter_shape[0])
|
| 356 |
+
values['filter_shape_s'] = str(operation.tile_description.filter_shape[1])
|
| 357 |
+
|
| 358 |
+
values['stride_r'] = str(operation.tile_description.stride[0])
|
| 359 |
+
values['stride_s'] = str(operation.tile_description.stride[1])
|
| 360 |
+
|
| 361 |
+
values['dilation_r'] = str(operation.tile_description.dilation[0])
|
| 362 |
+
values['dilation_s'] = str(operation.tile_description.dilation[1])
|
| 363 |
+
|
| 364 |
+
return SubstituteTemplate(self.template_depthwise_direct_conv, values)
|
| 365 |
+
|
| 366 |
+
else:
|
| 367 |
+
_LOGGER.debug("*** group_mode=" + GroupModeTag[operation.group_mode])
|
| 368 |
+
values['group_mode'] = GroupModeTag[operation.group_mode]
|
| 369 |
+
return SubstituteTemplate(self.template_group_conv, values)
|
| 370 |
+
|
| 371 |
+
###################################################################################################
|
| 372 |
+
#
|
| 373 |
+
# Generator functions for all layouts
|
| 374 |
+
#
|
| 375 |
+
###################################################################################################
|
| 376 |
+
|
| 377 |
+
#
|
| 378 |
+
def GenerateConv2dTensorOp(manifest, tile_descriptions, min_cc, align = 128):
|
| 379 |
+
_LOGGER.debug("*** GenerateConv2dTensorOp")
|
| 380 |
+
|
| 381 |
+
for tile in tile_descriptions:
|
| 382 |
+
for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
|
| 383 |
+
|
| 384 |
+
if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]):
|
| 385 |
+
|
| 386 |
+
#
|
| 387 |
+
output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
|
| 388 |
+
if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
|
| 389 |
+
else [tile.math_instruction.element_accumulator,]
|
| 390 |
+
|
| 391 |
+
for output_type in output_types:
|
| 392 |
+
A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_a]))
|
| 393 |
+
B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_b]))
|
| 394 |
+
C = TensorDescription(output_type, LayoutType.TensorNHWC, max(1, int(align / DataTypeSize[output_type])))
|
| 395 |
+
|
| 396 |
+
manifest.append(Conv2dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator))
|
| 397 |
+
|
| 398 |
+
class EmitConv2dIncludes:
|
| 399 |
+
'''Emit includes that are specific to the operation.'''
|
| 400 |
+
|
| 401 |
+
def __init__(self):
|
| 402 |
+
self.includes = ['conv2d_operation.h']
|
| 403 |
+
self.emitter_3x = EmitConv3xIncludes()
|
| 404 |
+
|
| 405 |
+
def operation_is_3x(self, operation) -> bool:
|
| 406 |
+
"""Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
|
| 407 |
+
return hasattr(operation, 'is_3x') and operation.is_3x
|
| 408 |
+
|
| 409 |
+
def emit(self, operation) -> str:
|
| 410 |
+
if self.operation_is_3x(operation):
|
| 411 |
+
return self.emitter_3x.emit(operation)
|
| 412 |
+
|
| 413 |
+
return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \
|
| 414 |
+
"\n\n///////////////////////////////////////////////////////////////////////////////////////////////////"
|
| 415 |
+
|
| 416 |
+
###################################################################################################
|
| 417 |
+
#
|
| 418 |
+
# Emitters functions for all targets
|
| 419 |
+
#
|
| 420 |
+
###################################################################################################
|
| 421 |
+
|
| 422 |
+
class EmitConv2dConfigurationLibrary:
|
| 423 |
+
def __init__(self, operation_path, configuration_name):
|
| 424 |
+
self.configuration_name = configuration_name
|
| 425 |
+
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
|
| 426 |
+
|
| 427 |
+
self.instance_emitter = EmitConv2dInstance()
|
| 428 |
+
self.includes_emitter = EmitConv2dIncludes()
|
| 429 |
+
|
| 430 |
+
self.header_template = """
|
| 431 |
+
/*
|
| 432 |
+
Generated by conv2d_operation.py - Do not edit.
|
| 433 |
+
*/
|
| 434 |
+
|
| 435 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 436 |
+
|
| 437 |
+
#include "cutlass/cutlass.h"
|
| 438 |
+
#include "cutlass/library/library.h"
|
| 439 |
+
#include "cutlass/library/manifest.h"
|
| 440 |
+
|
| 441 |
+
#include "library_internal.h"
|
| 442 |
+
"""
|
| 443 |
+
|
| 444 |
+
self.instance_template = """
|
| 445 |
+
${stub_begin}
|
| 446 |
+
${operation_instance}
|
| 447 |
+
// Derived class
|
| 448 |
+
struct ${operation_name} :
|
| 449 |
+
public ${operation_name}_base { };
|
| 450 |
+
${stub_end}
|
| 451 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 452 |
+
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
self.configuration_header = """
|
| 456 |
+
|
| 457 |
+
namespace cutlass {
|
| 458 |
+
namespace library {
|
| 459 |
+
|
| 460 |
+
// Initialize all instances
|
| 461 |
+
void initialize_${configuration_name}(Manifest &manifest) {
|
| 462 |
+
"""
|
| 463 |
+
|
| 464 |
+
self.configuration_instance = """${stub_begin}
|
| 465 |
+
using Operation_${operation_name} = cutlass::conv::device::${kernel_name}<
|
| 466 |
+
${operation_name}>;
|
| 467 |
+
|
| 468 |
+
manifest.append(new cutlass::library::${operation_wrapper}<
|
| 469 |
+
Operation_${operation_name}
|
| 470 |
+
>(
|
| 471 |
+
"${operation_name}"
|
| 472 |
+
));
|
| 473 |
+
${stub_end}
|
| 474 |
+
"""
|
| 475 |
+
|
| 476 |
+
self.configuration_epilogue = "}\n"
|
| 477 |
+
|
| 478 |
+
self.epilogue_template = """
|
| 479 |
+
|
| 480 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 481 |
+
|
| 482 |
+
} // namespace library
|
| 483 |
+
} // namespace cutlass
|
| 484 |
+
|
| 485 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 486 |
+
|
| 487 |
+
"""
|
| 488 |
+
|
| 489 |
+
def operation_is_3x(self, operation):
|
| 490 |
+
"""Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
|
| 491 |
+
return hasattr(operation, 'is_3x') and operation.is_3x
|
| 492 |
+
|
| 493 |
+
def __enter__(self):
|
| 494 |
+
"""
|
| 495 |
+
Open the configuration_file, and write the "header" C++ code to it.
|
| 496 |
+
|
| 497 |
+
The "header" consists of a comment (that this is generated code,
|
| 498 |
+
so it should not be edited), and includes that are common
|
| 499 |
+
to all kinds of kernels.
|
| 500 |
+
"""
|
| 501 |
+
_LOGGER.debug('*** EmitConv2dConfigurationLibrary::__enter__')
|
| 502 |
+
_LOGGER.debug('*** configuration_path (file to write): ' +
|
| 503 |
+
str(self.configuration_path))
|
| 504 |
+
_LOGGER.debug('*** configuration_name: ' + self.configuration_name)
|
| 505 |
+
self.configuration_file = open(self.configuration_path, "w")
|
| 506 |
+
|
| 507 |
+
self.configuration_file.write(SubstituteTemplate(self.header_template, {
|
| 508 |
+
'configuration_name': self.configuration_name
|
| 509 |
+
}))
|
| 510 |
+
self.operations = []
|
| 511 |
+
return self
|
| 512 |
+
|
| 513 |
+
def emit(self, operation):
|
| 514 |
+
"""
|
| 515 |
+
Write three pieces of C++ code to the configuration_file
|
| 516 |
+
(that was opened by the __enter__ method above):
|
| 517 |
+
|
| 518 |
+
1. the header includes that are specific to the operation
|
| 519 |
+
(CUTLASS 2 vs. CUTLASS 3);
|
| 520 |
+
|
| 521 |
+
2. the "operation instance" (a "using" declaration ending in "_base"); and
|
| 522 |
+
|
| 523 |
+
3. the "operation name" (declaration and definition of a derived class
|
| 524 |
+
of the above operation instance).
|
| 525 |
+
|
| 526 |
+
The "using" declaration turns a C++ class name, possibly namespace-qualified,
|
| 527 |
+
possibly also with angle brackets, into a C-style, easily demangled identifier.
|
| 528 |
+
"""
|
| 529 |
+
_LOGGER.debug('*** EmitConv2dConfigurationLibrary::emit')
|
| 530 |
+
_LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name())
|
| 531 |
+
self.operations.append(operation)
|
| 532 |
+
|
| 533 |
+
self.configuration_file.write(self.includes_emitter.emit(operation))
|
| 534 |
+
|
| 535 |
+
stub_begin = ''
|
| 536 |
+
stub_end = ''
|
| 537 |
+
# It can be useful to stub (comment) out instantiations for testing.
|
| 538 |
+
# In this case, one need only set is_stub to True.
|
| 539 |
+
is_stub = False
|
| 540 |
+
if is_stub:
|
| 541 |
+
stub_begin = "// STUB for now\n#if 0"
|
| 542 |
+
stub_end = '#endif // 0'
|
| 543 |
+
|
| 544 |
+
self.configuration_file.write(Template(self.instance_template).substitute({
|
| 545 |
+
'configuration_name': self.configuration_name,
|
| 546 |
+
'operation_name': operation.procedural_name(),
|
| 547 |
+
'operation_instance': self.instance_emitter.emit(operation),
|
| 548 |
+
'stub_begin': stub_begin,
|
| 549 |
+
'stub_end': stub_end
|
| 550 |
+
}))
|
| 551 |
+
|
| 552 |
+
def __exit__(self, exception_type, exception_value, traceback):
|
| 553 |
+
"""
|
| 554 |
+
Write the rest of the C++ code to the configuration_file, and close the file.
|
| 555 |
+
|
| 556 |
+
The "rest of the C++ code" has the following components.
|
| 557 |
+
|
| 558 |
+
1. Configuration header: Open the namespace(s), and open the definition
|
| 559 |
+
of the "initialize_${configuration_name}" registration function
|
| 560 |
+
that registers the operation with the Manifest.
|
| 561 |
+
("Registration" helps turn C++ compile-time polymorphism
|
| 562 |
+
(via template parameters) into a run-time choice of parameters.)
|
| 563 |
+
|
| 564 |
+
2. Configuration instance: In the body of the registration function,
|
| 565 |
+
make a "using" declaration Operation_${operation_name} for the
|
| 566 |
+
operation type (which uses operation_name as its template argument).
|
| 567 |
+
Then, tell the manifest about the operation via a "manifest.append" call.
|
| 568 |
+
The argument of the call is a new instance of
|
| 569 |
+
"SomethingOperation<Operation_${operation_name}>"
|
| 570 |
+
(replace Something with a specific name).
|
| 571 |
+
|
| 572 |
+
3. Configuration epilogue: Close the definition of the registration function.
|
| 573 |
+
|
| 574 |
+
4. Epilogue template: Close the namespace(s).
|
| 575 |
+
"""
|
| 576 |
+
|
| 577 |
+
_LOGGER.debug('*** EmitConv2dConfigurationLibrary::__exit__')
|
| 578 |
+
_LOGGER.debug('*** configuration_path (file to write): ' +
|
| 579 |
+
str(self.configuration_path))
|
| 580 |
+
_LOGGER.debug('*** configuration_name: ' + self.configuration_name)
|
| 581 |
+
|
| 582 |
+
self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
|
| 583 |
+
'configuration_name': self.configuration_name
|
| 584 |
+
}))
|
| 585 |
+
|
| 586 |
+
for operation in self.operations:
|
| 587 |
+
stub_begin = ''
|
| 588 |
+
stub_end = ''
|
| 589 |
+
# It can be useful to stub (comment) out instantiations for testing.
|
| 590 |
+
# In this case, one need only set is_stub to True.
|
| 591 |
+
is_stub = False
|
| 592 |
+
if is_stub:
|
| 593 |
+
stub_begin = "// STUB for now\n#if 0"
|
| 594 |
+
stub_end = "#endif // 0"
|
| 595 |
+
|
| 596 |
+
if operation.group_mode == GroupMode.Depthwise:
|
| 597 |
+
kernel_name = 'DirectConvolution'
|
| 598 |
+
operation_wrapper = 'DirectConv2dOperation'
|
| 599 |
+
else:
|
| 600 |
+
kernel_name = 'ImplicitGemmConvolution'
|
| 601 |
+
operation_wrapper = 'Conv2dOperation'
|
| 602 |
+
if self.operation_is_3x(operation):
|
| 603 |
+
kernel_name = 'ConvUniversalAdapter'
|
| 604 |
+
operation_wrapper = 'ConvOperation3x'
|
| 605 |
+
|
| 606 |
+
self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
|
| 607 |
+
'configuration_name': self.configuration_name,
|
| 608 |
+
'operation_name': operation.procedural_name(),
|
| 609 |
+
'kernel_name': kernel_name,
|
| 610 |
+
'operation_wrapper': operation_wrapper,
|
| 611 |
+
'stub_begin': stub_begin,
|
| 612 |
+
'stub_end': stub_end
|
| 613 |
+
}))
|
| 614 |
+
|
| 615 |
+
self.configuration_file.write(self.configuration_epilogue)
|
| 616 |
+
self.configuration_file.write(self.epilogue_template)
|
| 617 |
+
self.configuration_file.close()
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
###################################################################################################
|
| 621 |
+
###################################################################################################
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for emitting Conv3d kernels
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import enum
|
| 38 |
+
import logging
|
| 39 |
+
import os.path
|
| 40 |
+
import shutil
|
| 41 |
+
from string import Template
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
import builtins
|
| 45 |
+
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 46 |
+
raise ImportError("Disabling attempt to import cutlass_library")
|
| 47 |
+
from cutlass_library.library import *
|
| 48 |
+
from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
|
| 49 |
+
except ImportError:
|
| 50 |
+
from library import *
|
| 51 |
+
from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
|
| 52 |
+
|
| 53 |
+
_LOGGER = logging.getLogger(__name__)
|
| 54 |
+
|
| 55 |
+
###################################################################################################
|
| 56 |
+
|
| 57 |
+
#
|
| 58 |
+
class Conv3dOperation:
|
| 59 |
+
#
|
| 60 |
+
def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \
|
| 61 |
+
stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4):
|
| 62 |
+
|
| 63 |
+
self.operation_kind = OperationKind.Conv3d
|
| 64 |
+
self.arch = arch
|
| 65 |
+
self.tile_description = tile_description
|
| 66 |
+
self.conv_kind = conv_kind
|
| 67 |
+
self.A = A
|
| 68 |
+
self.B = B
|
| 69 |
+
self.C = C
|
| 70 |
+
self.element_epilogue = element_epilogue
|
| 71 |
+
self.epilogue_functor = epilogue_functor
|
| 72 |
+
self.iterator_algorithm = iterator_algorithm
|
| 73 |
+
self.stride_support = stride_support
|
| 74 |
+
self.swizzling_functor = swizzling_functor
|
| 75 |
+
|
| 76 |
+
#
|
| 77 |
+
def is_mixed_input(self):
|
| 78 |
+
return self.A.element != self.B.element
|
| 79 |
+
|
| 80 |
+
#
|
| 81 |
+
def core_name(self):
|
| 82 |
+
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
| 83 |
+
|
| 84 |
+
intermediate_type = ''
|
| 85 |
+
|
| 86 |
+
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
|
| 87 |
+
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
| 88 |
+
if self.tile_description.math_instruction.element_a != self.A.element and \
|
| 89 |
+
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
| 90 |
+
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
| 91 |
+
else:
|
| 92 |
+
inst_shape = ''
|
| 93 |
+
|
| 94 |
+
return "%s%s%s%s3d_%s" % (ShortDataTypeNames[self.tile_description.math_instruction.element_accumulator], \
|
| 95 |
+
inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm])
|
| 96 |
+
|
| 97 |
+
#
|
| 98 |
+
def extended_name(self):
|
| 99 |
+
''' Append data types if they differ from compute type. '''
|
| 100 |
+
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
| 101 |
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 102 |
+
extended_name = "${element_c}_${core_name}_${element_a}"
|
| 103 |
+
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
| 104 |
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 105 |
+
extended_name = "${core_name}_${element_a}"
|
| 106 |
+
else:
|
| 107 |
+
extended_name = "${core_name}"
|
| 108 |
+
|
| 109 |
+
extended_name = SubstituteTemplate(extended_name, {
|
| 110 |
+
'element_a': DataTypeNames[self.A.element],
|
| 111 |
+
'element_c': DataTypeNames[self.C.element],
|
| 112 |
+
'core_name': self.core_name()
|
| 113 |
+
})
|
| 114 |
+
|
| 115 |
+
return extended_name
|
| 116 |
+
|
| 117 |
+
#
|
| 118 |
+
def configuration_name(self):
|
| 119 |
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 120 |
+
|
| 121 |
+
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
| 122 |
+
|
| 123 |
+
threadblock = "%dx%d_%dx%d" % (
|
| 124 |
+
self.tile_description.threadblock_shape[0],
|
| 125 |
+
self.tile_description.threadblock_shape[1],
|
| 126 |
+
self.tile_description.threadblock_shape[2],
|
| 127 |
+
self.tile_description.stages
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
if self.stride_support == StrideSupport.Unity:
|
| 131 |
+
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_unity_stride"
|
| 132 |
+
else:
|
| 133 |
+
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}"
|
| 134 |
+
|
| 135 |
+
return SubstituteTemplate(
|
| 136 |
+
configuration_name,
|
| 137 |
+
{
|
| 138 |
+
'opcode_class': opcode_class_name,
|
| 139 |
+
'extended_name': self.extended_name(),
|
| 140 |
+
'threadblock': threadblock,
|
| 141 |
+
}
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
#
|
| 145 |
+
def procedural_name(self):
|
| 146 |
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 147 |
+
return self.configuration_name()
|
| 148 |
+
|
| 149 |
+
###################################################################################################
|
| 150 |
+
#
|
| 151 |
+
# Emits single instances of a CUTLASS device-wide operator
|
| 152 |
+
#
|
| 153 |
+
###################################################################################################
|
| 154 |
+
|
| 155 |
+
class EmitConv3dInstance:
|
| 156 |
+
def __init__(self):
|
| 157 |
+
# Emitter for CUTLASS 3 convolution operations
|
| 158 |
+
self.conv3x_emitter = EmitConv3xInstance()
|
| 159 |
+
self.template = """
|
| 160 |
+
// Conv3d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
| 161 |
+
using ${operation_name}_base =
|
| 162 |
+
typename cutlass::conv::kernel::DefaultConv3d${conv_kind_name}<
|
| 163 |
+
${element_a},
|
| 164 |
+
cutlass::layout::TensorNDHWC,
|
| 165 |
+
${element_b},
|
| 166 |
+
cutlass::layout::TensorNDHWC,
|
| 167 |
+
${element_c},
|
| 168 |
+
cutlass::layout::TensorNDHWC,
|
| 169 |
+
${element_accumulator},
|
| 170 |
+
${opcode_class},
|
| 171 |
+
${arch},
|
| 172 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 173 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
| 174 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 175 |
+
${epilogue_functor}<
|
| 176 |
+
${element_c},
|
| 177 |
+
${epilogue_vector_length},
|
| 178 |
+
${element_accumulator},
|
| 179 |
+
${element_epilogue}
|
| 180 |
+
>,
|
| 181 |
+
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
| 182 |
+
${stages},
|
| 183 |
+
cutlass::arch::OpMultiplyAdd,
|
| 184 |
+
${iterator_algorithm},
|
| 185 |
+
${stride_support}
|
| 186 |
+
>::Kernel;
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def emit(self, operation):
|
| 190 |
+
_LOGGER.debug("*** EmitConv3dInstance::emit")
|
| 191 |
+
_LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name())
|
| 192 |
+
|
| 193 |
+
if hasattr(operation, 'is_3x') and operation.is_3x:
|
| 194 |
+
_LOGGER.debug("*** CUTLASS 3 operation")
|
| 195 |
+
return self.conv3x_emitter.emit(operation)
|
| 196 |
+
|
| 197 |
+
_LOGGER.debug("*** CUTLASS 2 operation")
|
| 198 |
+
|
| 199 |
+
warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
|
| 200 |
+
|
| 201 |
+
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
| 202 |
+
|
| 203 |
+
values = {
|
| 204 |
+
'operation_name': operation.procedural_name(),
|
| 205 |
+
'conv_kind': ConvKindTag[operation.conv_kind],
|
| 206 |
+
'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
|
| 207 |
+
'element_a': DataTypeTag[operation.A.element],
|
| 208 |
+
'layout_a': LayoutTag[operation.A.layout],
|
| 209 |
+
'element_b': DataTypeTag[operation.B.element],
|
| 210 |
+
'layout_b': LayoutTag[operation.B.layout],
|
| 211 |
+
'element_c': DataTypeTag[operation.C.element],
|
| 212 |
+
'layout_c': LayoutTag[operation.C.layout],
|
| 213 |
+
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
|
| 214 |
+
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 215 |
+
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
| 216 |
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
| 217 |
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
| 218 |
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
| 219 |
+
'warp_shape_m': str(warp_shape[0]),
|
| 220 |
+
'warp_shape_n': str(warp_shape[1]),
|
| 221 |
+
'warp_shape_k': str(warp_shape[2]),
|
| 222 |
+
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 223 |
+
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 224 |
+
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 225 |
+
'epilogue_vector_length': str(epilogue_vector_length),
|
| 226 |
+
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
| 227 |
+
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 228 |
+
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
| 229 |
+
'stages': str(operation.tile_description.stages),
|
| 230 |
+
'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
|
| 231 |
+
'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
|
| 232 |
+
'stride_support': StrideSupportTag[operation.stride_support]
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
return SubstituteTemplate(self.template, values)
|
| 236 |
+
|
| 237 |
+
###################################################################################################
|
| 238 |
+
#
|
| 239 |
+
# Generator functions for all layouts
|
| 240 |
+
#
|
| 241 |
+
###################################################################################################
|
| 242 |
+
|
| 243 |
+
#
|
| 244 |
+
def GenerateConv3dTensorOp(manifest, tile_descriptions, min_cc, align = 128):
|
| 245 |
+
|
| 246 |
+
for tile in tile_descriptions:
|
| 247 |
+
for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
|
| 248 |
+
|
| 249 |
+
if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]):
|
| 250 |
+
|
| 251 |
+
#
|
| 252 |
+
output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
|
| 253 |
+
if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
|
| 254 |
+
else [tile.math_instruction.element_accumulator,]
|
| 255 |
+
|
| 256 |
+
for output_type in output_types:
|
| 257 |
+
A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_a]))
|
| 258 |
+
B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_b]))
|
| 259 |
+
C = TensorDescription(output_type, LayoutType.TensorNDHWC, max(1, int(align / DataTypeSize[output_type])))
|
| 260 |
+
|
| 261 |
+
manifest.append(Conv3dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator))
|
| 262 |
+
|
| 263 |
+
class EmitConv3dIncludes:
|
| 264 |
+
'''Emit includes that are specific to the operation.'''
|
| 265 |
+
|
| 266 |
+
def __init__(self):
|
| 267 |
+
self.includes = ['conv3d_operation.h']
|
| 268 |
+
self.emitter_3x = EmitConv3xIncludes()
|
| 269 |
+
|
| 270 |
+
def operation_is_3x(self, operation) -> bool:
|
| 271 |
+
"""Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
|
| 272 |
+
return hasattr(operation, 'is_3x') and operation.is_3x
|
| 273 |
+
|
| 274 |
+
def emit(self, operation) -> str:
|
| 275 |
+
if self.operation_is_3x(operation):
|
| 276 |
+
return self.emitter_3x.emit(operation)
|
| 277 |
+
|
| 278 |
+
return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \
|
| 279 |
+
"\n\n///////////////////////////////////////////////////////////////////////////////////////////////////"
|
| 280 |
+
|
| 281 |
+
###################################################################################################
|
| 282 |
+
#
|
| 283 |
+
# Emitters functions for all targets
|
| 284 |
+
#
|
| 285 |
+
###################################################################################################
|
| 286 |
+
|
| 287 |
+
class EmitConv3dConfigurationLibrary:
|
| 288 |
+
def __init__(self, operation_path, configuration_name):
|
| 289 |
+
self.configuration_name = configuration_name
|
| 290 |
+
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
|
| 291 |
+
|
| 292 |
+
self.instance_emitter = EmitConv3dInstance()
|
| 293 |
+
self.includes_emitter = EmitConv3dIncludes()
|
| 294 |
+
|
| 295 |
+
self.header_template = """
|
| 296 |
+
/*
|
| 297 |
+
Generated by conv3d_operation.py - Do not edit.
|
| 298 |
+
*/
|
| 299 |
+
|
| 300 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 301 |
+
|
| 302 |
+
#include "cutlass/cutlass.h"
|
| 303 |
+
#include "cutlass/library/library.h"
|
| 304 |
+
#include "cutlass/library/manifest.h"
|
| 305 |
+
|
| 306 |
+
#include "library_internal.h"
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
self.instance_template = """
|
| 310 |
+
${stub_begin}
|
| 311 |
+
${operation_instance}
|
| 312 |
+
// Derived class
|
| 313 |
+
struct ${operation_name} :
|
| 314 |
+
public ${operation_name}_base { };
|
| 315 |
+
${stub_end}
|
| 316 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 317 |
+
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
self.configuration_header = """
|
| 321 |
+
|
| 322 |
+
namespace cutlass {
|
| 323 |
+
namespace library {
|
| 324 |
+
|
| 325 |
+
// Initialize all instances
|
| 326 |
+
void initialize_${configuration_name}(Manifest &manifest) {
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
self.configuration_instance = """${stub_begin}
|
| 330 |
+
using Operation_${operation_name} = cutlass::conv::device::${kernel_name}<
|
| 331 |
+
${operation_name}>;
|
| 332 |
+
|
| 333 |
+
manifest.append(new cutlass::library::${operation_wrapper}<
|
| 334 |
+
Operation_${operation_name}
|
| 335 |
+
>(
|
| 336 |
+
"${operation_name}"
|
| 337 |
+
));
|
| 338 |
+
${stub_end}
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
self.configuration_epilogue = "}\n"
|
| 342 |
+
|
| 343 |
+
self.epilogue_template = """
|
| 344 |
+
|
| 345 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 346 |
+
|
| 347 |
+
} // namespace library
|
| 348 |
+
} // namespace cutlass
|
| 349 |
+
|
| 350 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 351 |
+
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
def operation_is_3x(self, operation):
|
| 355 |
+
"""Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
|
| 356 |
+
return hasattr(operation, 'is_3x') and operation.is_3x
|
| 357 |
+
|
| 358 |
+
def __enter__(self):
|
| 359 |
+
"""
|
| 360 |
+
Open the configuration_file, and write the "header" C++ code to it.
|
| 361 |
+
|
| 362 |
+
The "header" consists of a comment (that this is generated code,
|
| 363 |
+
so it should not be edited), and includes that are common
|
| 364 |
+
to both the CUTLASS 2 and the CUTLASS 3 cases.
|
| 365 |
+
"""
|
| 366 |
+
_LOGGER.debug('*** EmitConv3dConfigurationLibrary::__enter__')
|
| 367 |
+
_LOGGER.debug('*** configuration_path (file to write): ' +
|
| 368 |
+
str(self.configuration_path))
|
| 369 |
+
_LOGGER.debug('*** configuration_name: ' + self.configuration_name)
|
| 370 |
+
self.configuration_file = open(self.configuration_path, "w")
|
| 371 |
+
|
| 372 |
+
self.configuration_file.write(SubstituteTemplate(self.header_template, {
|
| 373 |
+
'configuration_name': self.configuration_name
|
| 374 |
+
}))
|
| 375 |
+
self.operations = []
|
| 376 |
+
return self
|
| 377 |
+
|
| 378 |
+
def emit(self, operation):
|
| 379 |
+
"""
|
| 380 |
+
Write three pieces of C++ code to the configuration_file
|
| 381 |
+
(that was opened by the __enter__ method above):
|
| 382 |
+
|
| 383 |
+
1. the header includes that are specific to the operation
|
| 384 |
+
(CUTLASS 2 vs. CUTLASS 3);
|
| 385 |
+
|
| 386 |
+
2. the "operation instance" (a "using" declaration ending in "_base"); and
|
| 387 |
+
|
| 388 |
+
3. the "operation name" (declaration and definition of a derived class
|
| 389 |
+
of the above operation instance).
|
| 390 |
+
|
| 391 |
+
The "using" declaration turns a C++ class name, possibly namespace-qualified,
|
| 392 |
+
possibly also with angle brackets, into a C-style, easily demangled identifier.
|
| 393 |
+
"""
|
| 394 |
+
_LOGGER.debug('*** EmitConv3dConfigurationLibrary::emit')
|
| 395 |
+
_LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name())
|
| 396 |
+
self.operations.append(operation)
|
| 397 |
+
|
| 398 |
+
self.configuration_file.write(self.includes_emitter.emit(operation))
|
| 399 |
+
|
| 400 |
+
stub_begin = ''
|
| 401 |
+
stub_end = ''
|
| 402 |
+
# It can be useful to stub (comment) out instantiations for testing.
|
| 403 |
+
# In this case, one need only set is_stub to True.
|
| 404 |
+
is_stub = False
|
| 405 |
+
if is_stub:
|
| 406 |
+
stub_begin = "// STUB for now\n#if 0"
|
| 407 |
+
stub_end = '#endif // 0'
|
| 408 |
+
|
| 409 |
+
self.configuration_file.write(Template(self.instance_template).substitute({
|
| 410 |
+
'configuration_name': self.configuration_name,
|
| 411 |
+
'operation_name': operation.procedural_name(),
|
| 412 |
+
'operation_instance': self.instance_emitter.emit(operation),
|
| 413 |
+
'stub_begin': stub_begin,
|
| 414 |
+
'stub_end': stub_end
|
| 415 |
+
}))
|
| 416 |
+
|
| 417 |
+
def __exit__(self, exception_type, exception_value, traceback):
|
| 418 |
+
"""
|
| 419 |
+
Write the rest of the C++ code to the configuration_file, and close the file.
|
| 420 |
+
|
| 421 |
+
The "rest of the C++ code" has the following components.
|
| 422 |
+
|
| 423 |
+
1. Configuration header: Open the namespace(s), and open the definition
|
| 424 |
+
of the "initialize_${configuration_name}" registration function
|
| 425 |
+
that registers the operation with the Manifest.
|
| 426 |
+
("Registration" helps turn C++ compile-time polymorphism
|
| 427 |
+
(via template parameters) into a run-time choice of parameters.)
|
| 428 |
+
|
| 429 |
+
2. Configuration instance: In the body of the registration function,
|
| 430 |
+
make a "using" declaration Operation_${operation_name} for the
|
| 431 |
+
operation type (which uses operation_name as its template argument).
|
| 432 |
+
Then, tell the manifest about the operation via a "manifest.append" call.
|
| 433 |
+
The argument of the call is a new instance of
|
| 434 |
+
"SomethingOperation<Operation_${operation_name}>"
|
| 435 |
+
(replace Something with a specific name).
|
| 436 |
+
|
| 437 |
+
3. Configuration epilogue: Close the definition of the registration function.
|
| 438 |
+
|
| 439 |
+
4. Epilogue template: Close the namespace(s).
|
| 440 |
+
"""
|
| 441 |
+
|
| 442 |
+
_LOGGER.debug('*** EmitConv3dConfigurationLibrary::__exit__')
|
| 443 |
+
_LOGGER.debug('*** configuration_path (file to write): ' +
|
| 444 |
+
str(self.configuration_path))
|
| 445 |
+
_LOGGER.debug('*** configuration_name: ' + self.configuration_name)
|
| 446 |
+
|
| 447 |
+
self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
|
| 448 |
+
'configuration_name': self.configuration_name
|
| 449 |
+
}))
|
| 450 |
+
|
| 451 |
+
for operation in self.operations:
|
| 452 |
+
stub_begin = ''
|
| 453 |
+
stub_end = ''
|
| 454 |
+
# It can be useful to stub (comment) out instantiations for testing.
|
| 455 |
+
# In this case, one need only set is_stub to True.
|
| 456 |
+
is_stub = False
|
| 457 |
+
if is_stub:
|
| 458 |
+
stub_begin = "// STUB for now\n#if 0"
|
| 459 |
+
stub_end = "#endif // 0"
|
| 460 |
+
|
| 461 |
+
kernel_name = 'ImplicitGemmConvolution'
|
| 462 |
+
operation_wrapper = 'Conv3dOperation'
|
| 463 |
+
if self.operation_is_3x(operation):
|
| 464 |
+
kernel_name = 'ConvUniversalAdapter'
|
| 465 |
+
operation_wrapper = 'ConvOperation3x'
|
| 466 |
+
|
| 467 |
+
self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
|
| 468 |
+
'configuration_name': self.configuration_name,
|
| 469 |
+
'operation_name': operation.procedural_name(),
|
| 470 |
+
'kernel_name': kernel_name,
|
| 471 |
+
'operation_wrapper': operation_wrapper,
|
| 472 |
+
'stub_begin': stub_begin,
|
| 473 |
+
'stub_end': stub_end
|
| 474 |
+
}))
|
| 475 |
+
|
| 476 |
+
self.configuration_file.write(self.configuration_epilogue)
|
| 477 |
+
self.configuration_file.write(self.epilogue_template)
|
| 478 |
+
self.configuration_file.close()
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
###################################################################################################
|
| 482 |
+
###################################################################################################
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for emitting CUTLASS >= 3 convolution kernels
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import enum
|
| 38 |
+
import os.path
|
| 39 |
+
import shutil
|
| 40 |
+
import logging
|
| 41 |
+
from string import Template
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
import builtins
|
| 45 |
+
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 46 |
+
raise ImportError("Disabling attempt to import cutlass_library")
|
| 47 |
+
from cutlass_library.library import *
|
| 48 |
+
except ImportError:
|
| 49 |
+
from library import *
|
| 50 |
+
|
| 51 |
+
_LOGGER = logging.getLogger(__name__)
|
| 52 |
+
|
| 53 |
+
###################################################################################################
|
| 54 |
+
#
|
| 55 |
+
# Emits single instances of a CUTLASS device-wide operator
|
| 56 |
+
#
|
| 57 |
+
###################################################################################################
|
| 58 |
+
|
| 59 |
+
class EmitConv3xInstance:
|
| 60 |
+
def __init__(self):
|
| 61 |
+
_LOGGER.debug("*** EmitConv3xInstance::__init__")
|
| 62 |
+
|
| 63 |
+
# Define epilogue type first, so that the mainloop type
|
| 64 |
+
# can use it with StageCountAutoCarveout.
|
| 65 |
+
self.template = """
|
| 66 |
+
|
| 67 |
+
// CUTLASS >= 3 convolution ${conv_kind_name} kernel instance "${operation_name}"
|
| 68 |
+
using ${operation_name}_epilogue =
|
| 69 |
+
typename cutlass::epilogue::collective::CollectiveBuilder<
|
| 70 |
+
${arch},
|
| 71 |
+
${opcode_class_epi},
|
| 72 |
+
${mma_tile_shape}, // mma tile shape
|
| 73 |
+
${cluster_shape}, // cluster shape
|
| 74 |
+
${epi_tile_mn},
|
| 75 |
+
${element_accumulator},
|
| 76 |
+
${element_compute},
|
| 77 |
+
${element_c}, ${layout_c}, 128 / cute::sizeof_bits_v<${element_c}>,
|
| 78 |
+
${element_d}, ${layout_d}, 128 / cute::sizeof_bits_v<${element_d}>,
|
| 79 |
+
${epilogue_schedule}
|
| 80 |
+
// , class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination<ElementD,ElementCompute>
|
| 81 |
+
>::CollectiveOp;
|
| 82 |
+
|
| 83 |
+
using ${operation_name}_mainloop =
|
| 84 |
+
typename cutlass::conv::collective::CollectiveBuilder<
|
| 85 |
+
${arch},
|
| 86 |
+
${opcode_class_main},
|
| 87 |
+
${conv_kind}, // kFprop, kDgrad, or kWgrad
|
| 88 |
+
${element_a}, ${layout_a}, 128 / cute::sizeof_bits_v<${element_a}>,
|
| 89 |
+
${element_b}, ${layout_b}, 128 / cute::sizeof_bits_v<${element_b}>,
|
| 90 |
+
${element_accumulator},
|
| 91 |
+
${mma_tile_shape}, // mma tile shape
|
| 92 |
+
${cluster_shape}, // cluster shape
|
| 93 |
+
${stages},
|
| 94 |
+
${kernel_schedule}
|
| 95 |
+
>::CollectiveOp;
|
| 96 |
+
|
| 97 |
+
using ${operation_name}_problem_shape = cutlass::conv::ConvProblemShape<${conv_kind}, ${operation_name}_mainloop::NumSpatialDimensions>;
|
| 98 |
+
|
| 99 |
+
// Unit tests call this "ConvKernel".
|
| 100 |
+
// Conv operator ${operation_name}
|
| 101 |
+
using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
|
| 102 |
+
${operation_name}_problem_shape,
|
| 103 |
+
${operation_name}_mainloop,
|
| 104 |
+
${operation_name}_epilogue,
|
| 105 |
+
${tile_scheduler}
|
| 106 |
+
>;
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def arch_number_to_type(self, arch: int) -> str:
|
| 110 |
+
return f"cutlass::arch::Sm{arch}"
|
| 111 |
+
|
| 112 |
+
def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
|
| 113 |
+
mma_m = cta_m
|
| 114 |
+
mma_n = cta_n
|
| 115 |
+
mma_k = cta_k
|
| 116 |
+
|
| 117 |
+
if operation.arch >= 100:
|
| 118 |
+
# MmaTileShape (mma_m, mma_n, mma_k) is passed to kernel mainloop where
|
| 119 |
+
# mma_m = cta_m for 1sm version and mma_m = cta_m * 2 for 2sm version.
|
| 120 |
+
# If schedule is auto and cluster size is static and cta_m % 64 == 0 and cluster_m % 2 == 0, 2sm kernel version is allocated,
|
| 121 |
+
# otherwise 1sm kernel is allocated.
|
| 122 |
+
cta_m_per_mma_instruction = 1
|
| 123 |
+
if "2sm" in operation.procedural_name() :
|
| 124 |
+
cta_m_per_mma_instruction = 2
|
| 125 |
+
elif "1sm" in operation.procedural_name() :
|
| 126 |
+
cta_m_per_mma_instruction = 1
|
| 127 |
+
elif operation.tile_description.cluster_shape[0] > 0 and operation.tile_description.cluster_shape[0] % 2 == 0 and cta_m % 64 == 0 :
|
| 128 |
+
cta_m_per_mma_instruction = 2
|
| 129 |
+
mma_m = cta_m * cta_m_per_mma_instruction
|
| 130 |
+
|
| 131 |
+
# For all three kinds of convolutions, the tile shape's K mode
|
| 132 |
+
# differs from GEMM in that needs to be wrapped in a Shape.
|
| 133 |
+
# For Wgrad convolutions specifically,
|
| 134 |
+
# the N tile shape also needs to be wrapped in a Shape.
|
| 135 |
+
m_template = 'cute::_${mma_m}'
|
| 136 |
+
if operation.conv_kind == ConvKind.Wgrad:
|
| 137 |
+
n_template = 'cute::Shape<cute::_${mma_n}>'
|
| 138 |
+
else:
|
| 139 |
+
n_template = 'cute::_${mma_n}'
|
| 140 |
+
k_template = 'cute::Shape<cute::_${mma_k}>'
|
| 141 |
+
|
| 142 |
+
mma_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
|
| 143 |
+
values = {
|
| 144 |
+
'mma_m': mma_m,
|
| 145 |
+
'mma_n': mma_n,
|
| 146 |
+
'mma_k': mma_k
|
| 147 |
+
}
|
| 148 |
+
return Template(mma_tile_shape_template).substitute(values)
|
| 149 |
+
|
| 150 |
+
def cluster_shape(self, operation) -> str:
|
| 151 |
+
m_template = 'cute::_${cluster_shape_m}' if operation.tile_description.cluster_shape[0] > 0 else 'int(0)'
|
| 152 |
+
n_template = 'cute::_${cluster_shape_n}' if operation.tile_description.cluster_shape[1] > 0 else 'int(0)'
|
| 153 |
+
k_template = 'cute::_${cluster_shape_k}' if operation.tile_description.cluster_shape[2] > 0 else 'int(0)'
|
| 154 |
+
cluster_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
|
| 155 |
+
values = {
|
| 156 |
+
'cluster_shape_m': operation.tile_description.cluster_shape[0],
|
| 157 |
+
'cluster_shape_n': operation.tile_description.cluster_shape[1],
|
| 158 |
+
'cluster_shape_k': operation.tile_description.cluster_shape[2],
|
| 159 |
+
}
|
| 160 |
+
return Template(cluster_shape_template).substitute(values)
|
| 161 |
+
|
| 162 |
+
def stage_count(self, operation) -> str:
|
| 163 |
+
# stages == 0 tells builder to pick the number of stages automatically
|
| 164 |
+
namespace_prefix = 'cutlass::conv::collective::'
|
| 165 |
+
if operation.tile_description.stages > 0:
|
| 166 |
+
return f"{namespace_prefix}StageCount<{str(operation.tile_description.stages)}>"
|
| 167 |
+
else:
|
| 168 |
+
return f"{namespace_prefix}StageCountAutoCarveout<sizeof(typename {operation.procedural_name()}_epilogue::SharedStorage)>"
|
| 169 |
+
|
| 170 |
+
def emit(self, operation) -> str:
|
| 171 |
+
_LOGGER.debug("*** EmitConv3xInstance::emit")
|
| 172 |
+
_LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name())
|
| 173 |
+
|
| 174 |
+
# Identify the operation as CUTLASS 3 by its is_3x field
|
| 175 |
+
if (not hasattr(operation, 'is_3x')) or (not operation.is_3x):
|
| 176 |
+
raise RuntimeError("operation must be a CUTLASS 3 operation")
|
| 177 |
+
|
| 178 |
+
epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto"
|
| 179 |
+
opcode_class_main = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class]
|
| 180 |
+
opcode_class_epi = opcode_class_main
|
| 181 |
+
|
| 182 |
+
tile_shape = operation.tile_description.tile_shape
|
| 183 |
+
cluster_m = operation.tile_description.cluster_shape[0]
|
| 184 |
+
cluster_n = operation.tile_description.cluster_shape[1]
|
| 185 |
+
|
| 186 |
+
cta_m, cta_n, cta_k = tile_shape
|
| 187 |
+
# account for static/dynamic cluster shapes
|
| 188 |
+
if operation.arch >= 100:
|
| 189 |
+
cta_m = cta_m // cluster_m if cluster_m > 0 else cta_m
|
| 190 |
+
cta_n = cta_n // cluster_n if cluster_n > 0 else cta_n
|
| 191 |
+
|
| 192 |
+
warp_count = operation.tile_description.warp_count
|
| 193 |
+
epilogue_schedule = EpilogueScheduleTag[operation.epilogue_schedule]
|
| 194 |
+
|
| 195 |
+
# KernelScheduleTag and TileSchedulerTag both hard-code the
|
| 196 |
+
# namespace qualification of KernelScheduleAuto as
|
| 197 |
+
# "cutlass::gemm::collective::" (unless the tag is 'void').
|
| 198 |
+
#
|
| 199 |
+
# For TileSchedulerTag, this namespace is fine, since CUTLASS 3
|
| 200 |
+
# convolutions use the same tile schedulers (from the same
|
| 201 |
+
# cutlass::gemm::collective namespace) as GEMMs.
|
| 202 |
+
kernel_schedule = KernelScheduleTag[operation.kernel_schedule].replace('gemm::', 'conv::')
|
| 203 |
+
tile_scheduler = TileSchedulerTag[operation.tile_scheduler]
|
| 204 |
+
opcode_class = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class]
|
| 205 |
+
|
| 206 |
+
values = {
|
| 207 |
+
'operation_name': operation.procedural_name(),
|
| 208 |
+
'conv_kind': ConvKindTag[operation.conv_kind],
|
| 209 |
+
'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
|
| 210 |
+
'element_a': DataTypeTag[operation.A.element],
|
| 211 |
+
'layout_a': LayoutTag[operation.A.layout],
|
| 212 |
+
'align_a': int(operation.A.alignment),
|
| 213 |
+
'element_b': DataTypeTag[operation.B.element],
|
| 214 |
+
'layout_b': LayoutTag[operation.B.layout],
|
| 215 |
+
'align_b': int(operation.B.alignment),
|
| 216 |
+
'element_c': DataTypeTag[operation.C.element],
|
| 217 |
+
'layout_c': LayoutTag[operation.C.layout],
|
| 218 |
+
'align_c': int(operation.C.alignment),
|
| 219 |
+
'element_d': DataTypeTag[operation.D.element],
|
| 220 |
+
'layout_d': LayoutTag[operation.D.layout],
|
| 221 |
+
'align_d': int(operation.D.alignment),
|
| 222 |
+
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
| 223 |
+
'opcode_class': opcode_class,
|
| 224 |
+
'arch': self.arch_number_to_type(operation.arch),
|
| 225 |
+
'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k),
|
| 226 |
+
'cluster_shape': self.cluster_shape(operation),
|
| 227 |
+
'opcode_class_epi': opcode_class_epi,
|
| 228 |
+
'opcode_class_main': opcode_class_main,
|
| 229 |
+
'epi_tile_mn': epi_tile_mn,
|
| 230 |
+
'stages': self.stage_count(operation),
|
| 231 |
+
'kernel_schedule': kernel_schedule,
|
| 232 |
+
'epilogue_schedule': epilogue_schedule,
|
| 233 |
+
'tile_scheduler': tile_scheduler,
|
| 234 |
+
'element_compute': DataTypeTag[operation.element_compute]
|
| 235 |
+
}
|
| 236 |
+
return Template(self.template).substitute(values)
|
| 237 |
+
|
| 238 |
+
class EmitConv3xIncludes:
|
| 239 |
+
def __init__(self):
|
| 240 |
+
_LOGGER.debug("*** EmitConv3xIncludes::__init__")
|
| 241 |
+
self.includes = ['conv_operation_3x.hpp',
|
| 242 |
+
'cutlass/conv/device/conv_universal_adapter.hpp',
|
| 243 |
+
'cutlass/conv/kernel/conv_universal.hpp',
|
| 244 |
+
'cutlass/conv/collective/collective_builder.hpp',
|
| 245 |
+
'cutlass/epilogue/collective/collective_builder.hpp']
|
| 246 |
+
|
| 247 |
+
def emit(self, operation) -> str:
|
| 248 |
+
_LOGGER.debug("*** EmitConv3xIncludes::emit")
|
| 249 |
+
return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \
|
| 250 |
+
"\n\n///////////////////////////////////////////////////////////////////////////////////////////////////"
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py
ADDED
|
@@ -0,0 +1,868 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
#
|
| 34 |
+
#
|
| 35 |
+
# \brief Generates the CUTLASS kernel listing with kernel filtering
|
| 36 |
+
#
|
| 37 |
+
|
| 38 |
+
#
|
| 39 |
+
|
| 40 |
+
###############################################################################
|
| 41 |
+
# Example usage:
|
| 42 |
+
# generator.py --operations all --generator-target kernel_listing \
|
| 43 |
+
# --architectures "70;75;80" --kernels "*" --disable-cutlass-package-imports
|
| 44 |
+
###############################################################################
|
| 45 |
+
|
| 46 |
+
import collections
|
| 47 |
+
import csv
|
| 48 |
+
import json
|
| 49 |
+
import math
|
| 50 |
+
import os
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
import builtins
|
| 54 |
+
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 55 |
+
raise ImportError("Disabling attempt to import cutlass_library")
|
| 56 |
+
from cutlass_library.library import *
|
| 57 |
+
except ImportError:
|
| 58 |
+
from library import *
|
| 59 |
+
|
| 60 |
+
audit_csv_fields = [
|
| 61 |
+
"KernelType", "KernelName", "Type_A", "Type_B", "Type_C", "Type_Acc", "Type_EpilogueScale", "Type_D", "Type_SFA", "Type_SFD",
|
| 62 |
+
"Layout_A", "Layout_B", "Layout_C", "Layout_D",
|
| 63 |
+
"Alignment_A", "Alignment_B", "Alignment_C", "Alignment_D",
|
| 64 |
+
"1SM/2SM",
|
| 65 |
+
"StreamK Enabled", "Support Runtime_Cluster_Shape", "Support Runtime_Input_Types",
|
| 66 |
+
"Test Counts"
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
audit_csv_runtime_fields = [
|
| 70 |
+
"KerneIndex", "KernelName",
|
| 71 |
+
"Inst_M", "Inst_N", "Inst_K", "Tile_M", "Tile_N", "Tile_K",
|
| 72 |
+
"Cluster_M", "Cluster_N", "Cluster_K", "Preferred_Cluster_M", "Preferred_Cluster_N", "Preferred_Cluster_K", "Fallback_Cluster_M", "Fallback_Cluster_N", "Fallback_Cluster_K",
|
| 73 |
+
"M", "N", "K", "L", "Alpha_val", "Beta_val",
|
| 74 |
+
"Runtime_Input_Types Enabled", "Runtime_Cluster_Shape Enabled"
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
def hash_cutlass_string(input_string):
|
| 78 |
+
mma_cluster_shape_pattern = r"_\d+x\d+x\d+" # Matches MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
|
| 79 |
+
|
| 80 |
+
# Remove MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
|
| 81 |
+
output = re.sub(mma_cluster_shape_pattern, "", input_string)
|
| 82 |
+
|
| 83 |
+
return output
|
| 84 |
+
|
| 85 |
+
def transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b):
|
| 86 |
+
# Define a dictionary mapping the detected types to runtime values
|
| 87 |
+
datatype_map = {
|
| 88 |
+
'f4_f4': runtime_datatype_a + '_' + runtime_datatype_b,
|
| 89 |
+
'f4_f6': runtime_datatype_a + '_' + runtime_datatype_b,
|
| 90 |
+
'f4_f8': runtime_datatype_a + '_' + runtime_datatype_b,
|
| 91 |
+
'f6_f4': runtime_datatype_a + '_' + runtime_datatype_b,
|
| 92 |
+
'f6_f6': runtime_datatype_a + '_' + runtime_datatype_b,
|
| 93 |
+
'f6_f8': runtime_datatype_a + '_' + runtime_datatype_b,
|
| 94 |
+
'f8_f4': runtime_datatype_a + '_' + runtime_datatype_b,
|
| 95 |
+
'f8_f6': runtime_datatype_a + '_' + runtime_datatype_b,
|
| 96 |
+
'f8_f8': runtime_datatype_a + '_' + runtime_datatype_b,
|
| 97 |
+
'ue8m0xf4_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
| 98 |
+
'ue4m3xf4_ue4m3xf4': 'ue4m3x' + runtime_datatype_a + '_ue4m3x' + runtime_datatype_b,
|
| 99 |
+
'ue8m0xf4_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
| 100 |
+
'ue8m0xf4_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
| 101 |
+
'ue8m0xf6_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
| 102 |
+
'ue8m0xf6_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
| 103 |
+
'ue8m0xf8_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
| 104 |
+
'ue8m0xf8_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
| 105 |
+
'ue8m0xf8_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
# Regular expression to detect all the keys in datatype_map
|
| 109 |
+
pattern = re.compile(r'(' + '|'.join(map(re.escape, datatype_map.keys())) + r')')
|
| 110 |
+
|
| 111 |
+
# Replace detected patterns using the dictionary
|
| 112 |
+
updated_kernel_name = pattern.sub(lambda match: datatype_map[match.group(0)], hashed_kernel_name)
|
| 113 |
+
|
| 114 |
+
return updated_kernel_name
|
| 115 |
+
|
| 116 |
+
# This helper function reports foundational kernel features: datatypes, layouts, alignment and stream-k.
|
| 117 |
+
def get_kernel_features(operation, kernel_name,
|
| 118 |
+
dynamic_datatype, runtime_input_datatype):
|
| 119 |
+
numcta_inst = "2sm" if "2sm" in kernel_name else "1sm"
|
| 120 |
+
math_inst = operation.tile_description.math_instruction
|
| 121 |
+
|
| 122 |
+
if dynamic_datatype:
|
| 123 |
+
dtype_name_A = runtime_input_datatype[0]
|
| 124 |
+
dtype_name_B = runtime_input_datatype[1]
|
| 125 |
+
else:
|
| 126 |
+
dtype_name_A = DataTypeNames[operation.A.element]
|
| 127 |
+
dtype_name_B = DataTypeNames[operation.B.element]
|
| 128 |
+
|
| 129 |
+
layout_name_A = ShortLayoutTypeNames[operation.A.layout]
|
| 130 |
+
layout_name_B = ShortLayoutTypeNames[operation.B.layout]
|
| 131 |
+
layout_name_C = ShortLayoutTypeNames[operation.C.layout]
|
| 132 |
+
layout_name_D = ShortLayoutTypeNames[operation.D.layout]
|
| 133 |
+
|
| 134 |
+
scale_factor_D_type = operation.ScaleFactorD.element if hasattr(operation, "ScaleFactorD") else DataType.void
|
| 135 |
+
scale_factor_A_type = getattr(operation, "ScaleFactorA", DataType.void)
|
| 136 |
+
audit_vals = [
|
| 137 |
+
"BlockScaledGEMM" if math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp else "GEMM",
|
| 138 |
+
kernel_name,
|
| 139 |
+
dtype_name_A,
|
| 140 |
+
dtype_name_B,
|
| 141 |
+
DataTypeNames[operation.C.element],
|
| 142 |
+
DataTypeNames[operation.tile_description.math_instruction.element_accumulator],
|
| 143 |
+
DataTypeNames[operation.element_epilogue],
|
| 144 |
+
DataTypeNames[operation.D.element],
|
| 145 |
+
DataTypeNames[scale_factor_D_type],
|
| 146 |
+
DataTypeNames[scale_factor_A_type],
|
| 147 |
+
layout_name_A,
|
| 148 |
+
layout_name_B,
|
| 149 |
+
layout_name_C,
|
| 150 |
+
layout_name_D,
|
| 151 |
+
str(operation.A.alignment),
|
| 152 |
+
str(operation.B.alignment),
|
| 153 |
+
str(operation.C.alignment),
|
| 154 |
+
str(operation.D.alignment),
|
| 155 |
+
numcta_inst,
|
| 156 |
+
"Y" if 'stream_k' in kernel_name else "N",
|
| 157 |
+
]
|
| 158 |
+
return audit_vals
|
| 159 |
+
|
| 160 |
+
# This helper function reports other performance-related kernel parameters and those can be specified at runtime: cluster_shape, instruction shap, m/n/k and alpha/beta.
|
| 161 |
+
def get_kernel_params(operation, kernel_name, cluster_shape, fallback_cluster_shape, problem_shape, alpha, beta, dynamic_datatype, dynamic_cluster):
|
| 162 |
+
math_inst = operation.tile_description.math_instruction
|
| 163 |
+
audit_vals = [
|
| 164 |
+
str(math_inst.instruction_shape[0]),
|
| 165 |
+
str(math_inst.instruction_shape[1]),
|
| 166 |
+
str(math_inst.instruction_shape[2]),
|
| 167 |
+
str(operation.tile_description.threadblock_shape[0]),
|
| 168 |
+
str(operation.tile_description.threadblock_shape[1]),
|
| 169 |
+
str(operation.tile_description.threadblock_shape[2]),
|
| 170 |
+
str(operation.tile_description.cluster_shape[0]),
|
| 171 |
+
str(operation.tile_description.cluster_shape[1]),
|
| 172 |
+
str(operation.tile_description.cluster_shape[2]),
|
| 173 |
+
str(cluster_shape[0]),
|
| 174 |
+
str(cluster_shape[1]),
|
| 175 |
+
str(cluster_shape[2]),
|
| 176 |
+
str(fallback_cluster_shape[0]),
|
| 177 |
+
str(fallback_cluster_shape[1]),
|
| 178 |
+
str(fallback_cluster_shape[2]),
|
| 179 |
+
str(problem_shape[0]),
|
| 180 |
+
str(problem_shape[1]),
|
| 181 |
+
str(problem_shape[2]),
|
| 182 |
+
str(problem_shape[3]),
|
| 183 |
+
str(alpha),
|
| 184 |
+
str(beta),
|
| 185 |
+
"Y" if dynamic_datatype else "N",
|
| 186 |
+
"Y" if dynamic_cluster else "N",
|
| 187 |
+
]
|
| 188 |
+
return audit_vals
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _getSubOperationType(kernel):
|
| 192 |
+
|
| 193 |
+
if kernel.operation_kind == OperationKind.Gemm:
|
| 194 |
+
return GemmKindNames[kernel.gemm_kind]
|
| 195 |
+
elif kernel.operation_kind == OperationKind.Conv2d:
|
| 196 |
+
return "conv_" + ConvKindNames[kernel.conv_kind]
|
| 197 |
+
elif kernel.operation_kind == OperationKind.Syrk:
|
| 198 |
+
return "syrk_" + SyrkKindNames[kernel.syrk_kind]
|
| 199 |
+
elif kernel.operation_kind == OperationKind.Trmm:
|
| 200 |
+
return "trmm_" + TrmmKindNames[kernel.trmm_kind]
|
| 201 |
+
elif kernel.operation_kind == OperationKind.Symm:
|
| 202 |
+
return "symm_" + SymmKindNames[kernel.symm_kind]
|
| 203 |
+
else:
|
| 204 |
+
raise Exception("Unsupported kernel type")
|
| 205 |
+
|
| 206 |
+
def _get_inst_shape(math_instruction):
|
| 207 |
+
return "".join(str(x) for x in math_instruction.instruction_shape)
|
| 208 |
+
|
| 209 |
+
def _is_simt_inst(math_instruction):
|
| 210 |
+
return _get_inst_shape(math_instruction) in ["111","114"]
|
| 211 |
+
|
| 212 |
+
def _getInstType(input_precision, accumulate_precision, math_instruction):
|
| 213 |
+
|
| 214 |
+
# inst_shape
|
| 215 |
+
inst_shape = _get_inst_shape(math_instruction)
|
| 216 |
+
|
| 217 |
+
# input precision
|
| 218 |
+
if input_precision == "fp32" and inst_shape != "111":
|
| 219 |
+
inp = "tf32"
|
| 220 |
+
else:
|
| 221 |
+
inp = input_precision
|
| 222 |
+
|
| 223 |
+
# Handle SIMT op types first
|
| 224 |
+
if _is_simt_inst(math_instruction):
|
| 225 |
+
|
| 226 |
+
simt_input_precision_to_inst = {
|
| 227 |
+
"fp32": "FFMA",
|
| 228 |
+
"fp64": "DFMA",
|
| 229 |
+
"fp16": "HFMA",
|
| 230 |
+
"int8": "IDP4A",
|
| 231 |
+
}
|
| 232 |
+
inst = simt_input_precision_to_inst[input_precision]
|
| 233 |
+
|
| 234 |
+
else: # Tensor op instructions
|
| 235 |
+
|
| 236 |
+
if accumulate_precision == "cf64":
|
| 237 |
+
fp64_acc_map = {
|
| 238 |
+
MathOperation.multiply_add_complex_gaussian : "gz",
|
| 239 |
+
MathOperation.multiply_add_complex : "z",
|
| 240 |
+
}
|
| 241 |
+
acc = fp64_acc_map[math_instruction.math_operation]
|
| 242 |
+
else:
|
| 243 |
+
tensor_op_acc_map = {
|
| 244 |
+
"fp32" : "s",
|
| 245 |
+
"cf32" : "s",
|
| 246 |
+
"fp16" : "h",
|
| 247 |
+
"int32": "i",
|
| 248 |
+
"fp64" : "d",
|
| 249 |
+
}
|
| 250 |
+
acc = tensor_op_acc_map[accumulate_precision]
|
| 251 |
+
|
| 252 |
+
inst = "{}{}{}".format(acc, inst_shape, inp)
|
| 253 |
+
|
| 254 |
+
return inst
|
| 255 |
+
# TODO: Computes FLOps/Bytes for GEMM - revisit for conv
|
| 256 |
+
def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups=1):
|
| 257 |
+
assert not (batch_count > 1 and num_groups > 1)
|
| 258 |
+
|
| 259 |
+
# TODO: adjust for sparsity
|
| 260 |
+
gmem_bytes = (
|
| 261 |
+
(DataTypeSize[operation.A.element] * m // 8) * k +
|
| 262 |
+
(DataTypeSize[operation.B.element] * n // 8) * k +
|
| 263 |
+
(DataTypeSize[operation.C.element] * m // 8) * n
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# TODO: complex-valued support
|
| 267 |
+
flops = 2 * (m * n * k)
|
| 268 |
+
|
| 269 |
+
if bool(beta):
|
| 270 |
+
gmem_bytes += (DataTypeSize[operation.C.element] * m // 8) * n
|
| 271 |
+
flops += 2 * m * n
|
| 272 |
+
|
| 273 |
+
multiplier = max(batch_count, num_groups)
|
| 274 |
+
gmem_bytes *= multiplier
|
| 275 |
+
flops *= multiplier
|
| 276 |
+
|
| 277 |
+
return flops / gmem_bytes
|
| 278 |
+
|
| 279 |
+
def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
| 280 |
+
):
|
| 281 |
+
# For functional testing, we prefer to run reference computing on device if any
|
| 282 |
+
reference_device_archs = ["100a", "103a"]
|
| 283 |
+
run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False
|
| 284 |
+
profiler_flags_for_verification = "device" if run_reference_on_device else "host"
|
| 285 |
+
|
| 286 |
+
# beta values for L0 and L1
|
| 287 |
+
# TODO: randomize beta values for wider coverage
|
| 288 |
+
beta_values = [0.5]
|
| 289 |
+
|
| 290 |
+
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "103a", "110a", "110f", "120a", "120f", "121a", "121f"])
|
| 291 |
+
|
| 292 |
+
is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
|
| 293 |
+
|
| 294 |
+
if (mode == "functional_L0") and is_supported_arch:
|
| 295 |
+
problem_waves = [0.5, 1.25, 2.5]
|
| 296 |
+
|
| 297 |
+
#
|
| 298 |
+
# Dense Gemm
|
| 299 |
+
#
|
| 300 |
+
|
| 301 |
+
sm100_mma_data_type_general = [
|
| 302 |
+
'gemm_f16_f16_f16_f16_f16',
|
| 303 |
+
'gemm_f16_f16_f16_void_f16',
|
| 304 |
+
#'gemm_f16_f16_f32_f16_f16',
|
| 305 |
+
'tf32gemm_f32_f32_f32_f32_f32',
|
| 306 |
+
'bf16gemm_f32_f32_f32_f32_f32',
|
| 307 |
+
]
|
| 308 |
+
|
| 309 |
+
exclude_archs = arch not in ("103a")
|
| 310 |
+
if exclude_archs:
|
| 311 |
+
sm100_mma_data_type_general.append('gemm_s8_s8_s32_s8_s8')
|
| 312 |
+
|
| 313 |
+
sm100_mma_data_type_runtime_dtype = [
|
| 314 |
+
'gemm.*f4_f4_f32_f32_f32',
|
| 315 |
+
'gemm.*f6_f6_f32_f32_f32',
|
| 316 |
+
'gemm.*f8_f8_f32_f32_f32',
|
| 317 |
+
]
|
| 318 |
+
|
| 319 |
+
sm100_mma_cluster_size = [
|
| 320 |
+
'8x1x1',
|
| 321 |
+
'4x4x1', '2x1x1',
|
| 322 |
+
'0x0x1' # dynamic cluster
|
| 323 |
+
]
|
| 324 |
+
|
| 325 |
+
# Restrict to two layouts to reduce L0 build and test time.
|
| 326 |
+
sm100_mma_layouts = [
|
| 327 |
+
'tnt',
|
| 328 |
+
'ntn'
|
| 329 |
+
]
|
| 330 |
+
|
| 331 |
+
# regex list must be in kernel procedural name order
|
| 332 |
+
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
| 333 |
+
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
| 334 |
+
|
| 335 |
+
sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
| 336 |
+
sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
| 337 |
+
|
| 338 |
+
#
|
| 339 |
+
# Block Scale Gemm
|
| 340 |
+
#
|
| 341 |
+
|
| 342 |
+
block_scaled_data_type = [
|
| 343 |
+
# runtime datatypes
|
| 344 |
+
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
| 345 |
+
'gemm.*ue4m3xf4_ue4m3xf4_f32_f16_e5m2',
|
| 346 |
+
'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
|
| 347 |
+
#'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
|
| 348 |
+
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
|
| 349 |
+
]
|
| 350 |
+
|
| 351 |
+
block_scaled_tile_k = ['x128_', 'x256_']
|
| 352 |
+
|
| 353 |
+
sm103_block_scaled_data_type = [
|
| 354 |
+
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
| 355 |
+
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
|
| 356 |
+
]
|
| 357 |
+
|
| 358 |
+
sm103_block_scaled_tile_k = ['x768_']
|
| 359 |
+
|
| 360 |
+
block_scaled_cluster_size = [
|
| 361 |
+
'4x4x1', '2x1x1',
|
| 362 |
+
'0x0x1' # dynamic cluster
|
| 363 |
+
]
|
| 364 |
+
|
| 365 |
+
block_scaled_layouts = ['tnt']
|
| 366 |
+
# regex list must be in kernel procedural name order
|
| 367 |
+
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
| 368 |
+
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
| 369 |
+
|
| 370 |
+
sm103_block_scaled_prefetch_policy = ['tmapf']
|
| 371 |
+
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*"
|
| 372 |
+
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*"
|
| 373 |
+
|
| 374 |
+
if arch in ["100a", "100f"]:
|
| 375 |
+
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
| 376 |
+
f"({sm100_mma_filter_regex_2sm})|" \
|
| 377 |
+
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
| 378 |
+
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
| 379 |
+
f"({block_scaled_filter_regex_1sm})|" \
|
| 380 |
+
f"({block_scaled_filter_regex_2sm})"
|
| 381 |
+
elif arch in ["101a", "101f", "110a", "110f"]:
|
| 382 |
+
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
| 383 |
+
f"({sm100_mma_filter_regex_2sm})|" \
|
| 384 |
+
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
| 385 |
+
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
| 386 |
+
f"({block_scaled_filter_regex_1sm})|" \
|
| 387 |
+
f"({block_scaled_filter_regex_2sm})"
|
| 388 |
+
elif arch in ["103a"]:
|
| 389 |
+
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
| 390 |
+
f"({sm100_mma_filter_regex_2sm})|" \
|
| 391 |
+
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
| 392 |
+
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
| 393 |
+
f"({block_scaled_filter_regex_1sm})|" \
|
| 394 |
+
f"({block_scaled_filter_regex_2sm})|" \
|
| 395 |
+
f"({sm103_block_scaled_filter_regex_1sm})|" \
|
| 396 |
+
f"({sm103_block_scaled_filter_regex_2sm})"
|
| 397 |
+
elif arch in ["120a", "120f", "121a", "121f"]:
|
| 398 |
+
|
| 399 |
+
# blockscaled sm120_mma kernels
|
| 400 |
+
blockscaled_sm120_mma_kernel_cta_tiles = [
|
| 401 |
+
[ '128x128' ]
|
| 402 |
+
]
|
| 403 |
+
|
| 404 |
+
# Restrict to two layouts to reduce L0 build and test time.
|
| 405 |
+
blockscaled_sm120_mma_layouts = [ 'tn' ]
|
| 406 |
+
filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*"
|
| 407 |
+
|
| 408 |
+
problem_waves = [0.5, 1.25, 2.5]
|
| 409 |
+
|
| 410 |
+
kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
|
| 411 |
+
else:
|
| 412 |
+
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm110a, sm110f, sm103a, sm120a, sm120f, sm121a, sm121f"
|
| 413 |
+
raise Exception(error_message)
|
| 414 |
+
|
| 415 |
+
elif mode == "functional_L1":
|
| 416 |
+
sm100_mma_cluster_size = [
|
| 417 |
+
'0x0x1' # dynamic cluster
|
| 418 |
+
]
|
| 419 |
+
# Restrict to two layouts to reduce L1 build and test time.
|
| 420 |
+
sm100_mma_layouts = ['tnt', 'ntn']
|
| 421 |
+
sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
|
| 422 |
+
sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
|
| 423 |
+
block_scaled_data_type = [
|
| 424 |
+
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
| 425 |
+
'ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
|
| 426 |
+
'ue8m0xmx8s26_ue8m0xmx8s26_f32_f16_e5m2',
|
| 427 |
+
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
|
| 428 |
+
'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
|
| 429 |
+
]
|
| 430 |
+
|
| 431 |
+
sm103_block_scaled_data_type = [
|
| 432 |
+
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
| 433 |
+
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
|
| 434 |
+
]
|
| 435 |
+
|
| 436 |
+
block_scaled_cluster_size = ['0x0x1']
|
| 437 |
+
block_scaled_layouts = ['tnt']
|
| 438 |
+
|
| 439 |
+
# regex list must be in kernel procedural name order
|
| 440 |
+
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
| 441 |
+
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
| 442 |
+
|
| 443 |
+
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
| 444 |
+
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
| 445 |
+
|
| 446 |
+
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
|
| 447 |
+
f"({sm100_mma_filter_regex_2sm})|" \
|
| 448 |
+
f"({block_scaled_filter_regex_1sm})|" \
|
| 449 |
+
f"({block_scaled_filter_regex_2sm})" \
|
| 450 |
+
f"({sm103_block_scaled_filter_regex_1sm})|" \
|
| 451 |
+
f"({sm103_block_scaled_filter_regex_2sm})"
|
| 452 |
+
# CTA tiles for sm120 MMA - only run one tile size to reduce build/test times
|
| 453 |
+
sm120_mma_kernel_cta_tiles = [
|
| 454 |
+
# h1688, s1688, i16832, i8816
|
| 455 |
+
[ '256x128' ],
|
| 456 |
+
# d884, c1688,
|
| 457 |
+
[ '128x128' ],
|
| 458 |
+
# c1688, z884
|
| 459 |
+
[ '128x64' ],
|
| 460 |
+
# gz884
|
| 461 |
+
[ '64x64' ]
|
| 462 |
+
]
|
| 463 |
+
|
| 464 |
+
# sm120 MMA instruction shapes, planar complex type excluded as they are not required
|
| 465 |
+
sm120_mma_instruction_shapes = [
|
| 466 |
+
[ 'h1688gemm_(?!planar_complex)',
|
| 467 |
+
's1688gemm_f16',
|
| 468 |
+
's1688gemm_bf16',
|
| 469 |
+
's1688gemm_tf32',
|
| 470 |
+
'i16832gemm',
|
| 471 |
+
'i8816gemm' ],
|
| 472 |
+
[ 'd884gemm', 'c1688tf32gemm' ] ,
|
| 473 |
+
[ 'c1688gemm',
|
| 474 |
+
'z884gemm' ],
|
| 475 |
+
[ 'gz884gemm']
|
| 476 |
+
]
|
| 477 |
+
|
| 478 |
+
# It's not pretty, but not sure why different instructions support different tile sizes.
|
| 479 |
+
filter_regex_sm120_mma_0 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[0], sm120_mma_kernel_cta_tiles[0]]]) + ").*"
|
| 480 |
+
filter_regex_sm120_mma_1 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[1], sm120_mma_kernel_cta_tiles[1]]]) + ").*"
|
| 481 |
+
filter_regex_sm120_mma_2 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[2], sm120_mma_kernel_cta_tiles[2]]]) + ").*"
|
| 482 |
+
filter_regex_sm120_mma_3 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[3], sm120_mma_kernel_cta_tiles[3]]]) + ").*"
|
| 483 |
+
|
| 484 |
+
filter_regex_sm120_mma = f"({filter_regex_sm120_mma_0})|({filter_regex_sm120_mma_1})|({filter_regex_sm120_mma_2})|({filter_regex_sm120_mma_3})"
|
| 485 |
+
|
| 486 |
+
problem_waves = [0.5, 1.25, 2.5]
|
| 487 |
+
|
| 488 |
+
if arch in ["120a", "120f", "121a", "121f"]:
|
| 489 |
+
kernel_filter = f"({filter_regex_sm120_mma})"
|
| 490 |
+
else:
|
| 491 |
+
kernel_filter = f"({filter_regex_sm100_mma})"
|
| 492 |
+
else:
|
| 493 |
+
raise ValueError()
|
| 494 |
+
|
| 495 |
+
outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv")
|
| 496 |
+
|
| 497 |
+
audit_file_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_SM{arch}_cutlass3x_gemm.csv")
|
| 498 |
+
|
| 499 |
+
audit_file_params_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_params_SM{arch}_cutlass3x_gemm.csv")
|
| 500 |
+
|
| 501 |
+
kernel_filter_re = re.compile(kernel_filter)
|
| 502 |
+
testcase_counter = 0
|
| 503 |
+
kernels_emitted = 0
|
| 504 |
+
kernels_total = 0
|
| 505 |
+
|
| 506 |
+
perf_json_list = []
|
| 507 |
+
kernel_name_set = set()
|
| 508 |
+
|
| 509 |
+
testlist_csv_fields = ["testcase", "metadata"]
|
| 510 |
+
testlist_csv_rows = []
|
| 511 |
+
auditlist_csv_map = {}
|
| 512 |
+
auditlist_csv_params_map = {}
|
| 513 |
+
|
| 514 |
+
kernel_features = {}
|
| 515 |
+
|
| 516 |
+
for cc in manifest.operations[OperationKind.Gemm].keys():
|
| 517 |
+
for kernel_name, operation_l in manifest.operations[OperationKind.Gemm][cc].items():
|
| 518 |
+
assert(len(operation_l) == 1)
|
| 519 |
+
kernels_total += 1
|
| 520 |
+
if len(kernel_filter_re.findall(kernel_name)) == 0:
|
| 521 |
+
continue
|
| 522 |
+
# Only test f16 I/O void C kernels in void C kernel set
|
| 523 |
+
# Exception: Use void C kernels for more accurate perf testing
|
| 524 |
+
if '_void_' in kernel_name and 'perf_' not in mode:
|
| 525 |
+
if 'f16_f16_f16_void_f16' not in kernel_name :
|
| 526 |
+
continue
|
| 527 |
+
|
| 528 |
+
kernels_emitted += 1
|
| 529 |
+
kernel_name_set.add(kernel_name)
|
| 530 |
+
hashed_kernel_name = hash_cutlass_string(kernel_name)
|
| 531 |
+
operation = operation_l[0]
|
| 532 |
+
|
| 533 |
+
dynamic_cluster = (operation.tile_description.cluster_shape[0] == 0
|
| 534 |
+
or operation.tile_description.cluster_shape[1] == 0)
|
| 535 |
+
|
| 536 |
+
dynamic_datatype = "f8" in kernel_name or "f6" in kernel_name or "f4" in kernel_name
|
| 537 |
+
|
| 538 |
+
runtime_input_datatypes = [None]
|
| 539 |
+
|
| 540 |
+
if dynamic_datatype:
|
| 541 |
+
if "f4_f4" in kernel_name:
|
| 542 |
+
runtime_input_datatypes = [['e2m1','e2m1']]
|
| 543 |
+
elif "f4_f6" in kernel_name:
|
| 544 |
+
runtime_input_datatypes = [['e2m1','e3m2']]
|
| 545 |
+
elif "f4_f8" in kernel_name:
|
| 546 |
+
runtime_input_datatypes = [['e2m1','e4m3']]
|
| 547 |
+
|
| 548 |
+
elif "f6_f4" in kernel_name:
|
| 549 |
+
runtime_input_datatypes = [['e3m2','e2m1']]
|
| 550 |
+
elif "f6_f6" in kernel_name:
|
| 551 |
+
runtime_input_datatypes = [['e3m2','e3m2']]
|
| 552 |
+
elif "f6_f8" in kernel_name:
|
| 553 |
+
runtime_input_datatypes = [['e3m2','e4m3']]
|
| 554 |
+
|
| 555 |
+
elif "f8_f4" in kernel_name:
|
| 556 |
+
runtime_input_datatypes = [['e4m3','e2m1']]
|
| 557 |
+
elif "f8_f6" in kernel_name:
|
| 558 |
+
runtime_input_datatypes = [['e4m3','e3m2']]
|
| 559 |
+
elif "f8_f8" in kernel_name:
|
| 560 |
+
runtime_input_datatypes = [
|
| 561 |
+
# mask out those not covered in statically encoded test cases
|
| 562 |
+
# ['e5m2','e4m3'],
|
| 563 |
+
# ['e4m3','e5m2'],
|
| 564 |
+
['e4m3','e4m3']
|
| 565 |
+
]
|
| 566 |
+
|
| 567 |
+
# block scaled kernels
|
| 568 |
+
elif "ue8m0xf4_ue8m0xf4" in kernel_name:
|
| 569 |
+
runtime_input_datatypes = [['e2m1','e2m1']]
|
| 570 |
+
elif "ue4m3xf4_ue4m3xf4" in kernel_name:
|
| 571 |
+
runtime_input_datatypes = [['e2m1','e2m1']]
|
| 572 |
+
elif "ue8m0xf4_ue8m0xf6" in kernel_name:
|
| 573 |
+
runtime_input_datatypes = [['e2m1','e2m3']]
|
| 574 |
+
elif "ue8m0xf4_ue8m0xf8" in kernel_name:
|
| 575 |
+
runtime_input_datatypes = [['e2m1','e4m3']]
|
| 576 |
+
|
| 577 |
+
elif "ue8m0xf6_ue8m0xf4" in kernel_name:
|
| 578 |
+
runtime_input_datatypes = [['e2m3','e2m1']]
|
| 579 |
+
elif "ue8m0xf6_ue8m0xf6" in kernel_name:
|
| 580 |
+
runtime_input_datatypes = [['e2m3','e2m3']]
|
| 581 |
+
elif "ue8m0xf8_ue8m0xf4" in kernel_name:
|
| 582 |
+
runtime_input_datatypes = [['e4m3','e2m1']]
|
| 583 |
+
|
| 584 |
+
elif "ue8m0xf8_ue8m0xf4" in kernel_name:
|
| 585 |
+
runtime_input_datatypes = [['e4m3','e2m1']]
|
| 586 |
+
elif "ue8m0xf8_ue8m0xf6" in kernel_name:
|
| 587 |
+
runtime_input_datatypes = [['e4m3','e2m3']]
|
| 588 |
+
elif "ue8m0xf8_ue8m0xf8" in kernel_name:
|
| 589 |
+
runtime_input_datatypes = [['e4m3','e4m3']]
|
| 590 |
+
|
| 591 |
+
if "bstensorop" in kernel_name or is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind):
|
| 592 |
+
profiler_flags_for_verification = "host"
|
| 593 |
+
|
| 594 |
+
# reduce L1 test runtime if reference kernel is not running on device.
|
| 595 |
+
if mode == "functional_L1" and profiler_flags_for_verification == "host" :
|
| 596 |
+
problem_waves = [0.5, 2.5]
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
if dynamic_cluster:
|
| 600 |
+
if mode == "functional_L0":
|
| 601 |
+
runtime_cluster_shapes = [[1,1,1], [2,2,1]]
|
| 602 |
+
else:
|
| 603 |
+
runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]]
|
| 604 |
+
# reduce L1 test runtime if reference kernel is not running on device.
|
| 605 |
+
if profiler_flags_for_verification == "host":
|
| 606 |
+
runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1]]
|
| 607 |
+
cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape
|
| 608 |
+
else:
|
| 609 |
+
runtime_cluster_shapes = [operation.tile_description.cluster_shape]
|
| 610 |
+
cta_tile_shape_m = int(operation.tile_description.threadblock_shape[0] / operation.tile_description.cluster_shape[0])
|
| 611 |
+
cta_tile_shape_n = int(operation.tile_description.threadblock_shape[1] / operation.tile_description.cluster_shape[1])
|
| 612 |
+
cta_tile_shape_k = int(operation.tile_description.threadblock_shape[2] / operation.tile_description.cluster_shape[2])
|
| 613 |
+
|
| 614 |
+
alignment_a = operation.A.alignment
|
| 615 |
+
alignment_b = operation.B.alignment
|
| 616 |
+
alignment_c = operation.C.alignment
|
| 617 |
+
alignment_ab_max = max(alignment_a, alignment_b)
|
| 618 |
+
|
| 619 |
+
layout3x = operation.layout_name_3x()
|
| 620 |
+
data_types = operation.datatype_name_3x()
|
| 621 |
+
|
| 622 |
+
ctas_per_mma_instruction = 1
|
| 623 |
+
if '_2sm' in kernel_name:
|
| 624 |
+
ctas_per_mma_instruction = 2
|
| 625 |
+
valid_cluster_shapes = []
|
| 626 |
+
|
| 627 |
+
# Remove any cluster shapes that have cluster_m that is not divisible by 2
|
| 628 |
+
for cs in runtime_cluster_shapes:
|
| 629 |
+
if cs[0] % 2 == 0:
|
| 630 |
+
valid_cluster_shapes.append(cs)
|
| 631 |
+
runtime_cluster_shapes = valid_cluster_shapes
|
| 632 |
+
|
| 633 |
+
kernel_problem_waves = problem_waves
|
| 634 |
+
if mode == "functional_L0" or mode == "functional_L1":
|
| 635 |
+
# for functional testing, we want to perturb just a little from even shapes
|
| 636 |
+
# large K = 8 is chosen such that some kernels will warp around their smem buffers, and some will not
|
| 637 |
+
# -16 ensures that we are TMA aligned even for FP8/Int8
|
| 638 |
+
min_k = alignment_ab_max if cta_tile_shape_k == alignment_ab_max else cta_tile_shape_k - alignment_ab_max
|
| 639 |
+
max_k = (cta_tile_shape_k*8) - alignment_ab_max
|
| 640 |
+
problem_shapes_k = [min_k, max_k]
|
| 641 |
+
sm_count = 16
|
| 642 |
+
swizzle_sizes = [0]
|
| 643 |
+
# Larger k and less than half wave trigger streamk +separate reduction case to be generated
|
| 644 |
+
if 'stream_k' in kernel_name:
|
| 645 |
+
problem_shapes_k = [max_k, cta_tile_shape_k*32]
|
| 646 |
+
kernel_problem_waves = [0.125, 1.25, 2.5]
|
| 647 |
+
else:
|
| 648 |
+
raise ValueError
|
| 649 |
+
|
| 650 |
+
if "void" in kernel_name:
|
| 651 |
+
beta_values = [0]
|
| 652 |
+
|
| 653 |
+
alignment_shift_m = max(alignment_c, alignment_a)
|
| 654 |
+
alignment_shift_n = max(alignment_c, alignment_b)
|
| 655 |
+
|
| 656 |
+
is_first_line = True
|
| 657 |
+
for index_waves, waves in enumerate(kernel_problem_waves):
|
| 658 |
+
for index_k, k in enumerate(problem_shapes_k):
|
| 659 |
+
for beta in beta_values:
|
| 660 |
+
for cluster_shape in runtime_cluster_shapes:
|
| 661 |
+
for runtime_input_datatype in runtime_input_datatypes:
|
| 662 |
+
for swizzle_size in swizzle_sizes:
|
| 663 |
+
grid_size = waves * sm_count
|
| 664 |
+
cluster_shape_m, cluster_shape_n, cluster_shape_k = tuple(cluster_shape)
|
| 665 |
+
if cluster_shape_m >= cluster_shape_n:
|
| 666 |
+
grid_m = cluster_shape_m
|
| 667 |
+
grid_n = grid_size / grid_m
|
| 668 |
+
grid_n = max( int((grid_n + cluster_shape_n - 1) / cluster_shape_n) * cluster_shape_n, 1)
|
| 669 |
+
else:
|
| 670 |
+
grid_n = cluster_shape_n
|
| 671 |
+
grid_m = grid_size / grid_n
|
| 672 |
+
grid_m = max( int((grid_m + cluster_shape_m - 1) / cluster_shape_m) * cluster_shape_m, 1)
|
| 673 |
+
|
| 674 |
+
verification_required = False
|
| 675 |
+
if mode == "functional_L0" or mode == "functional_L1":
|
| 676 |
+
if '_void_' not in kernel_name:
|
| 677 |
+
verification_required = True
|
| 678 |
+
|
| 679 |
+
m = max(int(grid_m * cta_tile_shape_m), alignment_ab_max)
|
| 680 |
+
n = max(int(grid_n * cta_tile_shape_n), alignment_ab_max)
|
| 681 |
+
k = int(k)
|
| 682 |
+
|
| 683 |
+
# For functional testing, we want to perturb just a little from even shapes.
|
| 684 |
+
# Only do this if the perturbation does not cause one of the dimensions of the
|
| 685 |
+
# problem size to go to zero. This can occur for blockscaling kernels for which
|
| 686 |
+
# the alignment requirements for A and B can be quite large (e.g., 256).
|
| 687 |
+
if m > alignment_shift_m:
|
| 688 |
+
m -= alignment_shift_m
|
| 689 |
+
if n > alignment_shift_n:
|
| 690 |
+
n -= alignment_shift_n
|
| 691 |
+
|
| 692 |
+
if '_n32t32_' in kernel_name:
|
| 693 |
+
continue
|
| 694 |
+
batch_count = 1
|
| 695 |
+
if mode == "functional_L0" or mode == "functional_L1" :
|
| 696 |
+
if index_waves == 0 and index_k == 0 :
|
| 697 |
+
batch_count = 3 if mode == "functional_L0" else 5
|
| 698 |
+
gemm_op = "gemm"
|
| 699 |
+
|
| 700 |
+
grouped = is_grouped(manifest.operations_by_name[kernel_name].gemm_kind)
|
| 701 |
+
num_groups = 1
|
| 702 |
+
if grouped:
|
| 703 |
+
gemm_op = "grouped_gemm"
|
| 704 |
+
num_groups = 3 # small to limit test time in host block-scaled reference kernels
|
| 705 |
+
batch_count = 1
|
| 706 |
+
elif "bstensorop" in kernel_name:
|
| 707 |
+
gemm_op = "block_scaled_gemm"
|
| 708 |
+
elif is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind):
|
| 709 |
+
gemm_op = "blockwise_gemm"
|
| 710 |
+
|
| 711 |
+
problem_size_category = ['smallK','largeK'][index_k] + '_' + ['beta==0','beta!=0'][bool(beta)]
|
| 712 |
+
|
| 713 |
+
assert m > 0 and n > 0 and k > 0
|
| 714 |
+
|
| 715 |
+
# Emit per-testcase metadata for perf testing usage, eventually in perf database
|
| 716 |
+
metadata_dict = {
|
| 717 |
+
"input_params": {
|
| 718 |
+
'problem_size_category' : problem_size_category,
|
| 719 |
+
'operation' : _getSubOperationType(operation),
|
| 720 |
+
'datatype' : data_types,
|
| 721 |
+
'layout' : layout3x,
|
| 722 |
+
'm' : m,
|
| 723 |
+
'n' : n,
|
| 724 |
+
'k' : k,
|
| 725 |
+
'beta' : beta,
|
| 726 |
+
'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta, num_groups)
|
| 727 |
+
},
|
| 728 |
+
"runtime_params": {
|
| 729 |
+
'ctas_per_mma_instruction' : ctas_per_mma_instruction,
|
| 730 |
+
'tilesize_m' : cta_tile_shape_m,
|
| 731 |
+
'tilesize_n' : cta_tile_shape_n,
|
| 732 |
+
'tilesize_k' : cta_tile_shape_k,
|
| 733 |
+
'cluster_shape_m' : cluster_shape_m,
|
| 734 |
+
'cluster_shape_n' : cluster_shape_n,
|
| 735 |
+
}
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
cluster_m_fallback = ctas_per_mma_instruction if dynamic_cluster else cluster_shape_m
|
| 739 |
+
cluster_n_fallback = 1 if dynamic_cluster else cluster_shape_n
|
| 740 |
+
cluster_k_fallback = 1 if dynamic_cluster else cluster_shape_k
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
if dynamic_datatype:
|
| 744 |
+
runtime_datatype_a, runtime_datatype_b = tuple(runtime_input_datatype)
|
| 745 |
+
metadata_dict["runtime_params"]["runtime_datatype_a"] = runtime_datatype_a
|
| 746 |
+
metadata_dict["runtime_params"]["runtime_datatype_b"] = runtime_datatype_b
|
| 747 |
+
|
| 748 |
+
testcase_metadata = [
|
| 749 |
+
f"cutlass_profiler --operation={gemm_op}" +
|
| 750 |
+
(f" --verification-providers=device --providers=cutlass" if profiler_flags_for_verification == "device" else " --mode=trace") +
|
| 751 |
+
f" --error-on-no-match --error-if-nothing-is-profiled" +
|
| 752 |
+
f" --kernels={kernel_name}" +
|
| 753 |
+
f" --m={str(m)}" +
|
| 754 |
+
f" --n={str(n)}" +
|
| 755 |
+
f" --k={str(k)}" +
|
| 756 |
+
(f" --num_groups={str(num_groups)}" if grouped else "") +
|
| 757 |
+
f" --cluster_m={str(cluster_shape_m)}" +
|
| 758 |
+
f" --cluster_n={str(cluster_shape_n)}" +
|
| 759 |
+
f" --cluster_k={str(cluster_shape_k)}" +
|
| 760 |
+
f" --cluster_m_fallback={str(cluster_m_fallback)}" +
|
| 761 |
+
f" --cluster_n_fallback={str(cluster_n_fallback)}" +
|
| 762 |
+
f" --cluster_k_fallback={str(cluster_k_fallback)}" +
|
| 763 |
+
f" --beta={str(beta)}" +
|
| 764 |
+
("" if grouped else f" --batch_count={str(batch_count)}") +
|
| 765 |
+
f" --swizzle_size={str(swizzle_size)}" +
|
| 766 |
+
f" --verification-required={str(verification_required).lower()}"
|
| 767 |
+
] \
|
| 768 |
+
|
| 769 |
+
output_dynamic_datatype = dynamic_datatype
|
| 770 |
+
if output_dynamic_datatype:
|
| 771 |
+
testcase_metadata[0] += (f" --runtime_input_datatype_a={runtime_datatype_a}" +
|
| 772 |
+
f" --runtime_input_datatype_b={runtime_datatype_b}")
|
| 773 |
+
|
| 774 |
+
testcase_metadata.append(json.dumps(metadata_dict))
|
| 775 |
+
testlist_csv_rows.append(testcase_metadata)
|
| 776 |
+
testcase_counter += 1
|
| 777 |
+
|
| 778 |
+
alpha = 1.0
|
| 779 |
+
|
| 780 |
+
if dynamic_datatype:
|
| 781 |
+
hashed_kernel_name = transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b)
|
| 782 |
+
|
| 783 |
+
# If kernel_name is new, initialize its feature set with defaults
|
| 784 |
+
if hashed_kernel_name not in kernel_features:
|
| 785 |
+
kernel_features[hashed_kernel_name] = {
|
| 786 |
+
"is_support_dynamic_cluster": False,
|
| 787 |
+
"is_support_dynamic_datatype": False,
|
| 788 |
+
}
|
| 789 |
+
|
| 790 |
+
# Update features for the hashed kernel name
|
| 791 |
+
kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] |= dynamic_cluster
|
| 792 |
+
kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] |= dynamic_datatype
|
| 793 |
+
|
| 794 |
+
if hashed_kernel_name not in auditlist_csv_params_map:
|
| 795 |
+
auditlist_csv_params_map[hashed_kernel_name] = []
|
| 796 |
+
|
| 797 |
+
audit_row_params = get_kernel_params(
|
| 798 |
+
operation,
|
| 799 |
+
hashed_kernel_name,
|
| 800 |
+
(cluster_shape_m, cluster_shape_n, cluster_shape_k),
|
| 801 |
+
(cluster_m_fallback, cluster_n_fallback, cluster_k_fallback),
|
| 802 |
+
(m, n, k, batch_count),
|
| 803 |
+
alpha, beta,
|
| 804 |
+
dynamic_datatype, dynamic_cluster
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
auditlist_csv_params_map[hashed_kernel_name].append(audit_row_params)
|
| 808 |
+
|
| 809 |
+
if hashed_kernel_name not in auditlist_csv_map:
|
| 810 |
+
audit_row = get_kernel_features(operation, hashed_kernel_name, dynamic_datatype, runtime_input_datatype)
|
| 811 |
+
auditlist_csv_map[hashed_kernel_name] = audit_row
|
| 812 |
+
|
| 813 |
+
with open(outfile_name, 'w') as testlist_csv:
|
| 814 |
+
csv_writer = csv.writer(testlist_csv, delimiter=',')
|
| 815 |
+
csv_writer.writerow(testlist_csv_fields)
|
| 816 |
+
csv_writer.writerows(testlist_csv_rows)
|
| 817 |
+
|
| 818 |
+
with open(audit_file_name, 'w') as auditlist_csv:
|
| 819 |
+
csv_writer = csv.writer(auditlist_csv, delimiter=',')
|
| 820 |
+
csv_writer.writerow(audit_csv_fields)
|
| 821 |
+
for hashed_kernel_name, row in auditlist_csv_map.items():
|
| 822 |
+
# Append the dynamic features as "Y" or "N"
|
| 823 |
+
dynamic_cluster_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] else "N"
|
| 824 |
+
dynamic_datatype_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] else "N"
|
| 825 |
+
test_count = len(auditlist_csv_params_map[hashed_kernel_name])
|
| 826 |
+
csv_writer.writerow(row + [dynamic_cluster_flag, dynamic_datatype_flag, test_count])
|
| 827 |
+
|
| 828 |
+
with open(audit_file_params_name, 'w') as auditlist_csv:
|
| 829 |
+
csv_writer = csv.writer(auditlist_csv, delimiter=',')
|
| 830 |
+
csv_writer.writerow(audit_csv_runtime_fields)
|
| 831 |
+
for kernel_index, (hashed_kernel_name, rows) in enumerate(auditlist_csv_params_map.items(), start=1):
|
| 832 |
+
for i, row in enumerate(rows):
|
| 833 |
+
if i == 0:
|
| 834 |
+
csv_writer.writerow([kernel_index, hashed_kernel_name] + row)
|
| 835 |
+
else:
|
| 836 |
+
csv_writer.writerow(["", ""] + row)
|
| 837 |
+
|
| 838 |
+
print(f"Generated a total of {testcase_counter} test cases for {kernels_emitted} kernels out of {kernels_total} total.")
|
| 839 |
+
|
| 840 |
+
# Generate a newline separated list of kernel filters
|
| 841 |
+
assert(len(kernel_name_set) == kernels_emitted)
|
| 842 |
+
output_filter_enabled = True
|
| 843 |
+
if output_filter_enabled:
|
| 844 |
+
kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list")
|
| 845 |
+
with open(kernel_filter_outfile_name, "w") as file:
|
| 846 |
+
kernel_name_set = set(map(lambda x: x.replace("_epi_tma", ""), kernel_name_set))
|
| 847 |
+
for kernel_name in kernel_name_set:
|
| 848 |
+
file.write(kernel_name + "\n")
|
| 849 |
+
|
| 850 |
+
# Sort L0 and L1 kernel list and csv file to avoid mixing cutlass3.x kernels and sm120_mma kernels in cutlass2.x generated together.
|
| 851 |
+
if mode == "functional_L0" or mode == "functional_L1":
|
| 852 |
+
# Sort the .csv file
|
| 853 |
+
outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv")
|
| 854 |
+
with open(outfile_name) as file:
|
| 855 |
+
data = file.readlines()
|
| 856 |
+
data.sort()
|
| 857 |
+
with open(outfile_name, 'w') as file:
|
| 858 |
+
for i in range(len(data)):
|
| 859 |
+
file.write(data[i])
|
| 860 |
+
# Sort the kernel list
|
| 861 |
+
kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list")
|
| 862 |
+
with open(kernel_filter_outfile_name) as file:
|
| 863 |
+
data = file.readlines()
|
| 864 |
+
data.sort()
|
| 865 |
+
with open(kernel_filter_outfile_name, 'w') as file:
|
| 866 |
+
for i in range(len(data)):
|
| 867 |
+
file.write(data[i])
|
| 868 |
+
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py
ADDED
|
@@ -0,0 +1,1613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for emitting GEMM kernels
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import collections
|
| 38 |
+
import enum
|
| 39 |
+
import functools
|
| 40 |
+
import logging
|
| 41 |
+
import operator
|
| 42 |
+
import os.path
|
| 43 |
+
import shutil
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
import builtins
|
| 47 |
+
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 48 |
+
raise ImportError("Disabling attempt to import cutlass_library")
|
| 49 |
+
from cutlass_library.library import *
|
| 50 |
+
except ImportError:
|
| 51 |
+
from library import *
|
| 52 |
+
|
| 53 |
+
_LOGGER = logging.getLogger(__name__)
|
| 54 |
+
|
| 55 |
+
###################################################################################################
|
| 56 |
+
#
|
| 57 |
+
# Data structure modeling a GEMM operation
|
| 58 |
+
#
|
| 59 |
+
###################################################################################################
|
| 60 |
+
|
| 61 |
+
#
|
| 62 |
+
class GemmOperation:
|
| 63 |
+
#
|
| 64 |
+
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
|
| 65 |
+
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
|
| 66 |
+
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
|
| 67 |
+
tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False,
|
| 68 |
+
ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None,
|
| 69 |
+
ScaleFactorMVecSize = None, ScaleFactorNVecSize = None, ScaleFactorKVecSize = None):
|
| 70 |
+
|
| 71 |
+
kinds_3x = {
|
| 72 |
+
GemmKind.Universal3x,
|
| 73 |
+
GemmKind.SparseUniversal3x,
|
| 74 |
+
GemmKind.BlockScaledUniversal3x,
|
| 75 |
+
GemmKind.GroupedUniversal3x,
|
| 76 |
+
GemmKind.GroupedBlockScaledUniversal3x,
|
| 77 |
+
GemmKind.BlockwiseUniversal3x,
|
| 78 |
+
GemmKind.GroupedBlockwiseUniversal3x,
|
| 79 |
+
}
|
| 80 |
+
self.is_3x = gemm_kind in kinds_3x
|
| 81 |
+
self.prefix = "3x" if self.is_3x else ""
|
| 82 |
+
self.operation_kind = OperationKind.Gemm
|
| 83 |
+
self.arch = arch
|
| 84 |
+
self.tile_description = tile_description
|
| 85 |
+
self.gemm_kind = gemm_kind
|
| 86 |
+
self.A = A
|
| 87 |
+
self.B = B
|
| 88 |
+
self.C = C
|
| 89 |
+
self.D = D
|
| 90 |
+
|
| 91 |
+
if is_block_scaled(gemm_kind):
|
| 92 |
+
self.ScaleFactorA = ScaleFactorA
|
| 93 |
+
self.ScaleFactorB = ScaleFactorB
|
| 94 |
+
self.ScaleFactorD = ScaleFactorD["tensor"]
|
| 95 |
+
self.ScaleFactorVectorSize = ScaleFactorD["vector_size"]
|
| 96 |
+
|
| 97 |
+
if is_blockwise(gemm_kind):
|
| 98 |
+
self.ScaleFactorMVecSize = ScaleFactorMVecSize
|
| 99 |
+
self.ScaleFactorNVecSize = ScaleFactorNVecSize
|
| 100 |
+
self.ScaleFactorKVecSize = ScaleFactorKVecSize
|
| 101 |
+
|
| 102 |
+
if self.D == None:
|
| 103 |
+
self.D = self.C
|
| 104 |
+
|
| 105 |
+
if not self.is_3x:
|
| 106 |
+
assert(kernel_schedule == KernelScheduleType.ScheduleAuto)
|
| 107 |
+
assert(epilogue_schedule == EpilogueScheduleType.ScheduleAuto)
|
| 108 |
+
self.kernel_schedule = kernel_schedule
|
| 109 |
+
self.epilogue_schedule = epilogue_schedule
|
| 110 |
+
self.element_epilogue = element_epilogue
|
| 111 |
+
self.epilogue_functor = epilogue_functor
|
| 112 |
+
|
| 113 |
+
if self.is_3x and epilogue_functor == EpilogueFunctor.LinearCombination:
|
| 114 |
+
self.epilogue_functor = EpilogueFunctor3x.LinearCombination
|
| 115 |
+
|
| 116 |
+
self.swizzling_functor = swizzling_functor
|
| 117 |
+
self.tile_scheduler = tile_scheduler
|
| 118 |
+
|
| 119 |
+
# Only enable mixed input mode and mixed input shuffle for Hopper
|
| 120 |
+
self.mixed_input_mode = None
|
| 121 |
+
if self.is_mixed_input() and self.arch >= 90 and self.arch < 100:
|
| 122 |
+
self.mixed_input_mode = mixed_input_mode
|
| 123 |
+
self.mixed_input_shuffle = (self.mixed_input_mode is not None) and mixed_input_shuffle
|
| 124 |
+
|
| 125 |
+
#
|
| 126 |
+
def is_complex(self):
|
| 127 |
+
complex_operators = [
|
| 128 |
+
MathOperation.multiply_add_complex,
|
| 129 |
+
MathOperation.multiply_add_complex_gaussian,
|
| 130 |
+
MathOperation.multiply_add_complex_fast_f32
|
| 131 |
+
]
|
| 132 |
+
return self.tile_description.math_instruction.math_operation in complex_operators
|
| 133 |
+
|
| 134 |
+
#
|
| 135 |
+
def is_mixed_input(self):
|
| 136 |
+
return self.A.element != self.B.element
|
| 137 |
+
|
| 138 |
+
#
|
| 139 |
+
def is_planar_complex(self):
|
| 140 |
+
return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
|
| 141 |
+
|
| 142 |
+
#
|
| 143 |
+
def accumulator_type(self):
|
| 144 |
+
accum = self.tile_description.math_instruction.element_accumulator
|
| 145 |
+
|
| 146 |
+
if self.is_complex():
|
| 147 |
+
return get_complex_from_real(accum)
|
| 148 |
+
|
| 149 |
+
return accum
|
| 150 |
+
|
| 151 |
+
#
|
| 152 |
+
def short_math_name(self):
|
| 153 |
+
if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
|
| 154 |
+
return "g%s" % ShortDataTypeNames[self.accumulator_type()]
|
| 155 |
+
return ShortDataTypeNames[self.accumulator_type()]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
#
|
| 159 |
+
def core_name(self):
|
| 160 |
+
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
| 161 |
+
|
| 162 |
+
inst_shape = ''
|
| 163 |
+
inst_operation = ''
|
| 164 |
+
intermediate_type = ''
|
| 165 |
+
|
| 166 |
+
math_operations_map = {
|
| 167 |
+
MathOperation.xor_popc: 'xor',
|
| 168 |
+
MathOperation.and_popc: 'and',
|
| 169 |
+
MathOperation.multiply_add_fast_accum: 'fastaccum',
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
tensor_ops = [
|
| 173 |
+
OpcodeClass.TensorOp,
|
| 174 |
+
OpcodeClass.WmmaTensorOp,
|
| 175 |
+
OpcodeClass.SparseTensorOp,
|
| 176 |
+
OpcodeClass.BlockScaledTensorOp,
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
is_tensor_op = self.tile_description.math_instruction.opcode_class in tensor_ops
|
| 180 |
+
|
| 181 |
+
if is_tensor_op:
|
| 182 |
+
|
| 183 |
+
math_op = self.tile_description.math_instruction.math_operation
|
| 184 |
+
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
|
| 185 |
+
|
| 186 |
+
inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape)) if not self.is_3x else ""
|
| 187 |
+
|
| 188 |
+
inst_shape += math_op_string
|
| 189 |
+
|
| 190 |
+
if self.tile_description.math_instruction.element_a != self.A.element and \
|
| 191 |
+
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
| 192 |
+
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
| 193 |
+
|
| 194 |
+
short_math_name = self.short_math_name() if not self.is_3x else ""
|
| 195 |
+
|
| 196 |
+
return "%s%s%s%s" % (short_math_name, inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
|
| 197 |
+
|
| 198 |
+
# Generates a string representing the MMA instruction.
|
| 199 |
+
def extended_name(self):
|
| 200 |
+
''' Append data types if they differ from compute type. '''
|
| 201 |
+
element_sfa = ""
|
| 202 |
+
element_sfb = ""
|
| 203 |
+
if self.is_complex():
|
| 204 |
+
extended_name = "${core_name}"
|
| 205 |
+
else:
|
| 206 |
+
if self.is_mixed_input():
|
| 207 |
+
extended_name = "${core_name}_${element_a}_${element_b}"
|
| 208 |
+
if self.C.element != self.tile_description.math_instruction.element_accumulator:
|
| 209 |
+
extended_name = "${element_c}_" + extended_name
|
| 210 |
+
elif is_blockwise(self.gemm_kind):
|
| 211 |
+
extended_name = "${core_name}_${element_sfa}x${element_a}_${element_sfb}x${element_b}"
|
| 212 |
+
element_sfa = DataTypeNames[self.accumulator_type()]
|
| 213 |
+
element_sfb = DataTypeNames[self.accumulator_type()]
|
| 214 |
+
else:
|
| 215 |
+
extended_name = "${core_name}"
|
| 216 |
+
if self.C.element != self.tile_description.math_instruction.element_accumulator:
|
| 217 |
+
extended_name = "${element_c}_" + extended_name
|
| 218 |
+
if self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 219 |
+
extended_name += "_${element_a}"
|
| 220 |
+
|
| 221 |
+
extended_name = SubstituteTemplate(extended_name, {
|
| 222 |
+
'element_a': DataTypeNames[self.A.element],
|
| 223 |
+
'element_sfa' : element_sfa,
|
| 224 |
+
'element_b': DataTypeNames[self.B.element],
|
| 225 |
+
'element_sfb' : element_sfb,
|
| 226 |
+
'element_c': DataTypeNames[self.C.element],
|
| 227 |
+
'core_name': self.core_name()
|
| 228 |
+
})
|
| 229 |
+
|
| 230 |
+
return extended_name
|
| 231 |
+
|
| 232 |
+
#
|
| 233 |
+
def mixed_input_mode_name(self):
|
| 234 |
+
mode_name_mapping = {
|
| 235 |
+
MixedInputMode.ConvertOnly: "_cvt",
|
| 236 |
+
MixedInputMode.ScaleOnly: "_scl",
|
| 237 |
+
MixedInputMode.ScaleWithZeroPoint: "_sclzr"
|
| 238 |
+
}
|
| 239 |
+
mode_name = mode_name_mapping.get(self.mixed_input_mode, "")
|
| 240 |
+
if self.mixed_input_shuffle:
|
| 241 |
+
mode_name = mode_name + "_shfl"
|
| 242 |
+
return mode_name
|
| 243 |
+
|
| 244 |
+
def extended_name_3x(self):
|
| 245 |
+
'''Generates a string representing the MMA atom. Assumes accumulator type is C type.'''
|
| 246 |
+
extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
| 247 |
+
element_a = DataTypeNames[self.A.element],
|
| 248 |
+
element_b = DataTypeNames[self.B.element],
|
| 249 |
+
element_acc = DataTypeNames[self.accumulator_type()],
|
| 250 |
+
element_c = DataTypeNames[self.C.element],
|
| 251 |
+
element_d = DataTypeNames[self.D.element],
|
| 252 |
+
core_name = self.core_name())
|
| 253 |
+
|
| 254 |
+
if is_block_scaled(self.gemm_kind):
|
| 255 |
+
d_type_names = DataTypeNames[self.D.element]
|
| 256 |
+
|
| 257 |
+
if self.ScaleFactorD.element != DataType.void:
|
| 258 |
+
d_type_names = DataTypeNames[self.ScaleFactorD.element] + "x" + d_type_names
|
| 259 |
+
|
| 260 |
+
extended_name = "{core_name}_{element_sfa}x{element_a}_{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
| 261 |
+
element_sfa = DataTypeNames[self.ScaleFactorA],
|
| 262 |
+
element_a = DataTypeNames[self.A.element],
|
| 263 |
+
element_sfb = DataTypeNames[self.ScaleFactorB],
|
| 264 |
+
element_b = DataTypeNames[self.B.element],
|
| 265 |
+
element_acc = DataTypeNames[self.accumulator_type()],
|
| 266 |
+
element_c = DataTypeNames[self.C.element],
|
| 267 |
+
element_d = d_type_names,
|
| 268 |
+
core_name = self.core_name())
|
| 269 |
+
|
| 270 |
+
if is_blockwise(self.gemm_kind):
|
| 271 |
+
d_type_names = DataTypeNames[self.D.element]
|
| 272 |
+
|
| 273 |
+
extended_name = "{core_name}_{sfvec_m_size}x{sfvec_k_size}{element_sfa}x{element_a}_{sfvec_n_size}x{sfvec_k_size}{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
| 274 |
+
element_sfa = DataTypeNames[self.accumulator_type()],
|
| 275 |
+
element_a = DataTypeNames[self.A.element],
|
| 276 |
+
element_sfb = DataTypeNames[self.accumulator_type()],
|
| 277 |
+
element_b = DataTypeNames[self.B.element],
|
| 278 |
+
element_acc = DataTypeNames[self.accumulator_type()],
|
| 279 |
+
element_c = DataTypeNames[self.C.element],
|
| 280 |
+
element_d = d_type_names,
|
| 281 |
+
sfvec_m_size = self.ScaleFactorMVecSize,
|
| 282 |
+
sfvec_n_size = self.ScaleFactorNVecSize,
|
| 283 |
+
sfvec_k_size = self.ScaleFactorKVecSize,
|
| 284 |
+
core_name = self.core_name())
|
| 285 |
+
|
| 286 |
+
if self.mixed_input_mode != None:
|
| 287 |
+
extended_name = extended_name + self.mixed_input_mode_name()
|
| 288 |
+
return extended_name
|
| 289 |
+
|
| 290 |
+
def datatype_name_3x(self):
|
| 291 |
+
'''Generates a string representing the MMA atom. Assumes accumulator type is C type.'''
|
| 292 |
+
datatype_name = "{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
| 293 |
+
element_a = DataTypeNames[self.A.element],
|
| 294 |
+
element_b = DataTypeNames[self.B.element],
|
| 295 |
+
element_acc = DataTypeNames[self.accumulator_type()],
|
| 296 |
+
element_c = DataTypeNames[self.C.element],
|
| 297 |
+
element_d = DataTypeNames[self.D.element])
|
| 298 |
+
return datatype_name
|
| 299 |
+
|
| 300 |
+
# Generates a short string representing the AB layout tags (e.g. nt or tn)
|
| 301 |
+
def layout_name(self):
|
| 302 |
+
if self.is_complex() or self.is_planar_complex():
|
| 303 |
+
return "%s%s" % (
|
| 304 |
+
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
|
| 305 |
+
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
|
| 306 |
+
)
|
| 307 |
+
return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
|
| 308 |
+
|
| 309 |
+
# Generates a short string representing the ABC layout tags (e.g. ntn or tnn)
|
| 310 |
+
def layout_name_3x(self):
|
| 311 |
+
if self.is_complex() or self.is_planar_complex():
|
| 312 |
+
return "{}{}{}".format(
|
| 313 |
+
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
|
| 314 |
+
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)],
|
| 315 |
+
ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)])
|
| 316 |
+
else:
|
| 317 |
+
return "{}{}{}".format(
|
| 318 |
+
ShortLayoutTypeNames[self.A.layout],
|
| 319 |
+
ShortLayoutTypeNames[self.B.layout],
|
| 320 |
+
ShortLayoutTypeNames[self.C.layout])
|
| 321 |
+
|
| 322 |
+
# Generates a short string representing underlying kernel schedule type
|
| 323 |
+
def kernel_schedule_name_3x(self):
|
| 324 |
+
return KernelScheduleSuffixes[self.kernel_schedule]
|
| 325 |
+
|
| 326 |
+
# Generates a short string representing underlying epilogue schedule type
|
| 327 |
+
def epilogue_schedule_name_3x(self):
|
| 328 |
+
|
| 329 |
+
if is_block_scaled(self.gemm_kind):
|
| 330 |
+
if self.ScaleFactorD.element != DataType.void:
|
| 331 |
+
return EpilogueScheduleSuffixes[self.epilogue_schedule] + "_epiVs" + str(self.ScaleFactorVectorSize)+ShortLayoutTypeNames[self.ScaleFactorD.layout]
|
| 332 |
+
|
| 333 |
+
return EpilogueScheduleSuffixes[self.epilogue_schedule]
|
| 334 |
+
|
| 335 |
+
# Generate a short string representing the operation class
|
| 336 |
+
def opcode_class_name(self):
|
| 337 |
+
return OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
| 338 |
+
|
| 339 |
+
def get_collective_tile_shape(self):
|
| 340 |
+
"""
|
| 341 |
+
Get the tile shape passed to the collective builder.
|
| 342 |
+
On Blackwell, this is different than the operation.tile_description.tile_shape.
|
| 343 |
+
"""
|
| 344 |
+
is_sm100_kernel = (self.arch == 100 or self.arch == 103)
|
| 345 |
+
if not is_sm100_kernel:
|
| 346 |
+
return self.tile_description.tile_shape
|
| 347 |
+
|
| 348 |
+
opcode_class_main = self.tile_description.math_instruction.opcode_class
|
| 349 |
+
instruction_shape = self.tile_description.math_instruction.instruction_shape
|
| 350 |
+
tile_shape_m, tile_shape_n, tile_shape_k = self.tile_description.tile_shape
|
| 351 |
+
if opcode_class_main in [OpcodeClass.TensorOp, OpcodeClass.BlockScaledTensorOp, OpcodeClass.SparseTensorOp]:
|
| 352 |
+
tile_shape_m = instruction_shape[0]
|
| 353 |
+
tile_shape_n = instruction_shape[1]
|
| 354 |
+
return (tile_shape_m, tile_shape_n, tile_shape_k)
|
| 355 |
+
|
| 356 |
+
# Generates the full kernel function name
|
| 357 |
+
def procedural_name(self):
|
| 358 |
+
return self._procedural_name
|
| 359 |
+
|
| 360 |
+
@functools.cached_property
|
| 361 |
+
def _procedural_name(self):
|
| 362 |
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 363 |
+
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
| 364 |
+
if self.arch >= 90:
|
| 365 |
+
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}{ct}{cs}_{l}_{s}_align{al}{t}{k}{e}"
|
| 366 |
+
tile_shape = self.get_collective_tile_shape()
|
| 367 |
+
return kernel_name_template.format(
|
| 368 |
+
p = self.prefix,
|
| 369 |
+
ar = self.arch,
|
| 370 |
+
op = opcode_class_name,
|
| 371 |
+
ex = self.extended_name_3x(),
|
| 372 |
+
ct = '_' + 'x'.join([str(i) for i in tile_shape]) if tile_shape[0] > 0 else "",
|
| 373 |
+
cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]),
|
| 374 |
+
l = self.tile_description.stages,
|
| 375 |
+
s = self.layout_name_3x(),
|
| 376 |
+
al = str(max(self.A.alignment, self.B.alignment)),
|
| 377 |
+
t = TileSchedulerSuffixes[self.tile_scheduler],
|
| 378 |
+
k = self.kernel_schedule_name_3x(),
|
| 379 |
+
e = self.epilogue_schedule_name_3x())
|
| 380 |
+
else:
|
| 381 |
+
threadblock = self.tile_description.procedural_name()
|
| 382 |
+
return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format(
|
| 383 |
+
p = self.prefix,
|
| 384 |
+
op = opcode_class_name,
|
| 385 |
+
ex = self.extended_name(),
|
| 386 |
+
tb = threadblock,
|
| 387 |
+
l = self.layout_name(),
|
| 388 |
+
a = str(max(self.A.alignment, self.B.alignment)))
|
| 389 |
+
|
| 390 |
+
#
|
| 391 |
+
def configuration_name(self):
|
| 392 |
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 393 |
+
return self.procedural_name()
|
| 394 |
+
|
| 395 |
+
def __hash__(self):
|
| 396 |
+
return hash(self.configuration_name())
|
| 397 |
+
|
| 398 |
+
def __eq__(self, other):
|
| 399 |
+
return self.configuration_name() == other.configuration_name()
|
| 400 |
+
|
| 401 |
+
###################################################################################################
|
| 402 |
+
#
|
| 403 |
+
# Data structure modeling a grouped GEMM operation
|
| 404 |
+
#
|
| 405 |
+
###################################################################################################
|
| 406 |
+
|
| 407 |
+
#
|
| 408 |
+
class GroupedGemmOperation(GemmOperation):
|
| 409 |
+
#
|
| 410 |
+
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
|
| 411 |
+
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \
|
| 412 |
+
scheduler_mode = GroupScheduleMode.Device):
|
| 413 |
+
super().__init__(gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
|
| 414 |
+
epilogue_functor, swizzling_functor)
|
| 415 |
+
|
| 416 |
+
self.scheduler_mode = scheduler_mode
|
| 417 |
+
|
| 418 |
+
#
|
| 419 |
+
def procedural_name(self):
|
| 420 |
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 421 |
+
base = super().procedural_name()
|
| 422 |
+
return SubstituteTemplate(
|
| 423 |
+
base + "_schedule${schedule}",
|
| 424 |
+
{
|
| 425 |
+
'schedule': ShortGroupScheduleModeNames[self.scheduler_mode]
|
| 426 |
+
})
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
###################################################################################################
|
| 430 |
+
#
|
| 431 |
+
# Emits single instances of a CUTLASS device-wide operator
|
| 432 |
+
#
|
| 433 |
+
###################################################################################################
|
| 434 |
+
|
| 435 |
+
#
|
| 436 |
+
class EmitGemmInstance:
|
| 437 |
+
''' Responsible for emitting a CUTLASS template definition'''
|
| 438 |
+
|
| 439 |
+
def __init__(self, operation_suffix = ''):
|
| 440 |
+
self.operation_suffix = operation_suffix
|
| 441 |
+
self.includes = []
|
| 442 |
+
self.gemm_template = """
|
| 443 |
+
// Gemm operator ${operation_name}
|
| 444 |
+
using Operation_${operation_name} = cutlass::gemm::device::Gemm<
|
| 445 |
+
${element_a}, ${layout_a},
|
| 446 |
+
${element_b}, ${layout_b},
|
| 447 |
+
${element_c}, ${layout_c},
|
| 448 |
+
${element_accumulator},
|
| 449 |
+
${opcode_class},
|
| 450 |
+
${arch},
|
| 451 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 452 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 453 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 454 |
+
${epilogue_functor}<
|
| 455 |
+
${element_c},
|
| 456 |
+
${epilogue_vector_length},
|
| 457 |
+
${element_accumulator},
|
| 458 |
+
${element_epilogue}
|
| 459 |
+
>,
|
| 460 |
+
${swizzling_functor},
|
| 461 |
+
${stages},
|
| 462 |
+
${align_a},
|
| 463 |
+
${align_b},
|
| 464 |
+
false,
|
| 465 |
+
${math_operation}
|
| 466 |
+
${residual}
|
| 467 |
+
>;
|
| 468 |
+
"""
|
| 469 |
+
self.gemm_complex_template = """
|
| 470 |
+
// Gemm operator ${operation_name}
|
| 471 |
+
using Operation_${operation_name} = cutlass::gemm::device::GemmComplex<
|
| 472 |
+
${element_a}, ${layout_a},
|
| 473 |
+
${element_b}, ${layout_b},
|
| 474 |
+
${element_c}, ${layout_c},
|
| 475 |
+
${element_accumulator},
|
| 476 |
+
${opcode_class},
|
| 477 |
+
${arch},
|
| 478 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 479 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 480 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 481 |
+
${epilogue_functor}<
|
| 482 |
+
${element_c},
|
| 483 |
+
${epilogue_vector_length},
|
| 484 |
+
${element_accumulator},
|
| 485 |
+
${element_epilogue}
|
| 486 |
+
>,
|
| 487 |
+
${swizzling_functor},
|
| 488 |
+
${stages},
|
| 489 |
+
${transform_a},
|
| 490 |
+
${transform_b},
|
| 491 |
+
${math_operation}
|
| 492 |
+
${residual}
|
| 493 |
+
>;
|
| 494 |
+
"""
|
| 495 |
+
|
| 496 |
+
#
|
| 497 |
+
def instance_template(self):
|
| 498 |
+
return """
|
| 499 |
+
${compile_guard_start}
|
| 500 |
+
manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
|
| 501 |
+
${compile_guard_end}
|
| 502 |
+
"""
|
| 503 |
+
|
| 504 |
+
#
|
| 505 |
+
def emit(self, operation):
|
| 506 |
+
|
| 507 |
+
warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
|
| 508 |
+
|
| 509 |
+
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
| 510 |
+
|
| 511 |
+
residual = ''
|
| 512 |
+
|
| 513 |
+
values = {
|
| 514 |
+
'operation_name': operation.procedural_name(),
|
| 515 |
+
'element_a': DataTypeTag[operation.A.element],
|
| 516 |
+
'layout_a': LayoutTag[operation.A.layout],
|
| 517 |
+
'element_b': DataTypeTag[operation.B.element],
|
| 518 |
+
'layout_b': LayoutTag[operation.B.layout],
|
| 519 |
+
'element_c': DataTypeTag[operation.C.element],
|
| 520 |
+
'layout_c': LayoutTag[operation.C.layout],
|
| 521 |
+
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
| 522 |
+
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 523 |
+
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
| 524 |
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
| 525 |
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
| 526 |
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
| 527 |
+
'warp_shape_m': str(warp_shape[0]),
|
| 528 |
+
'warp_shape_n': str(warp_shape[1]),
|
| 529 |
+
'warp_shape_k': str(warp_shape[2]),
|
| 530 |
+
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 531 |
+
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 532 |
+
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 533 |
+
'epilogue_vector_length': str(epilogue_vector_length),
|
| 534 |
+
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 535 |
+
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
| 536 |
+
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
| 537 |
+
'stages': str(operation.tile_description.stages),
|
| 538 |
+
'align_a': str(operation.A.alignment),
|
| 539 |
+
'align_b': str(operation.B.alignment),
|
| 540 |
+
'transform_a': ComplexTransformTag[operation.A.complex_transform],
|
| 541 |
+
'transform_b': ComplexTransformTag[operation.B.complex_transform],
|
| 542 |
+
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
| 543 |
+
'residual': residual
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
template = self.gemm_complex_template if operation.is_complex() else self.gemm_template
|
| 547 |
+
|
| 548 |
+
return SubstituteTemplate(template, values)
|
| 549 |
+
|
| 550 |
+
###################################################################################################
|
| 551 |
+
|
| 552 |
+
class EmitSparseGemmInstance:
|
| 553 |
+
''' Responsible for emitting a CUTLASS template definition'''
|
| 554 |
+
|
| 555 |
+
def __init__(self, operation_suffix = ''):
|
| 556 |
+
self.operation_suffix = operation_suffix
|
| 557 |
+
self.includes = []
|
| 558 |
+
self.gemm_template = """
|
| 559 |
+
// Gemm operator ${operation_name}
|
| 560 |
+
using Operation_${operation_name} = cutlass::gemm::device::SparseGemm<
|
| 561 |
+
${element_a}, ${layout_a},
|
| 562 |
+
${element_b}, ${layout_b},
|
| 563 |
+
${element_c}, ${layout_c},
|
| 564 |
+
${element_accumulator},
|
| 565 |
+
${opcode_class},
|
| 566 |
+
${arch},
|
| 567 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 568 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 569 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 570 |
+
${epilogue_functor}<
|
| 571 |
+
${element_c},
|
| 572 |
+
${epilogue_vector_length},
|
| 573 |
+
${element_accumulator},
|
| 574 |
+
${element_epilogue}
|
| 575 |
+
>,
|
| 576 |
+
${swizzling_functor},
|
| 577 |
+
${stages},
|
| 578 |
+
${align_a},
|
| 579 |
+
${align_b},
|
| 580 |
+
false,
|
| 581 |
+
${math_operation}
|
| 582 |
+
${residual}
|
| 583 |
+
>;
|
| 584 |
+
"""
|
| 585 |
+
|
| 586 |
+
#
|
| 587 |
+
def instance_template(self):
|
| 588 |
+
return """
|
| 589 |
+
${compile_guard_start}
|
| 590 |
+
manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
|
| 591 |
+
${compile_guard_end}
|
| 592 |
+
"""
|
| 593 |
+
|
| 594 |
+
#
|
| 595 |
+
def emit(self, operation):
|
| 596 |
+
|
| 597 |
+
warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
|
| 598 |
+
|
| 599 |
+
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
| 600 |
+
|
| 601 |
+
residual = ''
|
| 602 |
+
|
| 603 |
+
values = {
|
| 604 |
+
'operation_name': operation.procedural_name(),
|
| 605 |
+
'element_a': DataTypeTag[operation.A.element],
|
| 606 |
+
'layout_a': LayoutTag[operation.A.layout],
|
| 607 |
+
'element_b': DataTypeTag[operation.B.element],
|
| 608 |
+
'layout_b': LayoutTag[operation.B.layout],
|
| 609 |
+
'element_c': DataTypeTag[operation.C.element],
|
| 610 |
+
'layout_c': LayoutTag[operation.C.layout],
|
| 611 |
+
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
| 612 |
+
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 613 |
+
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
| 614 |
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
| 615 |
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
| 616 |
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
| 617 |
+
'warp_shape_m': str(warp_shape[0]),
|
| 618 |
+
'warp_shape_n': str(warp_shape[1]),
|
| 619 |
+
'warp_shape_k': str(warp_shape[2]),
|
| 620 |
+
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 621 |
+
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 622 |
+
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 623 |
+
'epilogue_vector_length': str(epilogue_vector_length),
|
| 624 |
+
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 625 |
+
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
| 626 |
+
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
| 627 |
+
'stages': str(operation.tile_description.stages),
|
| 628 |
+
'align_a': str(operation.A.alignment),
|
| 629 |
+
'align_b': str(operation.B.alignment),
|
| 630 |
+
'transform_a': ComplexTransformTag[operation.A.complex_transform],
|
| 631 |
+
'transform_b': ComplexTransformTag[operation.B.complex_transform],
|
| 632 |
+
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
| 633 |
+
'residual': residual
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
template = self.gemm_template
|
| 637 |
+
|
| 638 |
+
return SubstituteTemplate(template, values)
|
| 639 |
+
|
| 640 |
+
###################################################################################################
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
#
|
| 644 |
+
class EmitGemmUniversalInstance:
|
| 645 |
+
''' Responsible for emitting a CUTLASS template definition'''
|
| 646 |
+
|
| 647 |
+
def __init__(self, operation_suffix = ''):
|
| 648 |
+
self.operation_suffix = operation_suffix
|
| 649 |
+
self.includes = [
|
| 650 |
+
"cutlass/cutlass.h",
|
| 651 |
+
"cutlass/numeric_types.h",
|
| 652 |
+
"cutlass/arch/arch.h",
|
| 653 |
+
"cutlass/arch/mma.h",
|
| 654 |
+
"cutlass/layout/matrix.h",
|
| 655 |
+
"cutlass/gemm/device/gemm.h",
|
| 656 |
+
"cutlass/gemm/device/gemm_universal_adapter.h",
|
| 657 |
+
"cutlass/gemm/kernel/default_gemm_universal.h",
|
| 658 |
+
]
|
| 659 |
+
self.builtin_epilogue_functor_template = """
|
| 660 |
+
${epilogue_functor}<
|
| 661 |
+
${element_c},
|
| 662 |
+
${epilogue_vector_length},
|
| 663 |
+
${element_accumulator},
|
| 664 |
+
${element_epilogue}
|
| 665 |
+
>
|
| 666 |
+
"""
|
| 667 |
+
self.gemm_template = """
|
| 668 |
+
// Gemm operator ${operation_name}
|
| 669 |
+
using ${operation_name}_base =
|
| 670 |
+
typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
| 671 |
+
${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand
|
| 672 |
+
${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand
|
| 673 |
+
${element_c}, ${layout_c},
|
| 674 |
+
${element_accumulator},
|
| 675 |
+
${opcode_class},
|
| 676 |
+
${arch},
|
| 677 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 678 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 679 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 680 |
+
${epilogue_functor},
|
| 681 |
+
${swizzling_functor},
|
| 682 |
+
${stages},
|
| 683 |
+
${math_operation}
|
| 684 |
+
>::GemmKernel;
|
| 685 |
+
|
| 686 |
+
// Define named type
|
| 687 |
+
struct ${operation_name}${operation_suffix} :
|
| 688 |
+
public ${operation_name}_base { };
|
| 689 |
+
"""
|
| 690 |
+
self.gemm_template_interleaved = """
|
| 691 |
+
// Gemm operator ${operation_name}
|
| 692 |
+
using ${operation_name}_base =
|
| 693 |
+
typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
| 694 |
+
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
| 695 |
+
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
| 696 |
+
${element_c}, ${layout_c},
|
| 697 |
+
${element_accumulator},
|
| 698 |
+
${opcode_class},
|
| 699 |
+
${arch},
|
| 700 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 701 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 702 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 703 |
+
${epilogue_functor},
|
| 704 |
+
${swizzling_functor},
|
| 705 |
+
${stages},
|
| 706 |
+
${math_operation}
|
| 707 |
+
>::GemmKernel;
|
| 708 |
+
|
| 709 |
+
// Define named type
|
| 710 |
+
struct ${operation_name}${operation_suffix} :
|
| 711 |
+
public ${operation_name}_base { };
|
| 712 |
+
"""
|
| 713 |
+
|
| 714 |
+
#
|
| 715 |
+
def instance_template(self):
|
| 716 |
+
return """
|
| 717 |
+
${compile_guard_start}
|
| 718 |
+
manifest.append(new ${gemm_kind}<
|
| 719 |
+
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
| 720 |
+
>("${operation_name}"));
|
| 721 |
+
${compile_guard_end}
|
| 722 |
+
"""
|
| 723 |
+
|
| 724 |
+
#
|
| 725 |
+
def emit(self, operation):
|
| 726 |
+
|
| 727 |
+
threadblock_shape = operation.tile_description.threadblock_shape
|
| 728 |
+
warp_count = operation.tile_description.warp_count
|
| 729 |
+
|
| 730 |
+
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
| 731 |
+
|
| 732 |
+
transpose_layouts = {
|
| 733 |
+
LayoutType.ColumnMajor: LayoutType.RowMajor,
|
| 734 |
+
LayoutType.RowMajor: LayoutType.ColumnMajor
|
| 735 |
+
}
|
| 736 |
+
|
| 737 |
+
if operation.A.layout in transpose_layouts.keys() and \
|
| 738 |
+
operation.B.layout in transpose_layouts.keys() and \
|
| 739 |
+
operation.C.layout in transpose_layouts.keys():
|
| 740 |
+
|
| 741 |
+
instance_layout_A = transpose_layouts[operation.A.layout]
|
| 742 |
+
instance_layout_B = transpose_layouts[operation.B.layout]
|
| 743 |
+
instance_layout_C = transpose_layouts[operation.C.layout]
|
| 744 |
+
|
| 745 |
+
gemm_template = self.gemm_template
|
| 746 |
+
else:
|
| 747 |
+
instance_layout_A, instance_layout_B, instance_layout_C = \
|
| 748 |
+
(operation.A.layout, operation.B.layout, operation.C.layout)
|
| 749 |
+
|
| 750 |
+
gemm_template = self.gemm_template_interleaved
|
| 751 |
+
#
|
| 752 |
+
|
| 753 |
+
# Support built-in epilogue functors or user-defined functions
|
| 754 |
+
if isinstance(operation.epilogue_functor, enum.Enum):
|
| 755 |
+
|
| 756 |
+
epilogue_vector_length = \
|
| 757 |
+
min(operation.C.alignment * DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element]
|
| 758 |
+
|
| 759 |
+
values = {
|
| 760 |
+
'epilogue_vector_length': str(epilogue_vector_length),
|
| 761 |
+
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 762 |
+
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
| 763 |
+
}
|
| 764 |
+
epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
|
| 765 |
+
else:
|
| 766 |
+
epilogue_functor = self.epilogue_functor.emit_declaration()
|
| 767 |
+
#
|
| 768 |
+
|
| 769 |
+
values = {
|
| 770 |
+
'operation_name': operation.procedural_name(),
|
| 771 |
+
'operation_suffix': self.operation_suffix,
|
| 772 |
+
'element_a': DataTypeTag[operation.A.element],
|
| 773 |
+
'layout_a': LayoutTag[instance_layout_A],
|
| 774 |
+
'element_b': DataTypeTag[operation.B.element],
|
| 775 |
+
'layout_b': LayoutTag[instance_layout_B],
|
| 776 |
+
'element_c': DataTypeTag[operation.C.element],
|
| 777 |
+
'layout_c': LayoutTag[instance_layout_C],
|
| 778 |
+
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
| 779 |
+
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 780 |
+
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
| 781 |
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
| 782 |
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
| 783 |
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
| 784 |
+
'warp_shape_m': str(warp_shape[0]),
|
| 785 |
+
'warp_shape_n': str(warp_shape[1]),
|
| 786 |
+
'warp_shape_k': str(warp_shape[2]),
|
| 787 |
+
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 788 |
+
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 789 |
+
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 790 |
+
'epilogue_functor': epilogue_functor,
|
| 791 |
+
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
| 792 |
+
'stages': str(operation.tile_description.stages),
|
| 793 |
+
'align_a': str(operation.A.alignment),
|
| 794 |
+
'align_b': str(operation.B.alignment),
|
| 795 |
+
'transform_a': ComplexTransformTag[operation.A.complex_transform],
|
| 796 |
+
'transform_b': ComplexTransformTag[operation.B.complex_transform],
|
| 797 |
+
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
|
| 798 |
+
}
|
| 799 |
+
|
| 800 |
+
return SubstituteTemplate(gemm_template, values)
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
###################################################################################################
|
| 804 |
+
|
| 805 |
+
class EmitGemmUniversal3xInstance:
|
| 806 |
+
''' Responsible for emitting a CUTLASS 3.x template definition'''
|
| 807 |
+
|
| 808 |
+
def __init__(self, operation_suffix = ''):
|
| 809 |
+
self.operation_suffix = operation_suffix
|
| 810 |
+
self.includes = [
|
| 811 |
+
"cutlass/cutlass.h",
|
| 812 |
+
"cutlass/gemm/gemm.h",
|
| 813 |
+
"cutlass/numeric_types.h",
|
| 814 |
+
"cutlass/gemm/kernel/gemm_universal.hpp",
|
| 815 |
+
"cutlass/gemm/collective/collective_builder.hpp",
|
| 816 |
+
"cutlass/epilogue/collective/collective_builder.hpp",
|
| 817 |
+
"cutlass/detail/blockwise_scale_layout.hpp",
|
| 818 |
+
]
|
| 819 |
+
self.builtin_epilogue_functor_template = \
|
| 820 |
+
"""${epilogue_functor}<
|
| 821 |
+
${element_d},
|
| 822 |
+
${element_epilogue},
|
| 823 |
+
${element_c},
|
| 824 |
+
${element_epilogue}
|
| 825 |
+
>"""
|
| 826 |
+
|
| 827 |
+
self.gemm_template = """
|
| 828 |
+
|
| 829 |
+
using ${operation_name}_epilogue =
|
| 830 |
+
typename cutlass::epilogue::collective::CollectiveBuilder<
|
| 831 |
+
${arch}, ${opcode_class_epi},
|
| 832 |
+
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
|
| 833 |
+
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
|
| 834 |
+
${epi_tile_mn},
|
| 835 |
+
${element_accumulator}, ${element_epilogue},
|
| 836 |
+
${element_c}, ${layout_c}, ${align_c},
|
| 837 |
+
${element_d}, ${layout_d}, ${align_d},
|
| 838 |
+
${epilogue_schedule},
|
| 839 |
+
${epilogue_functor}
|
| 840 |
+
>::CollectiveOp;
|
| 841 |
+
|
| 842 |
+
${mixed_dtype_prepare_code}
|
| 843 |
+
${blockwise_prepare_code}
|
| 844 |
+
|
| 845 |
+
using ${operation_name}_mainloop =
|
| 846 |
+
typename cutlass::gemm::collective::CollectiveBuilder<
|
| 847 |
+
${arch}, ${opcode_class_main},
|
| 848 |
+
${element_a}, ${layout_a}, ${align_a},
|
| 849 |
+
${element_b}, ${layout_b}, ${align_b},
|
| 850 |
+
${element_accumulator},
|
| 851 |
+
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
|
| 852 |
+
cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
|
| 853 |
+
${stages},
|
| 854 |
+
${kernel_schedule}
|
| 855 |
+
>::CollectiveOp;
|
| 856 |
+
|
| 857 |
+
// Gemm operator ${operation_name}
|
| 858 |
+
using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
|
| 859 |
+
${problem_shape},
|
| 860 |
+
${operation_name}_mainloop,
|
| 861 |
+
${operation_name}_epilogue,
|
| 862 |
+
${tile_scheduler}>;
|
| 863 |
+
|
| 864 |
+
// Define named type
|
| 865 |
+
struct ${operation_name} :
|
| 866 |
+
public ${operation_name}_base { };
|
| 867 |
+
|
| 868 |
+
"""
|
| 869 |
+
#
|
| 870 |
+
def instance_template(self):
|
| 871 |
+
return """
|
| 872 |
+
${compile_guard_start}
|
| 873 |
+
{
|
| 874 |
+
using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>;
|
| 875 |
+
manifest.append(
|
| 876 |
+
new ${gemm_kind}<GemmKernel>("${operation_name}"));
|
| 877 |
+
}
|
| 878 |
+
${compile_guard_end}
|
| 879 |
+
"""
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
def emit_block_scale_epilogue_functor(self, operation):
|
| 883 |
+
block_scaled_template = """
|
| 884 |
+
${epilogue_functor}<
|
| 885 |
+
${epi_vs},
|
| 886 |
+
${element_d},
|
| 887 |
+
${element_accumulator},
|
| 888 |
+
${element_sfd},
|
| 889 |
+
${layout_sfd},
|
| 890 |
+
${element_c},
|
| 891 |
+
${element_scalar}
|
| 892 |
+
>
|
| 893 |
+
"""
|
| 894 |
+
block_scaled_values = {
|
| 895 |
+
'epi_vs' : str(operation.ScaleFactorVectorSize),
|
| 896 |
+
'element_d': str(DataTypeTag[operation.D.element]),
|
| 897 |
+
'element_sfd': str(DataTypeTag[operation.ScaleFactorD.element]),
|
| 898 |
+
'layout_sfd': LayoutTag[operation.ScaleFactorD.layout],
|
| 899 |
+
'epilogue_functor': EpilogueFunctor3xTag[EpilogueFunctor3x.LinearCombinationBlockScaleFactor],
|
| 900 |
+
'element_accumulator': str(DataTypeTag[operation.accumulator_type()]),
|
| 901 |
+
'element_scalar': str(DataTypeTag[operation.accumulator_type()]),
|
| 902 |
+
'element_c': str(DataTypeTag[operation.C.element]),
|
| 903 |
+
}
|
| 904 |
+
return SubstituteTemplate(block_scaled_template, block_scaled_values)
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
@staticmethod
|
| 908 |
+
def pointerize_if_grouped(operation, layout):
|
| 909 |
+
return layout if not is_grouped(operation.gemm_kind) else layout + "* "
|
| 910 |
+
|
| 911 |
+
@staticmethod
|
| 912 |
+
def transform_layout_A_if_blockwise(operation, layout):
|
| 913 |
+
layout_sfa = f"{operation.procedural_name()}_LayoutSFA"
|
| 914 |
+
layout_sfa = layout_sfa if not is_grouped(operation.gemm_kind) else layout_sfa + "* "
|
| 915 |
+
return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfa}>"
|
| 916 |
+
|
| 917 |
+
@staticmethod
|
| 918 |
+
def transform_layout_B_if_blockwise(operation, layout):
|
| 919 |
+
layout_sfb = f"{operation.procedural_name()}_LayoutSFB"
|
| 920 |
+
layout_sfb = layout_sfb if not is_grouped(operation.gemm_kind) else layout_sfb + "* "
|
| 921 |
+
return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfb}>"
|
| 922 |
+
|
| 923 |
+
@staticmethod
|
| 924 |
+
def problem_shape(operation):
|
| 925 |
+
gemm_shape_type = "cute::Shape<int,int,int,int>"
|
| 926 |
+
grouped_gemm_shape_type = "cute::Shape<int,int,int>"
|
| 927 |
+
grouped_gemm_shape_type = "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">"
|
| 928 |
+
|
| 929 |
+
return gemm_shape_type if not is_grouped(operation.gemm_kind) else grouped_gemm_shape_type
|
| 930 |
+
|
| 931 |
+
def emit(self, operation):
|
| 932 |
+
_LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)")
|
| 933 |
+
_LOGGER.debug("*** operation.procedural_name(): " + operation.procedural_name())
|
| 934 |
+
_LOGGER.debug("*** tile_shape: " + str(operation.tile_description.tile_shape))
|
| 935 |
+
_LOGGER.debug("*** warp_count: " + str(operation.tile_description.warp_count))
|
| 936 |
+
|
| 937 |
+
opcode_class_main = operation.tile_description.math_instruction.opcode_class
|
| 938 |
+
opcode_class_epi = opcode_class_main
|
| 939 |
+
|
| 940 |
+
tile_shape = operation.tile_description.tile_shape
|
| 941 |
+
instruction_shape = operation.tile_description.math_instruction.instruction_shape
|
| 942 |
+
cluster_m = operation.tile_description.cluster_shape[0]
|
| 943 |
+
cluster_n = operation.tile_description.cluster_shape[1]
|
| 944 |
+
cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1]
|
| 945 |
+
tile_shape_m, tile_shape_n, tile_shape_k = operation.get_collective_tile_shape()
|
| 946 |
+
|
| 947 |
+
# stage count set to zero indicates builder automatic stage selection
|
| 948 |
+
if operation.tile_description.stages > 0:
|
| 949 |
+
stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>"
|
| 950 |
+
elif opcode_class_main == OpcodeClass.SparseTensorOp and operation.arch == 100:
|
| 951 |
+
stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveoutEpi<{str(operation.procedural_name())}_epilogue>"
|
| 952 |
+
else:
|
| 953 |
+
stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage))>"
|
| 954 |
+
|
| 955 |
+
epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto"
|
| 956 |
+
|
| 957 |
+
instance_layout_A, instance_layout_B, instance_layout_C , instance_layout_D = \
|
| 958 |
+
(operation.A.layout, operation.B.layout, operation.C.layout, operation.D.layout)
|
| 959 |
+
|
| 960 |
+
# 3.0 profiler integration only supports trivial epilogues for now
|
| 961 |
+
epilogue_vector_length = 1
|
| 962 |
+
|
| 963 |
+
# Support built-in epilogue functors or user-defined functions
|
| 964 |
+
if isinstance(operation.epilogue_functor, enum.Enum):
|
| 965 |
+
values = {
|
| 966 |
+
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 967 |
+
'epilogue_functor': EpilogueFunctor3xTag[operation.epilogue_functor],
|
| 968 |
+
}
|
| 969 |
+
epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
|
| 970 |
+
|
| 971 |
+
if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void:
|
| 972 |
+
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
|
| 973 |
+
|
| 974 |
+
|
| 975 |
+
else:
|
| 976 |
+
epilogue_functor = self.epilogue_functor.emit_declaration()
|
| 977 |
+
|
| 978 |
+
if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void:
|
| 979 |
+
epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
|
| 980 |
+
|
| 981 |
+
#
|
| 982 |
+
# Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple<Element, Transform>, Transform : cute::identity / cute::conjugate.
|
| 983 |
+
element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>"
|
| 984 |
+
element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>"
|
| 985 |
+
epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
|
| 986 |
+
|
| 987 |
+
if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
|
| 988 |
+
grouped = is_grouped(operation.gemm_kind)
|
| 989 |
+
if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped):
|
| 990 |
+
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
| 991 |
+
if is_tma_epilogue(operation.epilogue_schedule):
|
| 992 |
+
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]
|
| 993 |
+
if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped):
|
| 994 |
+
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
| 995 |
+
if is_tma_epilogue(operation.epilogue_schedule):
|
| 996 |
+
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]
|
| 997 |
+
# SM103 FP4 Ultra
|
| 998 |
+
is_sm103_fp4_ultra_1sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, grouped),
|
| 999 |
+
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, grouped),
|
| 1000 |
+
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, grouped),
|
| 1001 |
+
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, grouped),
|
| 1002 |
+
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, grouped),
|
| 1003 |
+
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, grouped)
|
| 1004 |
+
]
|
| 1005 |
+
is_sm103_fp4_ultra_2sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, grouped),
|
| 1006 |
+
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, grouped),
|
| 1007 |
+
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, grouped),
|
| 1008 |
+
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, grouped),
|
| 1009 |
+
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, grouped),
|
| 1010 |
+
to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, grouped)
|
| 1011 |
+
]
|
| 1012 |
+
if cta_n == 256 and is_sm103_fp4_ultra_1sm_kernel_schedule:
|
| 1013 |
+
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
| 1014 |
+
if is_tma_epilogue(operation.epilogue_schedule):
|
| 1015 |
+
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]
|
| 1016 |
+
if cta_n == 256 and is_sm103_fp4_ultra_2sm_kernel_schedule:
|
| 1017 |
+
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
| 1018 |
+
if is_tma_epilogue(operation.epilogue_schedule):
|
| 1019 |
+
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]
|
| 1020 |
+
|
| 1021 |
+
element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>'
|
| 1022 |
+
element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'
|
| 1023 |
+
|
| 1024 |
+
alignment_c = get_tma_alignment(operation.C.element) \
|
| 1025 |
+
if is_tma_epilogue(operation.epilogue_schedule) and opcode_class_epi != OpcodeClass.Simt \
|
| 1026 |
+
else operation.C.alignment
|
| 1027 |
+
alignment_d = get_tma_alignment(operation.D.element) \
|
| 1028 |
+
if is_tma_epilogue(operation.epilogue_schedule) and opcode_class_epi != OpcodeClass.Simt \
|
| 1029 |
+
else operation.D.alignment
|
| 1030 |
+
|
| 1031 |
+
operation_name_str = operation.procedural_name()
|
| 1032 |
+
layout_a_str = LayoutTag[instance_layout_A]
|
| 1033 |
+
layout_b_str = LayoutTag[instance_layout_B]
|
| 1034 |
+
mixed_dtype_prepare_code = ""
|
| 1035 |
+
if operation.mixed_input_mode != None:
|
| 1036 |
+
A_dtype = operation.A.element
|
| 1037 |
+
B_dtype = operation.B.element
|
| 1038 |
+
A_dtype_bits = DataTypeSize[A_dtype]
|
| 1039 |
+
B_dtype_bits = DataTypeSize[B_dtype]
|
| 1040 |
+
is_A_dtype_narrow = A_dtype_bits < B_dtype_bits
|
| 1041 |
+
if is_A_dtype_narrow:
|
| 1042 |
+
narrow_dtype, wide_dtype = (A_dtype, B_dtype)
|
| 1043 |
+
narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits)
|
| 1044 |
+
else:
|
| 1045 |
+
narrow_dtype, wide_dtype = (B_dtype, A_dtype)
|
| 1046 |
+
narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits)
|
| 1047 |
+
|
| 1048 |
+
narrow_tag = DataTypeTag[narrow_dtype]
|
| 1049 |
+
wide_tag = DataTypeTag[wide_dtype]
|
| 1050 |
+
scale_tag = DataTypeTag[wide_dtype]
|
| 1051 |
+
zero_tag = DataTypeTag[wide_dtype]
|
| 1052 |
+
|
| 1053 |
+
do_shuffle = False
|
| 1054 |
+
value_shuffle_str = ""
|
| 1055 |
+
if narrow_dtype_bits == 4 and wide_dtype_bits == 16:
|
| 1056 |
+
value_shuffle_str = "cute::Layout<cute::Shape<cute::_2,cute::_4>, cute::Stride<cute::_4,cute::_1>>"
|
| 1057 |
+
do_shuffle = True
|
| 1058 |
+
if narrow_dtype_bits == 8 and wide_dtype_bits == 16:
|
| 1059 |
+
value_shuffle_str = "cute::Layout<cute::Shape<cute::_2,cute::_2>, cute::Stride<cute::_2,cute::_1>>"
|
| 1060 |
+
do_shuffle = True
|
| 1061 |
+
do_shuffle = operation.mixed_input_shuffle and do_shuffle
|
| 1062 |
+
|
| 1063 |
+
if do_shuffle:
|
| 1064 |
+
if is_A_dtype_narrow:
|
| 1065 |
+
stride_narrow_str = f"cutlass::detail::TagToStrideA_t<{layout_a_str}>"
|
| 1066 |
+
layout_a_str = f"{operation_name_str}_LayoutNarrowReordered"
|
| 1067 |
+
else:
|
| 1068 |
+
stride_narrow_str = f"cutlass::detail::TagToStrideB_t<{layout_b_str}>"
|
| 1069 |
+
layout_b_str = f"{operation_name_str}_LayoutNarrowReordered"
|
| 1070 |
+
# The {operation_name_str}_ prefixs in mixed_dtype_prepare_code and
|
| 1071 |
+
# layout_{a, b}_str are to prevent errors in Windows platform unity build
|
| 1072 |
+
mixed_dtype_prepare_code = f"""
|
| 1073 |
+
using {operation_name_str}_StrideNarrow = {stride_narrow_str};
|
| 1074 |
+
using {operation_name_str}_ValueShuffle = {value_shuffle_str};
|
| 1075 |
+
static constexpr int {operation_name_str}_NumShuffleAtoms = 1;
|
| 1076 |
+
using {operation_name_str}_MmaAtomShape = cute::Layout<cute::Shape<cute::_1, cute::Int<{operation_name_str}_NumShuffleAtoms>>>;
|
| 1077 |
+
using {operation_name_str}_LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<{wide_tag}, {operation_name_str}_MmaAtomShape, {operation_name_str}_ValueShuffle>());
|
| 1078 |
+
using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape({operation_name_str}_LayoutAtomQuant{{}}, cute::Layout<cute::Shape<int,int,int>, {operation_name_str}_StrideNarrow>{{}}));
|
| 1079 |
+
"""
|
| 1080 |
+
|
| 1081 |
+
mixed_input_modes_to_element = {
|
| 1082 |
+
MixedInputMode.ConvertOnly: narrow_tag,
|
| 1083 |
+
MixedInputMode.ScaleOnly: f"cute::tuple<{narrow_tag}, {scale_tag}>",
|
| 1084 |
+
MixedInputMode.ScaleWithZeroPoint: f"cute::tuple<{narrow_tag}, {scale_tag}, {zero_tag}>"
|
| 1085 |
+
}
|
| 1086 |
+
narrow_element = mixed_input_modes_to_element.get(operation.mixed_input_mode, narrow_tag)
|
| 1087 |
+
|
| 1088 |
+
if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2):
|
| 1089 |
+
narrow_element = f"cute::tuple<{narrow_tag}, cutlass::Array<{scale_tag}, 8>>"
|
| 1090 |
+
|
| 1091 |
+
if is_A_dtype_narrow:
|
| 1092 |
+
element_a = narrow_element
|
| 1093 |
+
else:
|
| 1094 |
+
element_b = narrow_element
|
| 1095 |
+
|
| 1096 |
+
blockwise_prepare_code = ""
|
| 1097 |
+
if is_blockwise(operation.gemm_kind):
|
| 1098 |
+
sfm_vec_size = operation.ScaleFactorMVecSize
|
| 1099 |
+
sfn_vec_size = operation.ScaleFactorNVecSize
|
| 1100 |
+
sfk_vec_size = operation.ScaleFactorKVecSize
|
| 1101 |
+
blockwise_prepare_code = f"""
|
| 1102 |
+
using {operation_name_str}_ScaleConfig = cutlass::detail::Sm{operation.arch}BlockwiseScaleConfig<{sfm_vec_size}, {sfn_vec_size}, {sfk_vec_size}>;
|
| 1103 |
+
using {operation_name_str}_LayoutSFA = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFA());
|
| 1104 |
+
using {operation_name_str}_LayoutSFB = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFB());
|
| 1105 |
+
"""
|
| 1106 |
+
|
| 1107 |
+
values = {
|
| 1108 |
+
'operation_name': operation_name_str,
|
| 1109 |
+
'operation_suffix': self.operation_suffix,
|
| 1110 |
+
'problem_shape': self.problem_shape(operation),
|
| 1111 |
+
'element_a': element_a,
|
| 1112 |
+
'layout_a': self.transform_layout_A_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_a_str)),
|
| 1113 |
+
'element_b': element_b,
|
| 1114 |
+
'layout_b': self.transform_layout_B_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_b_str)),
|
| 1115 |
+
'element_c': DataTypeTag[operation.C.element],
|
| 1116 |
+
'layout_c': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_C]),
|
| 1117 |
+
'element_d': DataTypeTag[operation.D.element],
|
| 1118 |
+
'layout_d': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_D]),
|
| 1119 |
+
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
| 1120 |
+
'opcode_class_main': OpcodeClassTag[opcode_class_main],
|
| 1121 |
+
'opcode_class_epi': OpcodeClassTag[opcode_class_epi],
|
| 1122 |
+
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
| 1123 |
+
'tile_shape_m': str(tile_shape_m),
|
| 1124 |
+
'tile_shape_n': str(tile_shape_n),
|
| 1125 |
+
'tile_shape_k': str(tile_shape_k),
|
| 1126 |
+
'cluster_shape_m': 'cute::_' + str(operation.tile_description.cluster_shape[0]) if operation.tile_description.cluster_shape[0] > 0 else "int",
|
| 1127 |
+
'cluster_shape_n': 'cute::_' + str(operation.tile_description.cluster_shape[1]) if operation.tile_description.cluster_shape[1] > 0 else "int",
|
| 1128 |
+
'cluster_shape_k': 'cute::_' + str(operation.tile_description.cluster_shape[2]) if operation.tile_description.cluster_shape[2] > 0 else "int",
|
| 1129 |
+
'instruction_shape_m': str(instruction_shape[0]),
|
| 1130 |
+
'instruction_shape_n': str(instruction_shape[1]),
|
| 1131 |
+
'instruction_shape_k': str(instruction_shape[2]),
|
| 1132 |
+
'kernel_schedule' : str(KernelScheduleTag[operation.kernel_schedule]),
|
| 1133 |
+
'epilogue_schedule' : str(epilogue_schedule_type),
|
| 1134 |
+
'epi_tile_mn' : epi_tile_mn,
|
| 1135 |
+
'epilogue_functor': epilogue_functor,
|
| 1136 |
+
'stages': stage_count_string,
|
| 1137 |
+
'align_a': str(operation.A.alignment),
|
| 1138 |
+
'align_b': str(operation.B.alignment),
|
| 1139 |
+
'align_c': str(alignment_c),
|
| 1140 |
+
'align_d': str(alignment_d),
|
| 1141 |
+
'transform_a': ComplexTransformTag[operation.A.complex_transform],
|
| 1142 |
+
'transform_b': ComplexTransformTag[operation.B.complex_transform],
|
| 1143 |
+
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
| 1144 |
+
'epilogue_vector_length': str(epilogue_vector_length),
|
| 1145 |
+
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 1146 |
+
'tile_scheduler': str(TileSchedulerTag[operation.tile_scheduler]),
|
| 1147 |
+
'mixed_dtype_prepare_code': mixed_dtype_prepare_code,
|
| 1148 |
+
'blockwise_prepare_code' : blockwise_prepare_code
|
| 1149 |
+
}
|
| 1150 |
+
|
| 1151 |
+
return SubstituteTemplate(self.gemm_template, values)
|
| 1152 |
+
|
| 1153 |
+
###################################################################################################
|
| 1154 |
+
|
| 1155 |
+
#
|
| 1156 |
+
class EmitGemmPlanarComplexInstance:
|
| 1157 |
+
''' Responsible for emitting a CUTLASS template definition'''
|
| 1158 |
+
|
| 1159 |
+
def __init__(self, operation_suffix = ''):
|
| 1160 |
+
self.operation_suffix = operation_suffix
|
| 1161 |
+
self.includes = []
|
| 1162 |
+
self.template = """
|
| 1163 |
+
// Gemm operator ${operation_name}
|
| 1164 |
+
using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
|
| 1165 |
+
${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
|
| 1166 |
+
${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
|
| 1167 |
+
${element_c}, cutlass::layout::RowMajor,
|
| 1168 |
+
${element_accumulator},
|
| 1169 |
+
${opcode_class},
|
| 1170 |
+
${arch},
|
| 1171 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 1172 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 1173 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 1174 |
+
cutlass::epilogue::thread::LinearCombinationPlanarComplex<
|
| 1175 |
+
${element_c},
|
| 1176 |
+
${alignment_c},
|
| 1177 |
+
${element_accumulator},
|
| 1178 |
+
${element_epilogue}
|
| 1179 |
+
>,
|
| 1180 |
+
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
| 1181 |
+
${stages},
|
| 1182 |
+
${math_operator}
|
| 1183 |
+
>::GemmKernel;
|
| 1184 |
+
|
| 1185 |
+
struct ${operation_name} :
|
| 1186 |
+
public Operation_${operation_name} { };
|
| 1187 |
+
"""
|
| 1188 |
+
|
| 1189 |
+
#
|
| 1190 |
+
def instance_template(self):
|
| 1191 |
+
return """
|
| 1192 |
+
${compile_guard_start}
|
| 1193 |
+
manifest.append(new ${gemm_kind}<
|
| 1194 |
+
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
| 1195 |
+
>("${operation_name}"));
|
| 1196 |
+
${compile_guard_end}
|
| 1197 |
+
"""
|
| 1198 |
+
|
| 1199 |
+
#
|
| 1200 |
+
def emit(self, operation):
|
| 1201 |
+
|
| 1202 |
+
warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
|
| 1203 |
+
|
| 1204 |
+
# exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
|
| 1205 |
+
transposed_layout_A = TransposedLayout[operation.A.layout]
|
| 1206 |
+
transposed_layout_B = TransposedLayout[operation.B.layout]
|
| 1207 |
+
|
| 1208 |
+
values = {
|
| 1209 |
+
'operation_name': operation.procedural_name(),
|
| 1210 |
+
'element_a': DataTypeTag[operation.B.element],
|
| 1211 |
+
'layout_a': LayoutTag[transposed_layout_B],
|
| 1212 |
+
'transform_a': ComplexTransformTag[operation.B.complex_transform],
|
| 1213 |
+
'alignment_a': str(operation.B.alignment),
|
| 1214 |
+
'element_b': DataTypeTag[operation.A.element],
|
| 1215 |
+
'layout_b': LayoutTag[transposed_layout_A],
|
| 1216 |
+
'transform_b': ComplexTransformTag[operation.A.complex_transform],
|
| 1217 |
+
'alignment_b': str(operation.A.alignment),
|
| 1218 |
+
'element_c': DataTypeTag[operation.C.element],
|
| 1219 |
+
'layout_c': LayoutTag[operation.C.layout],
|
| 1220 |
+
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
|
| 1221 |
+
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 1222 |
+
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
| 1223 |
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
| 1224 |
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
| 1225 |
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
| 1226 |
+
'warp_shape_m': str(warp_shape[0]),
|
| 1227 |
+
'warp_shape_n': str(warp_shape[1]),
|
| 1228 |
+
'warp_shape_k': str(warp_shape[2]),
|
| 1229 |
+
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 1230 |
+
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 1231 |
+
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 1232 |
+
'alignment_c': str(operation.C.alignment),
|
| 1233 |
+
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 1234 |
+
'stages': str(operation.tile_description.stages),
|
| 1235 |
+
'math_operator': 'cutlass::arch::OpMultiplyAdd'
|
| 1236 |
+
}
|
| 1237 |
+
|
| 1238 |
+
return SubstituteTemplate(self.template, values)
|
| 1239 |
+
|
| 1240 |
+
###################################################################################################
|
| 1241 |
+
|
| 1242 |
+
#
|
| 1243 |
+
class EmitGemmPlanarComplexArrayInstance:
|
| 1244 |
+
''' Responsible for emitting a CUTLASS template definition'''
|
| 1245 |
+
|
| 1246 |
+
def __init__(self, operation_suffix = ''):
|
| 1247 |
+
self.operation_suffix = operation_suffix
|
| 1248 |
+
self.includes = []
|
| 1249 |
+
self.template = """
|
| 1250 |
+
// Gemm operator ${operation_name}
|
| 1251 |
+
using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
|
| 1252 |
+
${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
|
| 1253 |
+
${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
|
| 1254 |
+
${element_c}, cutlass::layout::RowMajor,
|
| 1255 |
+
${element_accumulator},
|
| 1256 |
+
${opcode_class},
|
| 1257 |
+
${arch},
|
| 1258 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 1259 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 1260 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 1261 |
+
cutlass::epilogue::thread::LinearCombinationPlanarComplex<
|
| 1262 |
+
${element_c},
|
| 1263 |
+
${alignment_c},
|
| 1264 |
+
${element_accumulator},
|
| 1265 |
+
${element_epilogue}
|
| 1266 |
+
>,
|
| 1267 |
+
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
| 1268 |
+
${stages},
|
| 1269 |
+
${math_operator}
|
| 1270 |
+
>::GemmArrayKernel;
|
| 1271 |
+
|
| 1272 |
+
struct ${operation_name} : public Operation_${operation_name} { };
|
| 1273 |
+
"""
|
| 1274 |
+
|
| 1275 |
+
#
|
| 1276 |
+
def instance_template(self):
|
| 1277 |
+
return """
|
| 1278 |
+
${compile_guard_start}
|
| 1279 |
+
manifest.append(new ${gemm_kind}<
|
| 1280 |
+
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
| 1281 |
+
>("${operation_name}"));
|
| 1282 |
+
${compile_guard_end}
|
| 1283 |
+
"""
|
| 1284 |
+
|
| 1285 |
+
#
|
| 1286 |
+
def emit(self, operation):
|
| 1287 |
+
|
| 1288 |
+
warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
|
| 1289 |
+
|
| 1290 |
+
# exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
|
| 1291 |
+
transposed_layout_A = TransposedLayout[operation.A.layout]
|
| 1292 |
+
transposed_layout_B = TransposedLayout[operation.B.layout]
|
| 1293 |
+
|
| 1294 |
+
values = {
|
| 1295 |
+
'operation_name': operation.procedural_name(),
|
| 1296 |
+
'element_a': DataTypeTag[operation.B.element],
|
| 1297 |
+
'layout_a': LayoutTag[transposed_layout_B],
|
| 1298 |
+
'transform_a': ComplexTransformTag[operation.B.complex_transform],
|
| 1299 |
+
'alignment_a': str(operation.B.alignment),
|
| 1300 |
+
'element_b': DataTypeTag[operation.A.element],
|
| 1301 |
+
'layout_b': LayoutTag[transposed_layout_A],
|
| 1302 |
+
'transform_b': ComplexTransformTag[operation.A.complex_transform],
|
| 1303 |
+
'alignment_b': str(operation.A.alignment),
|
| 1304 |
+
'element_c': DataTypeTag[operation.C.element],
|
| 1305 |
+
'layout_c': LayoutTag[operation.C.layout],
|
| 1306 |
+
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
|
| 1307 |
+
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 1308 |
+
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
| 1309 |
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
| 1310 |
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
| 1311 |
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
| 1312 |
+
'warp_shape_m': str(warp_shape[0]),
|
| 1313 |
+
'warp_shape_n': str(warp_shape[1]),
|
| 1314 |
+
'warp_shape_k': str(warp_shape[2]),
|
| 1315 |
+
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 1316 |
+
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 1317 |
+
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 1318 |
+
'alignment_c': str(operation.C.alignment),
|
| 1319 |
+
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 1320 |
+
'stages': str(operation.tile_description.stages),
|
| 1321 |
+
'math_operator': 'cutlass::arch::OpMultiplyAdd'
|
| 1322 |
+
}
|
| 1323 |
+
|
| 1324 |
+
return SubstituteTemplate(self.template, values)
|
| 1325 |
+
|
| 1326 |
+
###################################################################################################
|
| 1327 |
+
|
| 1328 |
+
#
|
| 1329 |
+
class EmitGemmGroupedInstance:
|
| 1330 |
+
''' Responsible for emitting a CUTLASS template definition'''
|
| 1331 |
+
|
| 1332 |
+
def __init__(self, operation_suffix = ''):
|
| 1333 |
+
self.operation_suffix = operation_suffix
|
| 1334 |
+
self.includes = [
|
| 1335 |
+
"cutlass/cutlass.h",
|
| 1336 |
+
"cutlass/numeric_types.h",
|
| 1337 |
+
"cutlass/arch/arch.h",
|
| 1338 |
+
"cutlass/arch/mma.h",
|
| 1339 |
+
"cutlass/layout/matrix.h",
|
| 1340 |
+
"cutlass/gemm/device/gemm.h",
|
| 1341 |
+
"cutlass/gemm/kernel/gemm_grouped.h",
|
| 1342 |
+
"cutlass/gemm/kernel/default_gemm_grouped.h",
|
| 1343 |
+
"cutlass/gemm/device/gemm_grouped.h"
|
| 1344 |
+
]
|
| 1345 |
+
self.builtin_epilogue_functor_template = \
|
| 1346 |
+
"""${epilogue_functor}<
|
| 1347 |
+
${element_c},
|
| 1348 |
+
${epilogue_vector_length},
|
| 1349 |
+
${element_accumulator},
|
| 1350 |
+
${element_epilogue}
|
| 1351 |
+
>"""
|
| 1352 |
+
|
| 1353 |
+
self.gemm_template = """
|
| 1354 |
+
// Gemm operator ${operation_name}
|
| 1355 |
+
using ${operation_name}_base =
|
| 1356 |
+
typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
| 1357 |
+
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
| 1358 |
+
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
| 1359 |
+
${element_c}, ${layout_c},
|
| 1360 |
+
${element_accumulator},
|
| 1361 |
+
${opcode_class},
|
| 1362 |
+
${arch},
|
| 1363 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 1364 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 1365 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 1366 |
+
${epilogue_functor},
|
| 1367 |
+
${swizzling_functor},
|
| 1368 |
+
${stages},
|
| 1369 |
+
${scheduler_mode},
|
| 1370 |
+
${math_operation}
|
| 1371 |
+
>::GemmKernel;
|
| 1372 |
+
|
| 1373 |
+
// Define named type
|
| 1374 |
+
struct ${operation_name}${operation_suffix} :
|
| 1375 |
+
public ${operation_name}_base { };
|
| 1376 |
+
"""
|
| 1377 |
+
|
| 1378 |
+
#
|
| 1379 |
+
def instance_template(self):
|
| 1380 |
+
return """
|
| 1381 |
+
${compile_guard_start}
|
| 1382 |
+
manifest.append(new ${gemm_kind}<
|
| 1383 |
+
cutlass::gemm::device::GemmGrouped<${operation_name}>
|
| 1384 |
+
>("${operation_name}"));
|
| 1385 |
+
${compile_guard_end}
|
| 1386 |
+
"""
|
| 1387 |
+
|
| 1388 |
+
#
|
| 1389 |
+
def emit(self, operation):
|
| 1390 |
+
|
| 1391 |
+
threadblock_shape = operation.tile_description.threadblock_shape
|
| 1392 |
+
warp_count = operation.tile_description.warp_count
|
| 1393 |
+
|
| 1394 |
+
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
| 1395 |
+
|
| 1396 |
+
transpose_layouts = {
|
| 1397 |
+
LayoutType.ColumnMajor: LayoutType.RowMajor,
|
| 1398 |
+
LayoutType.RowMajor: LayoutType.ColumnMajor
|
| 1399 |
+
}
|
| 1400 |
+
|
| 1401 |
+
instance_layout_A, instance_layout_B, instance_layout_C = \
|
| 1402 |
+
(operation.A.layout, operation.B.layout, operation.C.layout)
|
| 1403 |
+
#
|
| 1404 |
+
|
| 1405 |
+
# Support built-in epilogue functors or user-defined functions
|
| 1406 |
+
if isinstance(operation.epilogue_functor, enum.Enum):
|
| 1407 |
+
|
| 1408 |
+
epilogue_vector_length = \
|
| 1409 |
+
min(operation.C.alignment * DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element]
|
| 1410 |
+
|
| 1411 |
+
values = {
|
| 1412 |
+
'epilogue_vector_length': str(epilogue_vector_length),
|
| 1413 |
+
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 1414 |
+
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
| 1415 |
+
}
|
| 1416 |
+
epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
|
| 1417 |
+
else:
|
| 1418 |
+
epilogue_functor = self.epilogue_functor.emit_declaration()
|
| 1419 |
+
#
|
| 1420 |
+
|
| 1421 |
+
values = {
|
| 1422 |
+
'operation_name': operation.procedural_name(),
|
| 1423 |
+
'operation_suffix': self.operation_suffix,
|
| 1424 |
+
'element_a': DataTypeTag[operation.A.element],
|
| 1425 |
+
'layout_a': LayoutTag[instance_layout_A],
|
| 1426 |
+
'element_b': DataTypeTag[operation.B.element],
|
| 1427 |
+
'layout_b': LayoutTag[instance_layout_B],
|
| 1428 |
+
'element_c': DataTypeTag[operation.C.element],
|
| 1429 |
+
'layout_c': LayoutTag[instance_layout_C],
|
| 1430 |
+
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
| 1431 |
+
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 1432 |
+
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
| 1433 |
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
| 1434 |
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
| 1435 |
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
| 1436 |
+
'warp_shape_m': str(warp_shape[0]),
|
| 1437 |
+
'warp_shape_n': str(warp_shape[1]),
|
| 1438 |
+
'warp_shape_k': str(warp_shape[2]),
|
| 1439 |
+
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 1440 |
+
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 1441 |
+
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 1442 |
+
'epilogue_functor': epilogue_functor,
|
| 1443 |
+
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
| 1444 |
+
'stages': str(operation.tile_description.stages),
|
| 1445 |
+
'align_a': str(operation.A.alignment),
|
| 1446 |
+
'align_b': str(operation.B.alignment),
|
| 1447 |
+
'transform_a': ComplexTransformTag[operation.A.complex_transform],
|
| 1448 |
+
'transform_b': ComplexTransformTag[operation.B.complex_transform],
|
| 1449 |
+
'scheduler_mode': GroupScheduleModeTag[operation.scheduler_mode],
|
| 1450 |
+
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
|
| 1451 |
+
}
|
| 1452 |
+
|
| 1453 |
+
return SubstituteTemplate(self.gemm_template, values)
|
| 1454 |
+
|
| 1455 |
+
###################################################################################################
|
| 1456 |
+
#
|
| 1457 |
+
# Emitters functions for all targets
|
| 1458 |
+
#
|
| 1459 |
+
###################################################################################################
|
| 1460 |
+
|
| 1461 |
+
class EmitGemmConfigurationLibrary:
|
| 1462 |
+
def __init__(self, operation_path, configuration_name):
|
| 1463 |
+
self.configuration_name = configuration_name
|
| 1464 |
+
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
|
| 1465 |
+
|
| 1466 |
+
self.instance_emitter = {
|
| 1467 |
+
GemmKind.Gemm: EmitGemmInstance,
|
| 1468 |
+
GemmKind.Sparse: EmitSparseGemmInstance,
|
| 1469 |
+
GemmKind.Universal: EmitGemmUniversalInstance,
|
| 1470 |
+
GemmKind.Universal3x: EmitGemmUniversal3xInstance,
|
| 1471 |
+
GemmKind.SparseUniversal3x: EmitGemmUniversal3xInstance,
|
| 1472 |
+
GemmKind.BlockScaledUniversal3x: EmitGemmUniversal3xInstance,
|
| 1473 |
+
GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
|
| 1474 |
+
GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance,
|
| 1475 |
+
GemmKind.Grouped: EmitGemmGroupedInstance,
|
| 1476 |
+
GemmKind.GroupedUniversal3x: EmitGemmUniversal3xInstance,
|
| 1477 |
+
GemmKind.GroupedBlockScaledUniversal3x: EmitGemmUniversal3xInstance,
|
| 1478 |
+
GemmKind.BlockwiseUniversal3x: EmitGemmUniversal3xInstance,
|
| 1479 |
+
GemmKind.GroupedBlockwiseUniversal3x: EmitGemmUniversal3xInstance,
|
| 1480 |
+
}
|
| 1481 |
+
|
| 1482 |
+
self.gemm_kind_wrappers = {
|
| 1483 |
+
GemmKind.Gemm: 'GemmOperation',
|
| 1484 |
+
GemmKind.Sparse: 'GemmSparseOperation',
|
| 1485 |
+
GemmKind.Universal: 'GemmUniversalOperation',
|
| 1486 |
+
GemmKind.Universal3x: 'GemmUniversal3xOperation',
|
| 1487 |
+
GemmKind.SparseUniversal3x: 'SparseGemmUniversal3xOperation',
|
| 1488 |
+
GemmKind.BlockScaledUniversal3x: 'BlockScaledGemmUniversal3xOperation',
|
| 1489 |
+
GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
|
| 1490 |
+
GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation',
|
| 1491 |
+
GemmKind.Grouped: 'GemmGroupedOperation',
|
| 1492 |
+
GemmKind.GroupedUniversal3x: 'GroupedGemmUniversal3xOperation',
|
| 1493 |
+
GemmKind.GroupedBlockScaledUniversal3x: 'GroupedBlockScaledGemmUniversal3xOperation',
|
| 1494 |
+
GemmKind.BlockwiseUniversal3x: 'BlockwiseGemmUniversal3xOperation',
|
| 1495 |
+
GemmKind.GroupedBlockwiseUniversal3x: 'GroupedBlockwiseGemmUniversal3xOperation',
|
| 1496 |
+
}
|
| 1497 |
+
|
| 1498 |
+
self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
|
| 1499 |
+
|
| 1500 |
+
self.separator = """
|
| 1501 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1502 |
+
|
| 1503 |
+
"""
|
| 1504 |
+
|
| 1505 |
+
self.header_template = """
|
| 1506 |
+
/*
|
| 1507 |
+
Generated by gemm_operation.py - Do not edit.
|
| 1508 |
+
*/
|
| 1509 |
+
"""
|
| 1510 |
+
|
| 1511 |
+
self.initialize_function_template = """
|
| 1512 |
+
|
| 1513 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1514 |
+
|
| 1515 |
+
namespace cutlass {
|
| 1516 |
+
namespace library {
|
| 1517 |
+
|
| 1518 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1519 |
+
|
| 1520 |
+
void initialize_${configuration_name}(Manifest &manifest) {
|
| 1521 |
+
|
| 1522 |
+
"""
|
| 1523 |
+
self.epilogue_template = """
|
| 1524 |
+
|
| 1525 |
+
}
|
| 1526 |
+
|
| 1527 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1528 |
+
|
| 1529 |
+
} // namespace library
|
| 1530 |
+
} // namespace cutlass
|
| 1531 |
+
|
| 1532 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1533 |
+
|
| 1534 |
+
"""
|
| 1535 |
+
|
| 1536 |
+
def __enter__(self):
|
| 1537 |
+
_LOGGER.debug("*** EmitGemmConfigurationLibrary::__enter__")
|
| 1538 |
+
_LOGGER.debug("*** configuration_path (file to write): " +
|
| 1539 |
+
str(self.configuration_path))
|
| 1540 |
+
|
| 1541 |
+
self.configuration_file = open(self.configuration_path, "w")
|
| 1542 |
+
self.configuration_file.write(self.header_template)
|
| 1543 |
+
self.configuration_file.write(self.separator)
|
| 1544 |
+
|
| 1545 |
+
self.includes = collections.OrderedDict([
|
| 1546 |
+
("cutlass/cutlass.h", None),
|
| 1547 |
+
("cutlass/library/library.h", None),
|
| 1548 |
+
("cutlass/library/manifest.h", None),
|
| 1549 |
+
("library_internal.h", None),
|
| 1550 |
+
("gemm_operation.h", None),
|
| 1551 |
+
("gemm_operation_3x.hpp", None),
|
| 1552 |
+
("grouped_gemm_operation_3x.hpp", None),
|
| 1553 |
+
("sparse_gemm_operation_3x.hpp", None),
|
| 1554 |
+
("block_scaled_gemm_operation_3x.hpp", None),
|
| 1555 |
+
("blockwise_gemm_operation_3x.hpp", None),
|
| 1556 |
+
("cutlass/arch/wmma.h", None),
|
| 1557 |
+
("cutlass/numeric_types.h", None)
|
| 1558 |
+
])
|
| 1559 |
+
self.instance_definitions = []
|
| 1560 |
+
self.instance_wrappers = []
|
| 1561 |
+
|
| 1562 |
+
self.operations = []
|
| 1563 |
+
return self
|
| 1564 |
+
|
| 1565 |
+
def emit(self, operation):
|
| 1566 |
+
_LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)")
|
| 1567 |
+
_LOGGER.debug("*** operation.gemm_kind: " + str(operation.gemm_kind))
|
| 1568 |
+
|
| 1569 |
+
emitter = self.instance_emitter[operation.gemm_kind]()
|
| 1570 |
+
|
| 1571 |
+
for incl in emitter.includes:
|
| 1572 |
+
self.includes[incl] = None
|
| 1573 |
+
|
| 1574 |
+
self.operations.append(operation)
|
| 1575 |
+
|
| 1576 |
+
self.instance_definitions.append(emitter.emit(operation))
|
| 1577 |
+
|
| 1578 |
+
self.instance_wrappers.append(SubstituteTemplate(emitter.instance_template(), {
|
| 1579 |
+
'configuration_name': self.configuration_name,
|
| 1580 |
+
'operation_name': operation.procedural_name(),
|
| 1581 |
+
'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind],
|
| 1582 |
+
'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
|
| 1583 |
+
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
|
| 1584 |
+
'compile_guard_end': "#endif" \
|
| 1585 |
+
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
|
| 1586 |
+
}))
|
| 1587 |
+
|
| 1588 |
+
def __exit__(self, exception_type, exception_value, traceback):
|
| 1589 |
+
|
| 1590 |
+
# Write includes
|
| 1591 |
+
for incl, _ in self.includes.items():
|
| 1592 |
+
include_statement = "#include \"%s\"\n" % incl
|
| 1593 |
+
self.configuration_file.write(include_statement)
|
| 1594 |
+
|
| 1595 |
+
self.configuration_file.write(self.separator)
|
| 1596 |
+
|
| 1597 |
+
# Write instance definitions in top-level namespace
|
| 1598 |
+
for instance_definition in self.instance_definitions:
|
| 1599 |
+
self.configuration_file.write(instance_definition)
|
| 1600 |
+
|
| 1601 |
+
# Add wrapper objects within initialize() function
|
| 1602 |
+
self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
|
| 1603 |
+
'configuration_name': self.configuration_name
|
| 1604 |
+
}))
|
| 1605 |
+
|
| 1606 |
+
for instance_wrapper in self.instance_wrappers:
|
| 1607 |
+
self.configuration_file.write(instance_wrapper)
|
| 1608 |
+
|
| 1609 |
+
self.configuration_file.write(self.epilogue_template)
|
| 1610 |
+
self.configuration_file.close()
|
| 1611 |
+
|
| 1612 |
+
###################################################################################################
|
| 1613 |
+
###################################################################################################
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/generator.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for selecting CUTLASS library kernels based on problem description
|
| 35 |
+
"""
|
| 36 |
+
import json
|
| 37 |
+
import csv
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import builtins
|
| 41 |
+
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 42 |
+
raise ImportError("Disabling attempt to import cutlass_library")
|
| 43 |
+
from cutlass_library.library import *
|
| 44 |
+
from cutlass_library.generator import *
|
| 45 |
+
from cutlass_library.heuristics_provider import *
|
| 46 |
+
except ImportError:
|
| 47 |
+
from library import *
|
| 48 |
+
from generator import *
|
| 49 |
+
from heuristics_provider import *
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
from .sm90_utils import (
|
| 53 |
+
get_valid_schedules,
|
| 54 |
+
generate_data_types_from_math_instruction,
|
| 55 |
+
fix_alignments,
|
| 56 |
+
)
|
| 57 |
+
except ImportError:
|
| 58 |
+
from sm90_utils import (
|
| 59 |
+
get_valid_schedules,
|
| 60 |
+
generate_data_types_from_math_instruction,
|
| 61 |
+
fix_alignments,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
_LOGGER = logging.getLogger(__name__)
|
| 65 |
+
|
| 66 |
+
dtype_map = {v: k for k, v in DataTypeNames.items()}
|
| 67 |
+
|
| 68 |
+
def serialize_heuristics_results_to_json(problems_with_configs, outfile_path):
|
| 69 |
+
"""
|
| 70 |
+
Utilitiy function to write heuristics results to a json file for debug
|
| 71 |
+
|
| 72 |
+
args:
|
| 73 |
+
problems_with_configs: List of problems provided to the heuristic, with a list of operations added to each problem dict
|
| 74 |
+
outfile_path: Outfile path
|
| 75 |
+
|
| 76 |
+
returns:
|
| 77 |
+
None
|
| 78 |
+
"""
|
| 79 |
+
pc_copy = problems_with_configs.copy()
|
| 80 |
+
for p in pc_copy:
|
| 81 |
+
for k, v in p.items():
|
| 82 |
+
if isinstance(v, DataType):
|
| 83 |
+
p[k] = DataTypeNames[v]
|
| 84 |
+
elif isinstance(v, LayoutType):
|
| 85 |
+
p[k] = ShortLayoutTypeNames[v]
|
| 86 |
+
configs = p['configs']
|
| 87 |
+
for c in configs:
|
| 88 |
+
for k, v in c.items():
|
| 89 |
+
if isinstance(v, DataType):
|
| 90 |
+
c[k] = DataTypeNames[v]
|
| 91 |
+
elif isinstance(v, LayoutType):
|
| 92 |
+
c[k] = ShortLayoutTypeNames[v]
|
| 93 |
+
with open(outfile_path, 'w') as f:
|
| 94 |
+
json.dump(pc_copy, f, indent=2)
|
| 95 |
+
|
| 96 |
+
def get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, voidC=False, use_fast_acc=True, count=1, provider=None):
|
| 97 |
+
"""
|
| 98 |
+
Get heuristic-suggested GEMM kernel configurations for a single GEMM problem.
|
| 99 |
+
|
| 100 |
+
args:
|
| 101 |
+
m, n, k: GEMM dimensions
|
| 102 |
+
batch_count: batch count
|
| 103 |
+
layouts: tuple of layouts of type LayoutType
|
| 104 |
+
use_fast_acc: Use fast accumulation for FP8. Ignored for other precisions
|
| 105 |
+
count: Number of configs to return
|
| 106 |
+
provider: Heuristics provider to use
|
| 107 |
+
|
| 108 |
+
returns:
|
| 109 |
+
A list of dictionaries containing the suggested kernel configurations and additional info from the input required to define a Cutlass GemmOperation, with the following keys:
|
| 110 |
+
- 'cta_tile_m', 'cta_tile_m', 'cta_tile_k': CTA tile size
|
| 111 |
+
- 'instr_tile_m', 'instr_tile_n', 'instr_tile_k': Instruction tile size
|
| 112 |
+
- 'stages': kernel pipeline stage count
|
| 113 |
+
- 'cluster_m', 'cluster_n', 'cluster_k': cluster size
|
| 114 |
+
- 'layout_a', 'layout_b': input tensor layouts of type LayoutType
|
| 115 |
+
- 'alignment_a', 'alignment_b': input tensor alignments, in count of elements
|
| 116 |
+
- 'dtype_a', 'dtype_b', 'dtype_acc': dtypes of a, b, and accumulator, of type DataType
|
| 117 |
+
- 'swizzle_size' : suggested threadblock swizzle
|
| 118 |
+
- 'split_k_slices': number of partitions of the k dimension for splitK
|
| 119 |
+
- 'raster_order': raster order for CTAs over output tiles ('along_m' or 'along_n')
|
| 120 |
+
"""
|
| 121 |
+
if provider is None:
|
| 122 |
+
provider = MatmulHeuristics()
|
| 123 |
+
return provider.get_configs(m, n, k, batch_count, dtypes, layouts, alignment_a, alignment_b, voidC=voidC, use_fast_acc=use_fast_acc, count=count)
|
| 124 |
+
|
| 125 |
+
def get_gemm_configs(problems, provider=None, count=1):
|
| 126 |
+
"""
|
| 127 |
+
Get heuristic-suggested GEMM kernel configurations for a set of GEMM problems.
|
| 128 |
+
|
| 129 |
+
args:
|
| 130 |
+
problems: List of dictionaries describing GEMM problems with the following keys:
|
| 131 |
+
- 'm', 'n', 'k': Matrix dimensions (required)
|
| 132 |
+
- 'dtype_a': Data type of matrix A (required)
|
| 133 |
+
- 'dtype_b': Data type of matrix B (required)
|
| 134 |
+
- 'dtype_c': Data type of matrix C (default: None)
|
| 135 |
+
- 'dtype_d': Data type of matrix D (required)
|
| 136 |
+
- 'dtype_acc': Compute data type (default 'f32')
|
| 137 |
+
- 'layout': Operation layout (e.g. 'tnt')
|
| 138 |
+
- 'alignment_a': Memory access granularity of A, in units of elements (default: 16 bytes equivalent elements)
|
| 139 |
+
- 'alignment_b': Memory access granularity of B, in units of elements (default: 16 bytes equivalent elements)
|
| 140 |
+
- 'alpha': Scalar multiplier for A*B (default: 1.0)
|
| 141 |
+
- 'beta': Scalar multiplier for C (default: 0.0)
|
| 142 |
+
- 'batch_count': Number of GEMM operations in batch (default: 1)
|
| 143 |
+
- 'use_fast_acc': Enable fast accumulation for FP8 on Hopper (default: True)
|
| 144 |
+
provider: Heuristics provider to use
|
| 145 |
+
count: Number of configurations to return per problem (defualt: 1)
|
| 146 |
+
|
| 147 |
+
returns:
|
| 148 |
+
A copy of the input dictionary, with key `configs` added containing the selected gemm configs
|
| 149 |
+
"""
|
| 150 |
+
ret = []
|
| 151 |
+
|
| 152 |
+
for problem in problems:
|
| 153 |
+
problem = problem.copy()
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
m = problem['m']
|
| 157 |
+
n = problem['n']
|
| 158 |
+
k = problem['k']
|
| 159 |
+
dtype_a = problem['dtype_a']
|
| 160 |
+
dtype_b = problem['dtype_b']
|
| 161 |
+
dtype_d = problem['dtype_d']
|
| 162 |
+
layout = problem['layout']
|
| 163 |
+
except KeyError as e:
|
| 164 |
+
_LOGGER.error(f"Missing required parameter {e} for problem {problem}")
|
| 165 |
+
raise
|
| 166 |
+
|
| 167 |
+
operation = problem.get('operation', 'gemm')
|
| 168 |
+
batch_count = problem.get('batch_count', 1)
|
| 169 |
+
dtype_acc = problem.get('dtype_acc', 'f32')
|
| 170 |
+
dtype_c = problem.get('dtype_c', None)
|
| 171 |
+
alpha = problem.get('alpha', 1.0)
|
| 172 |
+
beta = problem.get('beta', 0.0)
|
| 173 |
+
use_fast_acc = problem.get('use_fast_acc', True)
|
| 174 |
+
|
| 175 |
+
if operation != OperationKindNames[OperationKind.Gemm]:
|
| 176 |
+
raise ValueError(f"Unsupported operation {operation}")
|
| 177 |
+
if not (len(layout) == 3 and all(c in "nt" for c in layout)):
|
| 178 |
+
raise ValueError(f"layout must be a 3-character string containing only 'n' or 't', got {layout}")
|
| 179 |
+
layouts = tuple(LayoutType.RowMajor if l == 't' else LayoutType.ColumnMajor for l in layout)
|
| 180 |
+
|
| 181 |
+
try:
|
| 182 |
+
dtype_list = [dtype_a.lower(), dtype_b.lower(), dtype_acc.lower(), dtype_c.lower() if dtype_c is not None else dtype_d.lower(), dtype_d.lower()]
|
| 183 |
+
dtypes = tuple(dtype_map[dt] for dt in dtype_list)
|
| 184 |
+
except KeyError as dt:
|
| 185 |
+
_LOGGER.error(f"Unsupported data type: {dt}")
|
| 186 |
+
raise
|
| 187 |
+
|
| 188 |
+
alignment_a = problem.get('alignment_a', 128 // DataTypeSize[dtypes[0]])
|
| 189 |
+
alignment_b = problem.get('alignment_b', 128 // DataTypeSize[dtypes[1]])
|
| 190 |
+
|
| 191 |
+
configs = get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, beta==0.0, use_fast_acc, count, provider)
|
| 192 |
+
problem['configs'] = configs
|
| 193 |
+
|
| 194 |
+
ret.append(problem)
|
| 195 |
+
|
| 196 |
+
return ret
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def generate_sm100_from_heuristics_configs(manifest, cuda_version, kernel_configs):
|
| 200 |
+
"""
|
| 201 |
+
Generate CUTLASS operations based on the list of configs provided by the heuristic provider
|
| 202 |
+
|
| 203 |
+
args:
|
| 204 |
+
manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest)
|
| 205 |
+
cuda_version: Cuda compiler version for generating cutlass operations
|
| 206 |
+
kernel_configs: list of configs generated by the heuristic
|
| 207 |
+
|
| 208 |
+
returns:
|
| 209 |
+
(configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations
|
| 210 |
+
"""
|
| 211 |
+
min_cc = 100
|
| 212 |
+
max_cc = 101
|
| 213 |
+
if manifest is None:
|
| 214 |
+
# Use a dummy manifest so we can use existing CreateGemmOperator functions
|
| 215 |
+
manifest = Manifest()
|
| 216 |
+
|
| 217 |
+
configs = []
|
| 218 |
+
operations = []
|
| 219 |
+
for config in kernel_configs:
|
| 220 |
+
layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [config['layout_d'], 128 // DataTypeSize[config['dtype_d']]])
|
| 221 |
+
element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d']
|
| 222 |
+
|
| 223 |
+
# nvMMH assumes 2sm instruction for !(cluster_m % 2)
|
| 224 |
+
is_2sm = config['cluster_m'] % 2 == 0
|
| 225 |
+
instruction_shape = [(2 * config['cta_tile_m']) if is_2sm else config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k'] // 4]
|
| 226 |
+
math_instruction = MathInstruction(
|
| 227 |
+
instruction_shape,
|
| 228 |
+
element_a, element_b, element_accumulator,
|
| 229 |
+
OpcodeClass.TensorOp,
|
| 230 |
+
MathOperation.multiply_add
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
data_types = [
|
| 234 |
+
{
|
| 235 |
+
"a_type" : math_instruction.element_a,
|
| 236 |
+
"b_type" : math_instruction.element_b,
|
| 237 |
+
"c_type" : DataType.void if config['voidC'] else math_instruction.element_accumulator,
|
| 238 |
+
"d_type" : element_d,
|
| 239 |
+
"acc_type" : math_instruction.element_accumulator,
|
| 240 |
+
"epi_type" : math_instruction.element_accumulator,
|
| 241 |
+
}
|
| 242 |
+
]
|
| 243 |
+
|
| 244 |
+
tile_multiplier = (config['cluster_m'] // (2 if is_2sm else 1), config['cluster_n'], config['cluster_k'])
|
| 245 |
+
tile_description = TileDescription(
|
| 246 |
+
[instruction_shape[0] * tile_multiplier[0],
|
| 247 |
+
instruction_shape[1] * tile_multiplier[1],
|
| 248 |
+
instruction_shape[2] * 4 * tile_multiplier[2]],
|
| 249 |
+
0,
|
| 250 |
+
[4,1,1],
|
| 251 |
+
math_instruction,
|
| 252 |
+
min_cc,
|
| 253 |
+
max_cc,
|
| 254 |
+
cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k'])
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
schedules = []
|
| 258 |
+
if is_2sm:
|
| 259 |
+
schedules.append([KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm])
|
| 260 |
+
else:
|
| 261 |
+
schedules.append([KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm])
|
| 262 |
+
|
| 263 |
+
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, tile_schedulers=[TileSchedulerType.Default, TileSchedulerType.StreamK], gemm_kind=GemmKind.Universal3x):
|
| 264 |
+
configs.append(config)
|
| 265 |
+
operations.append(o)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
return configs, operations
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def generate_sm90_from_heuristics_configs(manifest, cuda_version, kernel_configs):
|
| 272 |
+
"""
|
| 273 |
+
Generate CUTLASS operations based on the list of configs provided by the heuristic provider
|
| 274 |
+
|
| 275 |
+
args:
|
| 276 |
+
manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest)
|
| 277 |
+
cuda_version: Cuda compiler version for generating cutlass operations
|
| 278 |
+
kernel_configs: list of configs generated by the heuristic
|
| 279 |
+
|
| 280 |
+
returns:
|
| 281 |
+
(configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations
|
| 282 |
+
"""
|
| 283 |
+
min_cc, max_cc = 90, 90
|
| 284 |
+
|
| 285 |
+
if manifest is None:
|
| 286 |
+
# Use a dummy manifest so we can use existing CreateGemmOperator functions
|
| 287 |
+
manifest = Manifest()
|
| 288 |
+
|
| 289 |
+
configs = []
|
| 290 |
+
operations = []
|
| 291 |
+
for config in kernel_configs:
|
| 292 |
+
|
| 293 |
+
is_aligned = (config['alignment_a'] * DataTypeSize[config['dtype_a']] >= 128) and (config['alignment_b'] * DataTypeSize[config['dtype_b']] >= 128)
|
| 294 |
+
layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [LayoutType.ColumnMajor, 1])
|
| 295 |
+
element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d']
|
| 296 |
+
|
| 297 |
+
# instr shape and warp config are unused for emitting 3x collective builder code
|
| 298 |
+
dummy_instr_shape = [0, 0, 0]
|
| 299 |
+
math_instruction = MathInstruction(
|
| 300 |
+
dummy_instr_shape,
|
| 301 |
+
element_a, element_b, element_accumulator,
|
| 302 |
+
OpcodeClass.TensorOp,
|
| 303 |
+
MathOperation.multiply_add
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
data_types = generate_data_types_from_math_instruction(math_instruction, element_source=element_c, element_dest=element_d)
|
| 307 |
+
if is_aligned:
|
| 308 |
+
layout = fix_alignments(data_types, layout, alignment_bits=128)
|
| 309 |
+
|
| 310 |
+
# instr shape and warp config are unused for emitting 3x collective builder code
|
| 311 |
+
dummy_warp_count = [0, 0, 0]
|
| 312 |
+
tile_description = TileDescription(
|
| 313 |
+
[config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k']],
|
| 314 |
+
0,
|
| 315 |
+
dummy_warp_count,
|
| 316 |
+
math_instruction,
|
| 317 |
+
min_cc,
|
| 318 |
+
max_cc,
|
| 319 |
+
cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k'])
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
schedules, stream_k_schedules = get_valid_schedules(
|
| 323 |
+
tile_description=tile_description,
|
| 324 |
+
cuda_version=cuda_version,
|
| 325 |
+
is_aligned=is_aligned,
|
| 326 |
+
data_types=data_types,
|
| 327 |
+
instantiation_level=9000, # don't prune schedules: we didn't get any schedule suggestion from the heuristic
|
| 328 |
+
layout=layout,
|
| 329 |
+
gemm_kind=GemmKind.Universal3x,
|
| 330 |
+
enable_fp8_fast_acc=config['use_fast_acc']
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
if len(schedules):
|
| 334 |
+
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, gemm_kind=GemmKind.Universal3x):
|
| 335 |
+
configs.append(config)
|
| 336 |
+
operations.append(o)
|
| 337 |
+
|
| 338 |
+
if len(stream_k_schedules):
|
| 339 |
+
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types,
|
| 340 |
+
stream_k_schedules,
|
| 341 |
+
tile_schedulers=[TileSchedulerType.StreamK]):
|
| 342 |
+
configs.append(config)
|
| 343 |
+
operations.append(o)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
return configs, operations
|
| 347 |
+
|
| 348 |
+
def filter_manifest_and_write_heuristics_file(manifest, args):
|
| 349 |
+
"""
|
| 350 |
+
Prune a manifest according to heuristics suggestions from the problems file
|
| 351 |
+
|
| 352 |
+
args:
|
| 353 |
+
manifest: Cutlass manifest to prune
|
| 354 |
+
args: generator.py args, requires:
|
| 355 |
+
- args.heuristics_problems_file
|
| 356 |
+
- args.heuristics_gpu
|
| 357 |
+
- args.heuristics_testlist_file
|
| 358 |
+
|
| 359 |
+
returns:
|
| 360 |
+
A list of dictionaries, each of which has information about an operation and a problem from the input problems
|
| 361 |
+
"""
|
| 362 |
+
heuristics_problems = []
|
| 363 |
+
with open(args.heuristics_problems_file, 'r') as f:
|
| 364 |
+
heuristics_problems = json.load(f)
|
| 365 |
+
gpu = None if (args.heuristics_gpu == "auto" or args.heuristics_gpu == "") else args.heuristics_gpu
|
| 366 |
+
mmh = MatmulHeuristics(gpu=gpu)
|
| 367 |
+
if any(('100' in arch) for arch in args.architectures.split(';')):
|
| 368 |
+
mmh.set_cta_div_n(64)
|
| 369 |
+
problems_with_configs = get_gemm_configs(heuristics_problems, provider=mmh, count=args.heuristics_configs_per_problem)
|
| 370 |
+
|
| 371 |
+
all_configs_and_operations = []
|
| 372 |
+
operations = []
|
| 373 |
+
for problem in problems_with_configs:
|
| 374 |
+
if any('90' in arch for arch in args.architectures.split(';')):
|
| 375 |
+
problem_configs, problem_operations = generate_sm90_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs'])
|
| 376 |
+
if any(('100' in arch) or ('101' in arch) for arch in args.architectures.split(';')):
|
| 377 |
+
problem_configs, problem_operations = generate_sm100_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs'])
|
| 378 |
+
|
| 379 |
+
operations += problem_operations
|
| 380 |
+
problem_without_configs = {k: v for k, v in problem.items() if k != 'configs'}
|
| 381 |
+
with_problem_size = [{'operation_name': o.procedural_name(), **problem_without_configs, **c} for c, o in zip(problem_configs, problem_operations)]
|
| 382 |
+
all_configs_and_operations += with_problem_size
|
| 383 |
+
|
| 384 |
+
for operation in operations:
|
| 385 |
+
manifest.add_kernel_filter(f"^{operation.procedural_name()}$")
|
| 386 |
+
if not all_configs_and_operations:
|
| 387 |
+
raise Exception("No valid configurations generated")
|
| 388 |
+
write_profiler_testlist_to_csv(all_configs_and_operations, args.heuristics_testlist_file)
|
| 389 |
+
return all_configs_and_operations
|
| 390 |
+
|
| 391 |
+
def write_profiler_testlist_to_csv(configs_list, outfile_path):
|
| 392 |
+
"""
|
| 393 |
+
Write a list of configs to a testlist to be consumed by cutlass_profiler
|
| 394 |
+
|
| 395 |
+
args:
|
| 396 |
+
configs_list: List of kernel configs along with runtime arguments and any other columns to include in the CSV, expressed as a list of dictionaries
|
| 397 |
+
outfile_path: Outfile path
|
| 398 |
+
|
| 399 |
+
returns:
|
| 400 |
+
None
|
| 401 |
+
"""
|
| 402 |
+
profiler_testlist = configs_list.copy()
|
| 403 |
+
for c in profiler_testlist:
|
| 404 |
+
for k, v in c.items():
|
| 405 |
+
if isinstance(v, DataType):
|
| 406 |
+
c[k] = DataTypeNames[v]
|
| 407 |
+
elif isinstance(v, LayoutType):
|
| 408 |
+
c[k] = ShortLayoutTypeNames[v]
|
| 409 |
+
|
| 410 |
+
with open(outfile_path, mode='w', newline='') as ofile:
|
| 411 |
+
k_names = profiler_testlist[0].keys()
|
| 412 |
+
|
| 413 |
+
writer = csv.DictWriter(ofile, fieldnames=k_names)
|
| 414 |
+
writer.writeheader()
|
| 415 |
+
writer.writerows(profiler_testlist)
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Providers for kernel selection heuristics
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import sys
|
| 38 |
+
import os
|
| 39 |
+
import glob
|
| 40 |
+
import logging
|
| 41 |
+
import ctypes
|
| 42 |
+
import functools
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
import builtins
|
| 47 |
+
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 48 |
+
raise ImportError("Disabling attempt to import cutlass_library")
|
| 49 |
+
from cutlass_library.library import DataType, LayoutType
|
| 50 |
+
except ImportError:
|
| 51 |
+
from library import DataType, LayoutType
|
| 52 |
+
|
| 53 |
+
class MatmulHeuristics:
|
| 54 |
+
|
| 55 |
+
def __init__(self, gpu = None):
|
| 56 |
+
import nvMatmulHeuristics
|
| 57 |
+
self.mmh_lib = nvMatmulHeuristics
|
| 58 |
+
self.gpu = gpu
|
| 59 |
+
|
| 60 |
+
if 'CUTLASS_NVMMH_SO_PATH' in os.environ:
|
| 61 |
+
nvmmhInterfaceEx = functools.partial(self.mmh_lib.NvMatmulHeuristicsInterfaceEx, path=os.environ['CUTLASS_NVMMH_SO_PATH'])
|
| 62 |
+
else:
|
| 63 |
+
nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx
|
| 64 |
+
|
| 65 |
+
self.lh = nvmmhInterfaceEx(
|
| 66 |
+
backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"],
|
| 67 |
+
flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING,
|
| 68 |
+
load_discovery_implicitly=True,
|
| 69 |
+
gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None
|
| 70 |
+
)
|
| 71 |
+
self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"])
|
| 72 |
+
|
| 73 |
+
def _layout_from_cutlass(self, layouts):
|
| 74 |
+
assert(len(layouts)==3)
|
| 75 |
+
full_layout_str = ''.join('t' if l == LayoutType.RowMajor else 'n' for l in layouts)
|
| 76 |
+
input_layouts = full_layout_str[:2].upper()
|
| 77 |
+
lh_layout = input_layouts + '_' + str("ROW_MAJOR" if full_layout_str[-1]=='t' else "COL_MAJOR")
|
| 78 |
+
return self.mmh_lib.NvMatmulHeuristicsMatmulLayout[lh_layout]
|
| 79 |
+
|
| 80 |
+
def _precision_from_cutlass_dtypes(self, dtypes):
|
| 81 |
+
dtype_to_cublas = {
|
| 82 |
+
DataType.f64: 'D',
|
| 83 |
+
DataType.f32: 'S',
|
| 84 |
+
DataType.f16: 'H',
|
| 85 |
+
DataType.bf16: 'T',
|
| 86 |
+
DataType.e4m3: 'Q',
|
| 87 |
+
DataType.e5m2: 'R',
|
| 88 |
+
DataType.s32: 'I',
|
| 89 |
+
DataType.s8: 'B',
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
dtype_a, dtype_b, dtype_compute, dtype_c, dtype_d = dtypes
|
| 93 |
+
|
| 94 |
+
a_c = dtype_to_cublas[dtype_a]
|
| 95 |
+
|
| 96 |
+
if a_c.lower() != 'q':
|
| 97 |
+
return a_c + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
|
| 98 |
+
else:
|
| 99 |
+
return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
|
| 100 |
+
|
| 101 |
+
def set_cta_div_n(self, div_n):
|
| 102 |
+
cta_n_div_requirement = ctypes.c_int(div_n)
|
| 103 |
+
self.lh.setBackendValueProperty(
|
| 104 |
+
self.backend,
|
| 105 |
+
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT,
|
| 106 |
+
ctypes.byref(cta_n_div_requirement),
|
| 107 |
+
ctypes.sizeof(cta_n_div_requirement)
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def set_cta_div_m(self, div_m):
|
| 111 |
+
cta_m_div_requirement = ctypes.c_int(div_m)
|
| 112 |
+
self.lh.setBackendValueProperty(
|
| 113 |
+
self.backend,
|
| 114 |
+
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT,
|
| 115 |
+
ctypes.byref(cta_m_div_requirement),
|
| 116 |
+
ctypes.sizeof(cta_m_div_requirement)
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1):
|
| 120 |
+
if use_fast_acc:
|
| 121 |
+
disable_fast_acc_for_fp8 = ctypes.c_int(0)
|
| 122 |
+
else:
|
| 123 |
+
disable_fast_acc_for_fp8 = ctypes.c_int(1)
|
| 124 |
+
self.lh.setBackendValueProperty(
|
| 125 |
+
self.backend,
|
| 126 |
+
self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8,
|
| 127 |
+
ctypes.byref(disable_fast_acc_for_fp8),
|
| 128 |
+
ctypes.sizeof(disable_fast_acc_for_fp8)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
precision = self._precision_from_cutlass_dtypes(dtypes)
|
| 132 |
+
layout = self._layout_from_cutlass(layouts)
|
| 133 |
+
|
| 134 |
+
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count)
|
| 135 |
+
configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision)
|
| 136 |
+
|
| 137 |
+
ret = []
|
| 138 |
+
for c in configs:
|
| 139 |
+
kernel = c['kernel']
|
| 140 |
+
problem = c['problem']
|
| 141 |
+
|
| 142 |
+
r = {}
|
| 143 |
+
r['estimated_runtime'] = c['runtime']
|
| 144 |
+
r['cta_tile_m'] = kernel.cta_tile_m
|
| 145 |
+
r['cta_tile_n'] = kernel.cta_tile_n
|
| 146 |
+
r['cta_tile_k'] = kernel.cta_tile_k
|
| 147 |
+
r['instr_tile_m'] = kernel.instr_tile_m
|
| 148 |
+
r['instr_tile_n'] = kernel.instr_tile_n
|
| 149 |
+
r['instr_tile_k'] = kernel.instr_tile_k
|
| 150 |
+
r['warp_tile_m'] = kernel.warp_tile_m
|
| 151 |
+
r['warp_tile_n'] = kernel.warp_tile_n
|
| 152 |
+
r['warp_tile_k'] = kernel.warp_tile_k
|
| 153 |
+
r['cluster_m'] = kernel.cluster_m
|
| 154 |
+
r['cluster_n'] = kernel.cluster_n
|
| 155 |
+
r['cluster_k'] = 1
|
| 156 |
+
r['layout_a'] = layouts[0]
|
| 157 |
+
r['layout_b'] = layouts[1]
|
| 158 |
+
r['layout_d'] = layouts[2]
|
| 159 |
+
r['dtype_a'] = dtypes[0]
|
| 160 |
+
r['dtype_b'] = dtypes[1]
|
| 161 |
+
r['dtype_acc'] = dtypes[2]
|
| 162 |
+
r['dtype_c'] = dtypes[3]
|
| 163 |
+
r['dtype_d'] = dtypes[4]
|
| 164 |
+
r['alignment_a'] = align_a
|
| 165 |
+
r['alignment_b'] = align_b
|
| 166 |
+
r['swizzle_size'] = kernel.swizzle_factor
|
| 167 |
+
r['raster_order'] = 'along_m' if kernel.cta_order==0 else 'along_n'
|
| 168 |
+
r['split_k_slices'] = kernel.split_k
|
| 169 |
+
r['use_fast_acc'] = use_fast_acc
|
| 170 |
+
r['voidC'] = voidC
|
| 171 |
+
|
| 172 |
+
ret.append(r)
|
| 173 |
+
|
| 174 |
+
return ret
|
| 175 |
+
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/library.py
ADDED
|
@@ -0,0 +1,1531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Data types and tags used for emitting CUTLASS C++ kernels
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import enum
|
| 38 |
+
import re
|
| 39 |
+
|
| 40 |
+
# The following block implements enum.auto() for Python 3.5 variants that don't include it such
|
| 41 |
+
# as the default 3.5.2 on Ubuntu 16.04.
|
| 42 |
+
#
|
| 43 |
+
# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
from enum import auto as enum_auto
|
| 47 |
+
except ImportError:
|
| 48 |
+
__cutlass_library_auto_enum = 0
|
| 49 |
+
def enum_auto() -> int:
|
| 50 |
+
global __cutlass_library_auto_enum
|
| 51 |
+
i = __cutlass_library_auto_enum
|
| 52 |
+
__cutlass_library_auto_enum += 1
|
| 53 |
+
return i
|
| 54 |
+
|
| 55 |
+
###################################################################################################
|
| 56 |
+
|
| 57 |
+
#
|
| 58 |
+
class GeneratorTarget(enum.Enum):
|
| 59 |
+
Library = enum_auto()
|
| 60 |
+
#
|
| 61 |
+
GeneratorTargetNames = {
|
| 62 |
+
GeneratorTarget.Library: 'library'
|
| 63 |
+
}
|
| 64 |
+
#
|
| 65 |
+
|
| 66 |
+
###################################################################################################
|
| 67 |
+
|
| 68 |
+
#
|
| 69 |
+
class DataType(enum.Enum):
|
| 70 |
+
void = enum_auto() # primarily used to disable C tensor for epilogues
|
| 71 |
+
b1 = enum_auto()
|
| 72 |
+
u2 = enum_auto()
|
| 73 |
+
u4 = enum_auto()
|
| 74 |
+
u8 = enum_auto()
|
| 75 |
+
u16 = enum_auto()
|
| 76 |
+
u32 = enum_auto()
|
| 77 |
+
u64 = enum_auto()
|
| 78 |
+
s2 = enum_auto()
|
| 79 |
+
s4 = enum_auto()
|
| 80 |
+
s8 = enum_auto()
|
| 81 |
+
s16 = enum_auto()
|
| 82 |
+
s32 = enum_auto()
|
| 83 |
+
s64 = enum_auto()
|
| 84 |
+
e4m3 = enum_auto()
|
| 85 |
+
e5m2 = enum_auto()
|
| 86 |
+
f8 = enum_auto()
|
| 87 |
+
f6 = enum_auto()
|
| 88 |
+
f4 = enum_auto()
|
| 89 |
+
e3m2 = enum_auto()
|
| 90 |
+
e2m3 = enum_auto()
|
| 91 |
+
e2m1 = enum_auto()
|
| 92 |
+
ue8m0 = enum_auto()
|
| 93 |
+
ue4m3 = enum_auto()
|
| 94 |
+
f16 = enum_auto()
|
| 95 |
+
bf16 = enum_auto()
|
| 96 |
+
f32 = enum_auto()
|
| 97 |
+
tf32 = enum_auto()
|
| 98 |
+
f64 = enum_auto()
|
| 99 |
+
cf16 = enum_auto()
|
| 100 |
+
cbf16 = enum_auto()
|
| 101 |
+
cf32 = enum_auto()
|
| 102 |
+
ctf32 = enum_auto()
|
| 103 |
+
cf64 = enum_auto()
|
| 104 |
+
cs2 = enum_auto()
|
| 105 |
+
cs4 = enum_auto()
|
| 106 |
+
cs8 = enum_auto()
|
| 107 |
+
cs16 = enum_auto()
|
| 108 |
+
cs32 = enum_auto()
|
| 109 |
+
cs64 = enum_auto()
|
| 110 |
+
cu2 = enum_auto()
|
| 111 |
+
cu4 = enum_auto()
|
| 112 |
+
cu8 = enum_auto()
|
| 113 |
+
cu16 = enum_auto()
|
| 114 |
+
cu32 = enum_auto()
|
| 115 |
+
cu64 = enum_auto()
|
| 116 |
+
invalid = enum_auto()
|
| 117 |
+
|
| 118 |
+
#
|
| 119 |
+
ShortDataTypeNames = {
|
| 120 |
+
DataType.s32: 'i',
|
| 121 |
+
DataType.e4m3: 'e4m3',
|
| 122 |
+
DataType.e5m2: 'e5m2',
|
| 123 |
+
DataType.f16: 'h',
|
| 124 |
+
DataType.f32: 's',
|
| 125 |
+
DataType.f64: 'd',
|
| 126 |
+
DataType.cf32: 'c',
|
| 127 |
+
DataType.cf64: 'z',
|
| 128 |
+
DataType.f8: 'f8',
|
| 129 |
+
DataType.f6: 'f6',
|
| 130 |
+
DataType.f4: 'f4',
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
#
|
| 134 |
+
DataTypeNames = {
|
| 135 |
+
DataType.void: "void",
|
| 136 |
+
DataType.b1: "b1",
|
| 137 |
+
DataType.u2: "u2",
|
| 138 |
+
DataType.u4: "u4",
|
| 139 |
+
DataType.u8: "u8",
|
| 140 |
+
DataType.u16: "u16",
|
| 141 |
+
DataType.u32: "u32",
|
| 142 |
+
DataType.u64: "u64",
|
| 143 |
+
DataType.s2: "s2",
|
| 144 |
+
DataType.s4: "s4",
|
| 145 |
+
DataType.s8: "s8",
|
| 146 |
+
DataType.s16: "s16",
|
| 147 |
+
DataType.s32: "s32",
|
| 148 |
+
DataType.s64: "s64",
|
| 149 |
+
DataType.e4m3: 'e4m3',
|
| 150 |
+
DataType.e5m2: 'e5m2',
|
| 151 |
+
DataType.f8: 'f8',
|
| 152 |
+
DataType.f6: 'f6',
|
| 153 |
+
DataType.f4: 'f4',
|
| 154 |
+
DataType.e2m3: 'e2m3',
|
| 155 |
+
DataType.e3m2: 'e3m2',
|
| 156 |
+
DataType.e2m1: 'e2m1',
|
| 157 |
+
DataType.ue8m0: 'ue8m0',
|
| 158 |
+
DataType.ue4m3: 'ue4m3',
|
| 159 |
+
DataType.f16: "f16",
|
| 160 |
+
DataType.bf16: "bf16",
|
| 161 |
+
DataType.f32: "f32",
|
| 162 |
+
DataType.tf32: "tf32",
|
| 163 |
+
DataType.f64: "f64",
|
| 164 |
+
DataType.cf16: "cf16",
|
| 165 |
+
DataType.cbf16: "cbf16",
|
| 166 |
+
DataType.cf32: "cf32",
|
| 167 |
+
DataType.ctf32: "ctf32",
|
| 168 |
+
DataType.cf64: "cf64",
|
| 169 |
+
DataType.cu2: "cu2",
|
| 170 |
+
DataType.cu4: "cu4",
|
| 171 |
+
DataType.cu8: "cu8",
|
| 172 |
+
DataType.cu16: "cu16",
|
| 173 |
+
DataType.cu32: "cu32",
|
| 174 |
+
DataType.cu64: "cu64",
|
| 175 |
+
DataType.cs2: "cs2",
|
| 176 |
+
DataType.cs4: "cs4",
|
| 177 |
+
DataType.cs8: "cs8",
|
| 178 |
+
DataType.cs16: "cs16",
|
| 179 |
+
DataType.cs32: "cs32",
|
| 180 |
+
DataType.cs64: "cs64",
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
DataTypeTag = {
|
| 184 |
+
DataType.void: "void",
|
| 185 |
+
DataType.b1: "cutlass::uint1b_t",
|
| 186 |
+
DataType.u2: "cutlass::uint2b_t",
|
| 187 |
+
DataType.u4: "cutlass::uint4b_t",
|
| 188 |
+
DataType.u8: "uint8_t",
|
| 189 |
+
DataType.u16: "uint16_t",
|
| 190 |
+
DataType.u32: "uint32_t",
|
| 191 |
+
DataType.u64: "uint64_t",
|
| 192 |
+
DataType.s2: "cutlass::int2b_t",
|
| 193 |
+
DataType.s4: "cutlass::int4b_t",
|
| 194 |
+
DataType.s8: "int8_t",
|
| 195 |
+
DataType.s16: "int16_t",
|
| 196 |
+
DataType.s32: "int32_t",
|
| 197 |
+
DataType.s64: "int64_t",
|
| 198 |
+
DataType.e4m3: 'cutlass::float_e4m3_t',
|
| 199 |
+
DataType.e5m2: 'cutlass::float_e5m2_t',
|
| 200 |
+
DataType.f8: 'cutlass::type_erased_dynamic_float8_t',
|
| 201 |
+
DataType.f6: 'cutlass::type_erased_dynamic_float6_t',
|
| 202 |
+
DataType.f4: 'cutlass::type_erased_dynamic_float4_t',
|
| 203 |
+
DataType.e2m3: 'cutlass::float_e2m3_t',
|
| 204 |
+
DataType.e3m2: 'cutlass::float_e3m2_t',
|
| 205 |
+
DataType.e2m1: 'cutlass::float_e2m1_t',
|
| 206 |
+
DataType.ue8m0: 'cutlass::float_ue8m0_t',
|
| 207 |
+
DataType.ue4m3: 'cutlass::float_ue4m3_t',
|
| 208 |
+
DataType.f16: "cutlass::half_t",
|
| 209 |
+
DataType.bf16: "cutlass::bfloat16_t",
|
| 210 |
+
DataType.f32: "float",
|
| 211 |
+
DataType.tf32: "cutlass::tfloat32_t",
|
| 212 |
+
DataType.f64: "double",
|
| 213 |
+
DataType.cf16: "cutlass::complex<cutlass::half_t>",
|
| 214 |
+
DataType.cbf16: "cutlass::complex<cutlass::bfloat16_t>",
|
| 215 |
+
DataType.cf32: "cutlass::complex<float>",
|
| 216 |
+
DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
|
| 217 |
+
DataType.cf64: "cutlass::complex<double>",
|
| 218 |
+
DataType.cu2: "cutlass::complex<cutlass::uint2b_t>",
|
| 219 |
+
DataType.cu4: "cutlass::complex<cutlass::uint4b_t>",
|
| 220 |
+
DataType.cu8: "cutlass::complex<cutlass::uint8_t>",
|
| 221 |
+
DataType.cu16: "cutlass::complex<cutlass::uint16_t>",
|
| 222 |
+
DataType.cu32: "cutlass::complex<cutlass::uint32_t>",
|
| 223 |
+
DataType.cu64: "cutlass::complex<cutlass::uint64_t>",
|
| 224 |
+
DataType.cs2: "cutlass::complex<cutlass::int2b_t>",
|
| 225 |
+
DataType.cs4: "cutlass::complex<cutlass::int4b_t>",
|
| 226 |
+
DataType.cs8: "cutlass::complex<cutlass::int8_t>",
|
| 227 |
+
DataType.cs16: "cutlass::complex<cutlass::int16_t>",
|
| 228 |
+
DataType.cs32: "cutlass::complex<cutlass::int32_t>",
|
| 229 |
+
DataType.cs64: "cutlass::complex<cutlass::int64_t>",
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
DataTypeSize = {
|
| 233 |
+
DataType.void: 0,
|
| 234 |
+
DataType.b1: 1,
|
| 235 |
+
DataType.u2: 2,
|
| 236 |
+
DataType.u4: 4,
|
| 237 |
+
DataType.u8: 8,
|
| 238 |
+
DataType.u16: 16,
|
| 239 |
+
DataType.u32: 32,
|
| 240 |
+
DataType.u64: 64,
|
| 241 |
+
DataType.s2: 2,
|
| 242 |
+
DataType.s4: 4,
|
| 243 |
+
DataType.s8: 8,
|
| 244 |
+
DataType.s16: 16,
|
| 245 |
+
DataType.s32: 32,
|
| 246 |
+
DataType.s64: 64,
|
| 247 |
+
DataType.e4m3: 8,
|
| 248 |
+
DataType.e5m2: 8,
|
| 249 |
+
DataType.f8: 8,
|
| 250 |
+
DataType.f6: 6,
|
| 251 |
+
DataType.f4: 4,
|
| 252 |
+
DataType.e2m3: 6,
|
| 253 |
+
DataType.e3m2: 6,
|
| 254 |
+
DataType.e2m1: 4,
|
| 255 |
+
DataType.ue8m0: 8,
|
| 256 |
+
DataType.ue4m3: 8,
|
| 257 |
+
DataType.f16: 16,
|
| 258 |
+
DataType.bf16: 16,
|
| 259 |
+
DataType.f32: 32,
|
| 260 |
+
DataType.tf32: 32,
|
| 261 |
+
DataType.f64: 64,
|
| 262 |
+
DataType.cf16: 32,
|
| 263 |
+
DataType.cbf16: 32,
|
| 264 |
+
DataType.cf32: 64,
|
| 265 |
+
DataType.ctf32: 32,
|
| 266 |
+
DataType.cf64: 128,
|
| 267 |
+
DataType.cu2: 4,
|
| 268 |
+
DataType.cu4: 8,
|
| 269 |
+
DataType.cu8: 16,
|
| 270 |
+
DataType.cu16: 32,
|
| 271 |
+
DataType.cu32: 64,
|
| 272 |
+
DataType.cu64: 128,
|
| 273 |
+
DataType.cs2: 4,
|
| 274 |
+
DataType.cs4: 8,
|
| 275 |
+
DataType.cs8: 16,
|
| 276 |
+
DataType.cs16: 32,
|
| 277 |
+
DataType.cs32: 64,
|
| 278 |
+
DataType.cs64: 128,
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
###################################################################################################
|
| 282 |
+
#
|
| 283 |
+
class BlasMode(enum.Enum):
|
| 284 |
+
symmetric = enum_auto()
|
| 285 |
+
hermitian = enum_auto()
|
| 286 |
+
|
| 287 |
+
#
|
| 288 |
+
BlasModeTag = {
|
| 289 |
+
BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric',
|
| 290 |
+
BlasMode.hermitian: 'cutlass::BlasMode::kHermitian',
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
#
|
| 294 |
+
class ComplexTransform(enum.Enum):
|
| 295 |
+
none = enum_auto()
|
| 296 |
+
conj = enum_auto()
|
| 297 |
+
|
| 298 |
+
#
|
| 299 |
+
ComplexTransformTag = {
|
| 300 |
+
ComplexTransform.none: 'cutlass::ComplexTransform::kNone',
|
| 301 |
+
ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate',
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
# Used for cutlass3x complex kernel collective mainloop builder instantiation
|
| 305 |
+
ComplexTransformTag3x = {
|
| 306 |
+
ComplexTransform.none: 'cute::identity',
|
| 307 |
+
ComplexTransform.conj: 'cute::conjugate',
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
#
|
| 311 |
+
RealComplexBijection = [
|
| 312 |
+
(DataType.f16, DataType.cf16),
|
| 313 |
+
(DataType.f32, DataType.cf32),
|
| 314 |
+
(DataType.f64, DataType.cf64),
|
| 315 |
+
]
|
| 316 |
+
|
| 317 |
+
#
|
| 318 |
+
def is_complex(data_type):
|
| 319 |
+
for r, c in RealComplexBijection:
|
| 320 |
+
if data_type == c:
|
| 321 |
+
return True
|
| 322 |
+
return False
|
| 323 |
+
|
| 324 |
+
def is_block_scaled(gemm_kind):
|
| 325 |
+
return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x)
|
| 326 |
+
|
| 327 |
+
def is_blockwise(gemm_kind):
|
| 328 |
+
return gemm_kind in (GemmKind.BlockwiseUniversal3x, GemmKind.GroupedBlockwiseUniversal3x)
|
| 329 |
+
|
| 330 |
+
def is_grouped(gemm_kind):
|
| 331 |
+
return gemm_kind in (GemmKind.GroupedUniversal3x,
|
| 332 |
+
GemmKind.GroupedBlockScaledUniversal3x, GemmKind.GroupedBlockwiseUniversal3x)
|
| 333 |
+
|
| 334 |
+
#
|
| 335 |
+
def get_complex_from_real(real_type):
|
| 336 |
+
for r, c in RealComplexBijection:
|
| 337 |
+
if real_type == r:
|
| 338 |
+
return c
|
| 339 |
+
return DataType.invalid
|
| 340 |
+
|
| 341 |
+
#
|
| 342 |
+
def get_real_from_complex(complex_type):
|
| 343 |
+
for r, c in RealComplexBijection:
|
| 344 |
+
if complex_type == c:
|
| 345 |
+
return r
|
| 346 |
+
return DataType.invalid
|
| 347 |
+
|
| 348 |
+
# TMA requires an alignment of 128 bits for all data types
|
| 349 |
+
def get_tma_alignment(data_type):
|
| 350 |
+
if data_type == DataType.void:
|
| 351 |
+
return 0
|
| 352 |
+
elif DataTypeSize[data_type] == 6:
|
| 353 |
+
return 128 # 96B alignment for 16U6 format
|
| 354 |
+
else:
|
| 355 |
+
return 128 // DataTypeSize[data_type]
|
| 356 |
+
|
| 357 |
+
#
|
| 358 |
+
class ComplexMultiplyOp(enum.Enum):
|
| 359 |
+
multiply_add = enum_auto()
|
| 360 |
+
gaussian = enum_auto()
|
| 361 |
+
|
| 362 |
+
###################################################################################################
|
| 363 |
+
|
| 364 |
+
#
|
| 365 |
+
class MathOperation(enum.Enum):
|
| 366 |
+
multiply_add = enum_auto()
|
| 367 |
+
multiply_add_saturate = enum_auto()
|
| 368 |
+
multiply_add_mixed_input_upcast = enum_auto()
|
| 369 |
+
xor_popc = enum_auto()
|
| 370 |
+
and_popc = enum_auto()
|
| 371 |
+
multiply_add_fast_bf16 = enum_auto()
|
| 372 |
+
multiply_add_fast_f16 = enum_auto()
|
| 373 |
+
multiply_add_fast_f32 = enum_auto()
|
| 374 |
+
multiply_add_complex_fast_f32 = enum_auto()
|
| 375 |
+
multiply_add_complex = enum_auto()
|
| 376 |
+
multiply_add_complex_gaussian = enum_auto()
|
| 377 |
+
multiply_add_fast_accum = enum_auto()
|
| 378 |
+
|
| 379 |
+
#
|
| 380 |
+
MathOperationTag = {
|
| 381 |
+
MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',
|
| 382 |
+
MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate',
|
| 383 |
+
MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast',
|
| 384 |
+
MathOperation.xor_popc: 'cutlass::arch::OpXorPopc',
|
| 385 |
+
MathOperation.and_popc: 'cutlass::arch::OpAndPopc',
|
| 386 |
+
MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16',
|
| 387 |
+
MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16',
|
| 388 |
+
MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32',
|
| 389 |
+
MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32',
|
| 390 |
+
MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex',
|
| 391 |
+
MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex',
|
| 392 |
+
MathOperation.multiply_add_fast_accum: 'cutlass::arch::OpMultiplyAddFastAccum',
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
###################################################################################################
|
| 396 |
+
|
| 397 |
+
#
|
| 398 |
+
class LayoutType(enum.Enum):
|
| 399 |
+
ColumnMajor = enum_auto()
|
| 400 |
+
RowMajor = enum_auto()
|
| 401 |
+
ColumnMajorInterleaved2 = enum_auto()
|
| 402 |
+
RowMajorInterleaved2 = enum_auto()
|
| 403 |
+
ColumnMajorInterleaved32 = enum_auto()
|
| 404 |
+
RowMajorInterleaved32 = enum_auto()
|
| 405 |
+
ColumnMajorInterleaved64 = enum_auto()
|
| 406 |
+
RowMajorInterleaved64 = enum_auto()
|
| 407 |
+
TensorNWC = enum_auto()
|
| 408 |
+
TensorNHWC = enum_auto()
|
| 409 |
+
TensorNDHWC = enum_auto()
|
| 410 |
+
TensorNCHW = enum_auto()
|
| 411 |
+
TensorNGHWC = enum_auto()
|
| 412 |
+
TensorNC32HW32 = enum_auto()
|
| 413 |
+
TensorNC64HW64 = enum_auto()
|
| 414 |
+
TensorC32RSK32 = enum_auto()
|
| 415 |
+
TensorC64RSK64 = enum_auto()
|
| 416 |
+
TensorKCS = enum_auto()
|
| 417 |
+
TensorKCSR = enum_auto()
|
| 418 |
+
TensorKCSRT = enum_auto()
|
| 419 |
+
|
| 420 |
+
#
|
| 421 |
+
LayoutTag = {
|
| 422 |
+
LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor',
|
| 423 |
+
LayoutType.RowMajor: 'cutlass::layout::RowMajor',
|
| 424 |
+
LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>',
|
| 425 |
+
LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>',
|
| 426 |
+
LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>',
|
| 427 |
+
LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>',
|
| 428 |
+
LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>',
|
| 429 |
+
LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>',
|
| 430 |
+
LayoutType.TensorNWC: 'cutlass::layout::TensorNWC',
|
| 431 |
+
LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC',
|
| 432 |
+
LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC',
|
| 433 |
+
LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW',
|
| 434 |
+
LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC',
|
| 435 |
+
LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>',
|
| 436 |
+
LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>',
|
| 437 |
+
LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>',
|
| 438 |
+
LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>',
|
| 439 |
+
LayoutType.TensorKCS: 'cutlass::layout::TensorKCS',
|
| 440 |
+
LayoutType.TensorKCSR: 'cutlass::layout::TensorKCSR',
|
| 441 |
+
LayoutType.TensorKCSRT: 'cutlass::layout::TensorKCSRT'
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
#
|
| 445 |
+
TransposedLayout = {
|
| 446 |
+
LayoutType.ColumnMajor: LayoutType.RowMajor,
|
| 447 |
+
LayoutType.RowMajor: LayoutType.ColumnMajor,
|
| 448 |
+
LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2,
|
| 449 |
+
LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2,
|
| 450 |
+
LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32,
|
| 451 |
+
LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32,
|
| 452 |
+
LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64,
|
| 453 |
+
LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64,
|
| 454 |
+
LayoutType.TensorNHWC: LayoutType.TensorNHWC
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
#
|
| 458 |
+
ShortLayoutTypeNames = {
|
| 459 |
+
LayoutType.ColumnMajor: 'n',
|
| 460 |
+
LayoutType.ColumnMajorInterleaved2: 'n2',
|
| 461 |
+
LayoutType.ColumnMajorInterleaved32: 'n32',
|
| 462 |
+
LayoutType.ColumnMajorInterleaved64: 'n64',
|
| 463 |
+
LayoutType.RowMajor: 't',
|
| 464 |
+
LayoutType.RowMajorInterleaved2: 't2',
|
| 465 |
+
LayoutType.RowMajorInterleaved32: 't32',
|
| 466 |
+
LayoutType.RowMajorInterleaved64: 't64',
|
| 467 |
+
LayoutType.TensorNWC: 'nwc',
|
| 468 |
+
LayoutType.TensorNHWC: 'nhwc',
|
| 469 |
+
LayoutType.TensorNDHWC: 'ndhwc',
|
| 470 |
+
LayoutType.TensorNCHW: 'nchw',
|
| 471 |
+
LayoutType.TensorNGHWC: 'nghwc',
|
| 472 |
+
LayoutType.TensorNC32HW32: 'nc32hw32',
|
| 473 |
+
LayoutType.TensorNC64HW64: 'nc64hw64',
|
| 474 |
+
LayoutType.TensorC32RSK32: 'c32rsk32',
|
| 475 |
+
LayoutType.TensorC64RSK64: 'c64rsk64',
|
| 476 |
+
LayoutType.TensorKCS: 'kcs',
|
| 477 |
+
LayoutType.TensorKCSR: 'kcsr',
|
| 478 |
+
LayoutType.TensorKCSRT: 'kcsrt'
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
#
|
| 482 |
+
ShortComplexLayoutNames = {
|
| 483 |
+
(LayoutType.ColumnMajor, ComplexTransform.none): 'n',
|
| 484 |
+
(LayoutType.ColumnMajor, ComplexTransform.conj): 'c',
|
| 485 |
+
(LayoutType.RowMajor, ComplexTransform.none): 't',
|
| 486 |
+
(LayoutType.RowMajor, ComplexTransform.conj): 'h'
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
###################################################################################################
|
| 490 |
+
class KernelScheduleType(enum.Enum):
|
| 491 |
+
ScheduleAuto = enum_auto()
|
| 492 |
+
Multistage = enum_auto()
|
| 493 |
+
CpAsyncWarpSpecialized = enum_auto()
|
| 494 |
+
CpAsyncWarpSpecializedPingpong = enum_auto()
|
| 495 |
+
CpAsyncWarpSpecializedCooperative = enum_auto()
|
| 496 |
+
Tma = enum_auto()
|
| 497 |
+
TmaWarpSpecialized = enum_auto()
|
| 498 |
+
TmaWarpSpecializedPingpong = enum_auto()
|
| 499 |
+
TmaWarpSpecializedCooperative = enum_auto()
|
| 500 |
+
TmaWarpSpecializedFP8FastAccum = enum_auto()
|
| 501 |
+
TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
| 502 |
+
TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
| 503 |
+
ImplicitTmaWarpSpecializedSm90 = enum_auto()
|
| 504 |
+
PtrArrayTmaWarpSpecializedCooperative = enum_auto()
|
| 505 |
+
PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
|
| 506 |
+
PtrArrayTmaWarpSpecializedPingpong = enum_auto()
|
| 507 |
+
PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
|
| 508 |
+
|
| 509 |
+
BlockwiseTmaWarpSpecializedCooperative = enum_auto()
|
| 510 |
+
PtrArrayBlockwiseTmaWarpSpecializedCooperative = enum_auto()
|
| 511 |
+
BlockwiseTmaWarpSpecializedPingpong = enum_auto()
|
| 512 |
+
PtrArrayBlockwiseTmaWarpSpecializedPingpong = enum_auto()
|
| 513 |
+
|
| 514 |
+
TmaWarpSpecialized1SmSm100 = enum_auto()
|
| 515 |
+
TmaWarpSpecialized2SmSm100 = enum_auto()
|
| 516 |
+
ImplicitTmaWarpSpecialized1SmSm100 = enum_auto()
|
| 517 |
+
ImplicitTmaWarpSpecialized2SmSm100 = enum_auto()
|
| 518 |
+
|
| 519 |
+
PtrArrayTmaWarpSpecialized1SmSm100 = enum_auto()
|
| 520 |
+
PtrArrayTmaWarpSpecialized2SmSm100 = enum_auto()
|
| 521 |
+
|
| 522 |
+
PtrArrayTmaWarpSpecialized1SmBlockScaledSm100 = enum_auto()
|
| 523 |
+
PtrArrayTmaWarpSpecialized2SmBlockScaledSm100 = enum_auto()
|
| 524 |
+
PtrArrayNvf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
| 525 |
+
PtrArrayNvf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
| 526 |
+
PtrArrayMxf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
| 527 |
+
PtrArrayMxf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
| 528 |
+
PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
|
| 529 |
+
PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()
|
| 530 |
+
|
| 531 |
+
SparseTmaWarpSpecialized1SmSm100 = enum_auto()
|
| 532 |
+
SparseTmaWarpSpecialized2SmSm100 = enum_auto()
|
| 533 |
+
|
| 534 |
+
BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto()
|
| 535 |
+
BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto()
|
| 536 |
+
Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
|
| 537 |
+
Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()
|
| 538 |
+
|
| 539 |
+
BlockwiseTmaWarpSpecialized1SmSm100 = enum_auto()
|
| 540 |
+
BlockwiseTmaWarpSpecialized2SmSm100 = enum_auto()
|
| 541 |
+
|
| 542 |
+
PtrArrayBlockwiseTmaWarpSpecialized1SmSm100 = enum_auto()
|
| 543 |
+
PtrArrayBlockwiseTmaWarpSpecialized2SmSm100 = enum_auto()
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
Mxf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
| 547 |
+
Mxf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
| 548 |
+
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
| 549 |
+
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
| 550 |
+
|
| 551 |
+
# FP4 Ultra
|
| 552 |
+
MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto()
|
| 553 |
+
MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto()
|
| 554 |
+
MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto()
|
| 555 |
+
MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto()
|
| 556 |
+
|
| 557 |
+
MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto()
|
| 558 |
+
MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto()
|
| 559 |
+
MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto()
|
| 560 |
+
MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto()
|
| 561 |
+
|
| 562 |
+
MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto()
|
| 563 |
+
MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto()
|
| 564 |
+
MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto()
|
| 565 |
+
MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto()
|
| 566 |
+
|
| 567 |
+
PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto()
|
| 568 |
+
PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto()
|
| 569 |
+
PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto()
|
| 570 |
+
PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto()
|
| 571 |
+
|
| 572 |
+
PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto()
|
| 573 |
+
PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto()
|
| 574 |
+
PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto()
|
| 575 |
+
PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto()
|
| 576 |
+
|
| 577 |
+
PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto()
|
| 578 |
+
PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto()
|
| 579 |
+
PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto()
|
| 580 |
+
PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto()
|
| 581 |
+
|
| 582 |
+
Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto()
|
| 583 |
+
Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto()
|
| 584 |
+
Nvf4TmaWarpSpecializedCooperativeSm120 = enum_auto()
|
| 585 |
+
Nvf4TmaWarpSpecializedPingpongSm120 = enum_auto()
|
| 586 |
+
Mxf4TmaWarpSpecializedCooperativeSm120 = enum_auto()
|
| 587 |
+
Mxf4TmaWarpSpecializedPingpongSm120 = enum_auto()
|
| 588 |
+
|
| 589 |
+
F8f6f4SparseTmaWarpSpecializedCooperativeSm120 = enum_auto()
|
| 590 |
+
|
| 591 |
+
BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto()
|
| 592 |
+
BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto()
|
| 593 |
+
|
| 594 |
+
KernelScheduleTag = {
|
| 595 |
+
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
|
| 596 |
+
KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage',
|
| 597 |
+
KernelScheduleType.CpAsyncWarpSpecialized: 'cutlass::gemm::KernelCpAsyncWarpSpecialized',
|
| 598 |
+
KernelScheduleType.CpAsyncWarpSpecializedPingpong: 'cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong',
|
| 599 |
+
KernelScheduleType.CpAsyncWarpSpecializedCooperative: 'cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative',
|
| 600 |
+
KernelScheduleType.Tma: 'cutlass::gemm::KernelTma',
|
| 601 |
+
KernelScheduleType.TmaWarpSpecialized: 'cutlass::gemm::KernelTmaWarpSpecialized',
|
| 602 |
+
KernelScheduleType.TmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpong',
|
| 603 |
+
KernelScheduleType.TmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperative',
|
| 604 |
+
KernelScheduleType.TmaWarpSpecializedFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum',
|
| 605 |
+
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum',
|
| 606 |
+
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum',
|
| 607 |
+
KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90',
|
| 608 |
+
|
| 609 |
+
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise',
|
| 610 |
+
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8Blockwise',
|
| 611 |
+
|
| 612 |
+
KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100',
|
| 613 |
+
KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100',
|
| 614 |
+
|
| 615 |
+
KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: 'cutlass::conv::KernelImplicitTmaWarpSpecialized1SmSm100',
|
| 616 |
+
KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: 'cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100',
|
| 617 |
+
|
| 618 |
+
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100',
|
| 619 |
+
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100',
|
| 620 |
+
|
| 621 |
+
KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100',
|
| 622 |
+
KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100',
|
| 623 |
+
|
| 624 |
+
KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100',
|
| 625 |
+
KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100',
|
| 626 |
+
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100',
|
| 627 |
+
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100',
|
| 628 |
+
|
| 629 |
+
KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100',
|
| 630 |
+
KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100',
|
| 631 |
+
|
| 632 |
+
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100',
|
| 633 |
+
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100',
|
| 634 |
+
|
| 635 |
+
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100',
|
| 636 |
+
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100',
|
| 637 |
+
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
|
| 638 |
+
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
|
| 639 |
+
|
| 640 |
+
# FP4 Ultra
|
| 641 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103',
|
| 642 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103',
|
| 643 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103',
|
| 644 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103',
|
| 645 |
+
|
| 646 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
|
| 647 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
|
| 648 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
|
| 649 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
|
| 650 |
+
|
| 651 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
|
| 652 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
|
| 653 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
|
| 654 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
|
| 655 |
+
|
| 656 |
+
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
|
| 657 |
+
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
|
| 658 |
+
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
|
| 659 |
+
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
|
| 660 |
+
|
| 661 |
+
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise',
|
| 662 |
+
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise',
|
| 663 |
+
|
| 664 |
+
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100",
|
| 665 |
+
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100",
|
| 666 |
+
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100",
|
| 667 |
+
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100",
|
| 668 |
+
KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100",
|
| 669 |
+
KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100",
|
| 670 |
+
KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100",
|
| 671 |
+
KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100",
|
| 672 |
+
|
| 673 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103',
|
| 674 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103',
|
| 675 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103',
|
| 676 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103',
|
| 677 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
|
| 678 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
|
| 679 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
|
| 680 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
|
| 681 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
|
| 682 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
|
| 683 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
|
| 684 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
|
| 685 |
+
|
| 686 |
+
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf8f6f4Sm120',
|
| 687 |
+
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120',
|
| 688 |
+
KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120',
|
| 689 |
+
KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongNvf4Sm120',
|
| 690 |
+
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120',
|
| 691 |
+
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf4Sm120',
|
| 692 |
+
|
| 693 |
+
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelScheduleSparseF8f6f4Sm120',
|
| 694 |
+
|
| 695 |
+
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120',
|
| 696 |
+
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120',
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
#
|
| 700 |
+
KernelScheduleSuffixes = {
|
| 701 |
+
KernelScheduleType.ScheduleAuto: '',
|
| 702 |
+
KernelScheduleType.Multistage: '_cpasync',
|
| 703 |
+
KernelScheduleType.CpAsyncWarpSpecialized: '_cpasync_warpspecialized',
|
| 704 |
+
KernelScheduleType.CpAsyncWarpSpecializedPingpong: '_cpasync_warpspecialized_pingpong',
|
| 705 |
+
KernelScheduleType.CpAsyncWarpSpecializedCooperative: '_cpasync_warpspecialized_cooperative',
|
| 706 |
+
KernelScheduleType.Tma: '_unspecialized',
|
| 707 |
+
KernelScheduleType.TmaWarpSpecialized: '_warpspecialized',
|
| 708 |
+
KernelScheduleType.TmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
|
| 709 |
+
KernelScheduleType.TmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
| 710 |
+
KernelScheduleType.TmaWarpSpecializedFP8FastAccum: '_warpspecialized_fp8_fastaccum',
|
| 711 |
+
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
|
| 712 |
+
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
|
| 713 |
+
KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized',
|
| 714 |
+
|
| 715 |
+
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
| 716 |
+
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
|
| 717 |
+
|
| 718 |
+
KernelScheduleType.TmaWarpSpecialized1SmSm100: '_1sm',
|
| 719 |
+
KernelScheduleType.TmaWarpSpecialized2SmSm100: '_2sm',
|
| 720 |
+
|
| 721 |
+
KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: '_1sm',
|
| 722 |
+
KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: '_2sm',
|
| 723 |
+
|
| 724 |
+
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: '_1sm',
|
| 725 |
+
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: '_2sm',
|
| 726 |
+
|
| 727 |
+
KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: '_1sm',
|
| 728 |
+
KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: '_2sm',
|
| 729 |
+
|
| 730 |
+
KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: '_1sm',
|
| 731 |
+
KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: '_2sm',
|
| 732 |
+
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm',
|
| 733 |
+
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm',
|
| 734 |
+
|
| 735 |
+
KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: '_1sm',
|
| 736 |
+
KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: '_2sm',
|
| 737 |
+
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: '_1sm',
|
| 738 |
+
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: '_2sm',
|
| 739 |
+
|
| 740 |
+
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
|
| 741 |
+
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
|
| 742 |
+
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
|
| 743 |
+
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
|
| 744 |
+
|
| 745 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm',
|
| 746 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm',
|
| 747 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm',
|
| 748 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm',
|
| 749 |
+
|
| 750 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf',
|
| 751 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf',
|
| 752 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf',
|
| 753 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf',
|
| 754 |
+
|
| 755 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf',
|
| 756 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf',
|
| 757 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf',
|
| 758 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf',
|
| 759 |
+
|
| 760 |
+
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
| 761 |
+
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
|
| 762 |
+
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
|
| 763 |
+
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
|
| 764 |
+
|
| 765 |
+
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
| 766 |
+
KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
|
| 767 |
+
|
| 768 |
+
KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm',
|
| 769 |
+
KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm',
|
| 770 |
+
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
|
| 771 |
+
KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
|
| 772 |
+
KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
|
| 773 |
+
KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
|
| 774 |
+
KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
|
| 775 |
+
KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
|
| 776 |
+
|
| 777 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm',
|
| 778 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm',
|
| 779 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm',
|
| 780 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm',
|
| 781 |
+
|
| 782 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf',
|
| 783 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf',
|
| 784 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf',
|
| 785 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf',
|
| 786 |
+
|
| 787 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf',
|
| 788 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf',
|
| 789 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf',
|
| 790 |
+
KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf',
|
| 791 |
+
|
| 792 |
+
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: '_cooperative_q',
|
| 793 |
+
KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: '_pingpong_q',
|
| 794 |
+
KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs16',
|
| 795 |
+
KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs16',
|
| 796 |
+
KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs32',
|
| 797 |
+
KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs32',
|
| 798 |
+
|
| 799 |
+
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: '_q',
|
| 800 |
+
|
| 801 |
+
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: '_cooperative_q',
|
| 802 |
+
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: '_pingpong_q'
|
| 803 |
+
}
|
| 804 |
+
|
| 805 |
+
class EpilogueScheduleType(enum.Enum):
|
| 806 |
+
ScheduleAuto = enum_auto()
|
| 807 |
+
EpilogueTransposed = enum_auto()
|
| 808 |
+
NoSmemWarpSpecialized = enum_auto()
|
| 809 |
+
PtrArrayNoSmemWarpSpecialized = enum_auto()
|
| 810 |
+
NoSmemWarpSpecialized1Sm = enum_auto()
|
| 811 |
+
NoSmemWarpSpecialized2Sm = enum_auto()
|
| 812 |
+
FastF32NoSmemWarpSpecialized1Sm = enum_auto()
|
| 813 |
+
FastF32NoSmemWarpSpecialized2Sm = enum_auto()
|
| 814 |
+
BlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
|
| 815 |
+
BlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
|
| 816 |
+
PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
|
| 817 |
+
PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
|
| 818 |
+
PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto()
|
| 819 |
+
PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto()
|
| 820 |
+
PtrArrayBlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
|
| 821 |
+
PtrArrayBlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
|
| 822 |
+
TmaWarpSpecialized = enum_auto()
|
| 823 |
+
TmaWarpSpecializedCooperative = enum_auto()
|
| 824 |
+
TmaWarpSpecialized1Sm = enum_auto()
|
| 825 |
+
TmaWarpSpecialized2Sm = enum_auto()
|
| 826 |
+
PtrArrayTmaWarpSpecialized1Sm = enum_auto()
|
| 827 |
+
PtrArrayTmaWarpSpecialized2Sm = enum_auto()
|
| 828 |
+
PtrArrayTmaWarpSpecializedPingpong = enum_auto()
|
| 829 |
+
PtrArrayTmaWarpSpecializedCooperative = enum_auto()
|
| 830 |
+
|
| 831 |
+
#
|
| 832 |
+
EpilogueScheduleTag = {
|
| 833 |
+
EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto',
|
| 834 |
+
EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed',
|
| 835 |
+
EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized',
|
| 836 |
+
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized',
|
| 837 |
+
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::NoSmemWarpSpecialized1Sm',
|
| 838 |
+
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm',
|
| 839 |
+
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm',
|
| 840 |
+
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm',
|
| 841 |
+
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm',
|
| 842 |
+
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm',
|
| 843 |
+
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm',
|
| 844 |
+
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm',
|
| 845 |
+
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm',
|
| 846 |
+
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm',
|
| 847 |
+
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized1Sm',
|
| 848 |
+
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized2Sm',
|
| 849 |
+
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
|
| 850 |
+
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
|
| 851 |
+
EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
|
| 852 |
+
EpilogueScheduleType.TmaWarpSpecialized2Sm: 'cutlass::epilogue::TmaWarpSpecialized2Sm',
|
| 853 |
+
EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm',
|
| 854 |
+
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm',
|
| 855 |
+
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative',
|
| 856 |
+
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong',
|
| 857 |
+
}
|
| 858 |
+
|
| 859 |
+
#
|
| 860 |
+
EpilogueScheduleSuffixes = {
|
| 861 |
+
EpilogueScheduleType.ScheduleAuto: '',
|
| 862 |
+
EpilogueScheduleType.EpilogueTransposed: '',
|
| 863 |
+
EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem',
|
| 864 |
+
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem',
|
| 865 |
+
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: '_epi_nosmem',
|
| 866 |
+
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
| 867 |
+
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
|
| 868 |
+
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
|
| 869 |
+
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem',
|
| 870 |
+
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
| 871 |
+
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem',
|
| 872 |
+
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
| 873 |
+
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
|
| 874 |
+
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
|
| 875 |
+
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem',
|
| 876 |
+
EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
| 877 |
+
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
|
| 878 |
+
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
|
| 879 |
+
EpilogueScheduleType.TmaWarpSpecialized1Sm: '',
|
| 880 |
+
EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma',
|
| 881 |
+
EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '',
|
| 882 |
+
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_epi_tma',
|
| 883 |
+
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma',
|
| 884 |
+
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma',
|
| 885 |
+
}
|
| 886 |
+
|
| 887 |
+
class EpilogueFunctor3x(enum.Enum):
|
| 888 |
+
LinearCombination = enum_auto()
|
| 889 |
+
LinearCombinationBlockScaleFactor = enum_auto()
|
| 890 |
+
|
| 891 |
+
#
|
| 892 |
+
EpilogueFunctor3xTag = {
|
| 893 |
+
EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination',
|
| 894 |
+
EpilogueFunctor3x.LinearCombinationBlockScaleFactor: 'cutlass::epilogue::fusion::LinCombBlockScaleFactor',
|
| 895 |
+
}
|
| 896 |
+
|
| 897 |
+
# TMA epilogues have certain alignment requirements as calculated in get_tma_alignment(data_type)
|
| 898 |
+
def is_tma_epilogue(epilogue_schedule_type):
|
| 899 |
+
return epilogue_schedule_type in [
|
| 900 |
+
EpilogueScheduleType.ScheduleAuto,
|
| 901 |
+
EpilogueScheduleType.TmaWarpSpecialized,
|
| 902 |
+
EpilogueScheduleType.TmaWarpSpecializedCooperative,
|
| 903 |
+
EpilogueScheduleType.TmaWarpSpecialized1Sm,
|
| 904 |
+
EpilogueScheduleType.TmaWarpSpecialized2Sm,
|
| 905 |
+
EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
|
| 906 |
+
EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
|
| 907 |
+
EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
| 908 |
+
EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
| 909 |
+
]
|
| 910 |
+
|
| 911 |
+
def to_grouped_schedule(schedule, grouped):
|
| 912 |
+
if not grouped:
|
| 913 |
+
return schedule
|
| 914 |
+
|
| 915 |
+
group_schedule_map = {
|
| 916 |
+
# SM90
|
| 917 |
+
KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
| 918 |
+
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative,
|
| 919 |
+
KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong,
|
| 920 |
+
KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
| 921 |
+
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
|
| 922 |
+
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum,
|
| 923 |
+
EpilogueScheduleType.TmaWarpSpecialized : EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
|
| 924 |
+
EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
|
| 925 |
+
EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized,
|
| 926 |
+
# SM100
|
| 927 |
+
KernelScheduleType.TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100,
|
| 928 |
+
KernelScheduleType.TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100,
|
| 929 |
+
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100,
|
| 930 |
+
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100,
|
| 931 |
+
KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100,
|
| 932 |
+
KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100,
|
| 933 |
+
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100,
|
| 934 |
+
KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100,
|
| 935 |
+
KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100,
|
| 936 |
+
KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100,
|
| 937 |
+
EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
|
| 938 |
+
EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
|
| 939 |
+
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm,
|
| 940 |
+
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm,
|
| 941 |
+
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm,
|
| 942 |
+
EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm,
|
| 943 |
+
# SM103
|
| 944 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103,
|
| 945 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103,
|
| 946 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103,
|
| 947 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103,
|
| 948 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch,
|
| 949 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch,
|
| 950 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch,
|
| 951 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch,
|
| 952 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch,
|
| 953 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch,
|
| 954 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch,
|
| 955 |
+
KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch,
|
| 956 |
+
}
|
| 957 |
+
|
| 958 |
+
return group_schedule_map[schedule]
|
| 959 |
+
|
| 960 |
+
class TileSchedulerType(enum.Enum):
|
| 961 |
+
Default = enum_auto()
|
| 962 |
+
Persistent = enum_auto()
|
| 963 |
+
StreamK = enum_auto()
|
| 964 |
+
#
|
| 965 |
+
TileSchedulerTag = {
|
| 966 |
+
TileSchedulerType.Default: 'void',
|
| 967 |
+
TileSchedulerType.Persistent: 'cutlass::gemm::PersistentScheduler',
|
| 968 |
+
TileSchedulerType.StreamK: 'cutlass::gemm::StreamKScheduler',
|
| 969 |
+
}
|
| 970 |
+
|
| 971 |
+
#
|
| 972 |
+
TileSchedulerSuffixes = {
|
| 973 |
+
TileSchedulerType.Default: '',
|
| 974 |
+
TileSchedulerType.Persistent: '',
|
| 975 |
+
TileSchedulerType.StreamK: '_stream_k',
|
| 976 |
+
}
|
| 977 |
+
|
| 978 |
+
###################################################################################################
|
| 979 |
+
|
| 980 |
+
#
|
| 981 |
+
class SideMode(enum.Enum):
|
| 982 |
+
Left = enum_auto()
|
| 983 |
+
Right = enum_auto()
|
| 984 |
+
|
| 985 |
+
#
|
| 986 |
+
SideModeTag = {
|
| 987 |
+
SideMode.Left: 'cutlass::SideMode::kLeft',
|
| 988 |
+
SideMode.Right: 'cutlass::SideMode::kRight'
|
| 989 |
+
}
|
| 990 |
+
|
| 991 |
+
#
|
| 992 |
+
ShortSideModeNames = {
|
| 993 |
+
SideMode.Left: 'ls',
|
| 994 |
+
SideMode.Right: 'rs'
|
| 995 |
+
}
|
| 996 |
+
|
| 997 |
+
###################################################################################################
|
| 998 |
+
|
| 999 |
+
#
|
| 1000 |
+
class FillMode(enum.Enum):
|
| 1001 |
+
Lower = enum_auto()
|
| 1002 |
+
Upper = enum_auto()
|
| 1003 |
+
|
| 1004 |
+
#
|
| 1005 |
+
FillModeTag = {
|
| 1006 |
+
FillMode.Lower: 'cutlass::FillMode::kLower',
|
| 1007 |
+
FillMode.Upper: 'cutlass::FillMode::kUpper'
|
| 1008 |
+
}
|
| 1009 |
+
|
| 1010 |
+
#
|
| 1011 |
+
ShortFillModeNames = {
|
| 1012 |
+
FillMode.Lower: 'l',
|
| 1013 |
+
FillMode.Upper: 'u'
|
| 1014 |
+
}
|
| 1015 |
+
|
| 1016 |
+
###################################################################################################
|
| 1017 |
+
|
| 1018 |
+
#
|
| 1019 |
+
class DiagType(enum.Enum):
|
| 1020 |
+
NonUnit = enum_auto()
|
| 1021 |
+
Unit = enum_auto()
|
| 1022 |
+
|
| 1023 |
+
#
|
| 1024 |
+
DiagTypeTag = {
|
| 1025 |
+
DiagType.NonUnit: 'cutlass::DiagType::kNonUnit',
|
| 1026 |
+
DiagType.Unit: 'cutlass::DiagType::kUnit'
|
| 1027 |
+
}
|
| 1028 |
+
|
| 1029 |
+
#
|
| 1030 |
+
ShortDiagTypeNames = {
|
| 1031 |
+
DiagType.NonUnit: 'nu',
|
| 1032 |
+
DiagType.Unit: 'un'
|
| 1033 |
+
}
|
| 1034 |
+
|
| 1035 |
+
###################################################################################################
|
| 1036 |
+
|
| 1037 |
+
#
|
| 1038 |
+
class OpcodeClass(enum.Enum):
|
| 1039 |
+
Simt = enum_auto()
|
| 1040 |
+
TensorOp = enum_auto()
|
| 1041 |
+
WmmaTensorOp = enum_auto()
|
| 1042 |
+
SparseTensorOp = enum_auto()
|
| 1043 |
+
BlockScaledTensorOp = enum_auto()
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
OpcodeClassNames = {
|
| 1047 |
+
OpcodeClass.Simt: 'simt',
|
| 1048 |
+
OpcodeClass.TensorOp: 'tensorop',
|
| 1049 |
+
OpcodeClass.WmmaTensorOp: 'wmma_tensorop',
|
| 1050 |
+
OpcodeClass.SparseTensorOp: 'sptensorop',
|
| 1051 |
+
OpcodeClass.BlockScaledTensorOp: 'bstensorop'
|
| 1052 |
+
}
|
| 1053 |
+
|
| 1054 |
+
OpcodeClassTag = {
|
| 1055 |
+
OpcodeClass.Simt: 'cutlass::arch::OpClassSimt',
|
| 1056 |
+
OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp',
|
| 1057 |
+
OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp',
|
| 1058 |
+
OpcodeClass.SparseTensorOp: 'cutlass::arch::OpClassSparseTensorOp',
|
| 1059 |
+
OpcodeClass.BlockScaledTensorOp: 'cutlass::arch::OpClassBlockScaledTensorOp'
|
| 1060 |
+
}
|
| 1061 |
+
|
| 1062 |
+
###################################################################################################
|
| 1063 |
+
|
| 1064 |
+
#
|
| 1065 |
+
class OperationKind(enum.Enum):
|
| 1066 |
+
Gemm = enum_auto()
|
| 1067 |
+
RankK = enum_auto()
|
| 1068 |
+
Rank2K = enum_auto()
|
| 1069 |
+
Trmm = enum_auto()
|
| 1070 |
+
Symm = enum_auto()
|
| 1071 |
+
Conv2d = enum_auto()
|
| 1072 |
+
Conv3d = enum_auto()
|
| 1073 |
+
|
| 1074 |
+
#
|
| 1075 |
+
OperationKindNames = {
|
| 1076 |
+
OperationKind.Gemm: 'gemm'
|
| 1077 |
+
, OperationKind.RankK: 'rank_k'
|
| 1078 |
+
, OperationKind.Rank2K: 'rank_2k'
|
| 1079 |
+
, OperationKind.Trmm: 'trmm'
|
| 1080 |
+
, OperationKind.Symm: 'symm'
|
| 1081 |
+
, OperationKind.Conv2d: 'conv2d'
|
| 1082 |
+
, OperationKind.Conv3d: 'conv3d'
|
| 1083 |
+
}
|
| 1084 |
+
|
| 1085 |
+
#
|
| 1086 |
+
class Target(enum.Enum):
|
| 1087 |
+
library = enum_auto()
|
| 1088 |
+
#
|
| 1089 |
+
ArchitectureNames = {
|
| 1090 |
+
50: 'maxwell',
|
| 1091 |
+
60: 'pascal',
|
| 1092 |
+
61: 'pascal',
|
| 1093 |
+
70: 'volta',
|
| 1094 |
+
75: 'turing',
|
| 1095 |
+
80: 'ampere',
|
| 1096 |
+
89: 'ada',
|
| 1097 |
+
90: 'hopper'
|
| 1098 |
+
}
|
| 1099 |
+
|
| 1100 |
+
#
|
| 1101 |
+
SharedMemPerCC = {
|
| 1102 |
+
70: 96, # 96KB of SMEM
|
| 1103 |
+
72: 96, # 96KB of SMEM
|
| 1104 |
+
75: 64, # 64KB of SMEM
|
| 1105 |
+
80: 163, # 163KB of SMEM - 1KB reserved for the driver
|
| 1106 |
+
86: 99, # 99KB of SMEM - 1KB reserved for the driver
|
| 1107 |
+
87: 163, # 163KB of SMEM - 1KB reserved for the driver
|
| 1108 |
+
89: 99, # 99KB of SMEM - 1KB reserved for the driver
|
| 1109 |
+
90: 227, # 227KB of SMEM - 1KB reserved for the driver
|
| 1110 |
+
100: 227, # 227KB of SMEM - 1KB reserved for the driver
|
| 1111 |
+
}
|
| 1112 |
+
|
| 1113 |
+
###################################################################################################
|
| 1114 |
+
|
| 1115 |
+
#
|
| 1116 |
+
def SubstituteTemplate(template, values):
|
| 1117 |
+
text = template
|
| 1118 |
+
changed = True
|
| 1119 |
+
while changed:
|
| 1120 |
+
changed = False
|
| 1121 |
+
for key, value in values.items():
|
| 1122 |
+
regex = "\\$\\{%s\\}" % key
|
| 1123 |
+
newtext = re.sub(regex, value, text)
|
| 1124 |
+
if newtext != text:
|
| 1125 |
+
changed = True
|
| 1126 |
+
text = newtext
|
| 1127 |
+
return text
|
| 1128 |
+
|
| 1129 |
+
###################################################################################################
|
| 1130 |
+
|
| 1131 |
+
#
|
| 1132 |
+
class GemmKind(enum.Enum):
|
| 1133 |
+
Gemm = enum_auto()
|
| 1134 |
+
Sparse = enum_auto()
|
| 1135 |
+
Universal = enum_auto()
|
| 1136 |
+
Universal3x = enum_auto()
|
| 1137 |
+
SparseUniversal3x = enum_auto()
|
| 1138 |
+
PlanarComplex = enum_auto()
|
| 1139 |
+
PlanarComplexArray = enum_auto()
|
| 1140 |
+
Grouped = enum_auto()
|
| 1141 |
+
BlockScaledUniversal3x = enum_auto()
|
| 1142 |
+
GroupedUniversal3x = enum_auto()
|
| 1143 |
+
GroupedBlockScaledUniversal3x = enum_auto()
|
| 1144 |
+
BlockwiseUniversal3x = enum_auto()
|
| 1145 |
+
GroupedBlockwiseUniversal3x = enum_auto()
|
| 1146 |
+
|
| 1147 |
+
#
|
| 1148 |
+
GemmKindNames = {
|
| 1149 |
+
GemmKind.Gemm: "gemm",
|
| 1150 |
+
GemmKind.Sparse: "spgemm",
|
| 1151 |
+
GemmKind.Universal: "gemm",
|
| 1152 |
+
GemmKind.Universal3x: "gemm",
|
| 1153 |
+
GemmKind.SparseUniversal3x: "spgemm",
|
| 1154 |
+
GemmKind.PlanarComplex: "gemm_planar_complex",
|
| 1155 |
+
GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
|
| 1156 |
+
GemmKind.Grouped: "gemm_grouped",
|
| 1157 |
+
GemmKind.BlockScaledUniversal3x: "gemm",
|
| 1158 |
+
GemmKind.GroupedUniversal3x: "gemm_grouped",
|
| 1159 |
+
GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped",
|
| 1160 |
+
GemmKind.BlockwiseUniversal3x: "gemm",
|
| 1161 |
+
GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped"
|
| 1162 |
+
}
|
| 1163 |
+
|
| 1164 |
+
#
|
| 1165 |
+
class RankKKind(enum.Enum):
|
| 1166 |
+
Universal = enum_auto()
|
| 1167 |
+
|
| 1168 |
+
#
|
| 1169 |
+
RankKKindNames = {
|
| 1170 |
+
RankKKind.Universal: "rank_k"
|
| 1171 |
+
}
|
| 1172 |
+
|
| 1173 |
+
#
|
| 1174 |
+
class TrmmKind(enum.Enum):
|
| 1175 |
+
Universal = enum_auto()
|
| 1176 |
+
|
| 1177 |
+
#
|
| 1178 |
+
TrmmKindNames = {
|
| 1179 |
+
TrmmKind.Universal: "trmm"
|
| 1180 |
+
}
|
| 1181 |
+
|
| 1182 |
+
#
|
| 1183 |
+
class SymmKind(enum.Enum):
|
| 1184 |
+
Universal = enum_auto()
|
| 1185 |
+
|
| 1186 |
+
#
|
| 1187 |
+
SymmKindNames = {
|
| 1188 |
+
SymmKind.Universal: "symm"
|
| 1189 |
+
}
|
| 1190 |
+
|
| 1191 |
+
#
|
| 1192 |
+
class EpilogueFunctor(enum.Enum):
|
| 1193 |
+
LinearCombination = enum_auto()
|
| 1194 |
+
LinearCombinationClamp = enum_auto()
|
| 1195 |
+
|
| 1196 |
+
#
|
| 1197 |
+
EpilogueFunctorTag = {
|
| 1198 |
+
EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination',
|
| 1199 |
+
EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp',
|
| 1200 |
+
}
|
| 1201 |
+
|
| 1202 |
+
#
|
| 1203 |
+
class MixedInputMode(enum.Enum):
|
| 1204 |
+
ConvertOnly = enum_auto()
|
| 1205 |
+
ScaleOnly = enum_auto()
|
| 1206 |
+
ScaleWithZeroPoint = enum_auto()
|
| 1207 |
+
|
| 1208 |
+
#
|
| 1209 |
+
class SwizzlingFunctor(enum.Enum):
|
| 1210 |
+
Identity1 = enum_auto()
|
| 1211 |
+
Identity2 = enum_auto()
|
| 1212 |
+
Identity4 = enum_auto()
|
| 1213 |
+
Identity8 = enum_auto()
|
| 1214 |
+
Horizontal = enum_auto()
|
| 1215 |
+
StridedDgradIdentity1 = enum_auto()
|
| 1216 |
+
StridedDgradIdentity4 = enum_auto()
|
| 1217 |
+
StridedDgradHorizontal = enum_auto()
|
| 1218 |
+
StreamK = enum_auto()
|
| 1219 |
+
|
| 1220 |
+
#
|
| 1221 |
+
SwizzlingFunctorTag = {
|
| 1222 |
+
SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>',
|
| 1223 |
+
SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>',
|
| 1224 |
+
SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>',
|
| 1225 |
+
SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>',
|
| 1226 |
+
SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle',
|
| 1227 |
+
SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>',
|
| 1228 |
+
SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>',
|
| 1229 |
+
SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle',
|
| 1230 |
+
SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK',
|
| 1231 |
+
}
|
| 1232 |
+
|
| 1233 |
+
#
|
| 1234 |
+
class GroupScheduleMode(enum.Enum):
|
| 1235 |
+
Device = enum_auto(),
|
| 1236 |
+
Host = enum_auto()
|
| 1237 |
+
|
| 1238 |
+
#
|
| 1239 |
+
GroupScheduleModeTag = {
|
| 1240 |
+
GroupScheduleMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly',
|
| 1241 |
+
GroupScheduleMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute'
|
| 1242 |
+
}
|
| 1243 |
+
|
| 1244 |
+
#
|
| 1245 |
+
ShortGroupScheduleModeNames = {
|
| 1246 |
+
GroupScheduleMode.Device: 'Device',
|
| 1247 |
+
GroupScheduleMode.Host: 'Host'
|
| 1248 |
+
}
|
| 1249 |
+
|
| 1250 |
+
###################################################################################################
|
| 1251 |
+
|
| 1252 |
+
#
|
| 1253 |
+
class ConvKind(enum.IntEnum):
|
| 1254 |
+
Fprop = 0
|
| 1255 |
+
Dgrad = 1
|
| 1256 |
+
Wgrad = 2
|
| 1257 |
+
|
| 1258 |
+
#
|
| 1259 |
+
ConvKindTag = {
|
| 1260 |
+
ConvKind.Fprop: 'cutlass::conv::Operator::kFprop',
|
| 1261 |
+
ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad',
|
| 1262 |
+
ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad'
|
| 1263 |
+
}
|
| 1264 |
+
|
| 1265 |
+
ConvKindNames = {
|
| 1266 |
+
ConvKind.Fprop: 'fprop',
|
| 1267 |
+
ConvKind.Dgrad: 'dgrad',
|
| 1268 |
+
ConvKind.Wgrad: 'wgrad',
|
| 1269 |
+
}
|
| 1270 |
+
|
| 1271 |
+
class ConvMode(enum.IntEnum):
|
| 1272 |
+
CrossCorrelation = 0
|
| 1273 |
+
Convolution = 1
|
| 1274 |
+
|
| 1275 |
+
#
|
| 1276 |
+
class IteratorAlgorithm(enum.Enum):
|
| 1277 |
+
Analytic = 0
|
| 1278 |
+
Optimized = 1
|
| 1279 |
+
FixedChannels = 2
|
| 1280 |
+
FewChannels = 3
|
| 1281 |
+
FixedStrideDilation = 4
|
| 1282 |
+
|
| 1283 |
+
#
|
| 1284 |
+
IteratorAlgorithmTag = {
|
| 1285 |
+
IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic',
|
| 1286 |
+
IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized',
|
| 1287 |
+
IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels',
|
| 1288 |
+
IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels',
|
| 1289 |
+
IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation'
|
| 1290 |
+
}
|
| 1291 |
+
|
| 1292 |
+
IteratorAlgorithmNames = {
|
| 1293 |
+
IteratorAlgorithm.Analytic: 'analytic',
|
| 1294 |
+
IteratorAlgorithm.Optimized: 'optimized',
|
| 1295 |
+
IteratorAlgorithm.FixedChannels: 'fixed_channels',
|
| 1296 |
+
IteratorAlgorithm.FewChannels: 'few_channels',
|
| 1297 |
+
IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation'
|
| 1298 |
+
}
|
| 1299 |
+
|
| 1300 |
+
#
|
| 1301 |
+
class StrideSupport(enum.Enum):
|
| 1302 |
+
Strided = 0
|
| 1303 |
+
Unity = 1
|
| 1304 |
+
Fixed = 2
|
| 1305 |
+
|
| 1306 |
+
#
|
| 1307 |
+
StrideSupportTag = {
|
| 1308 |
+
StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided',
|
| 1309 |
+
StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity',
|
| 1310 |
+
StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed'
|
| 1311 |
+
}
|
| 1312 |
+
|
| 1313 |
+
StrideSupportNames = {
|
| 1314 |
+
StrideSupport.Strided: '',
|
| 1315 |
+
StrideSupport.Unity: 'unity_stride',
|
| 1316 |
+
StrideSupport.Fixed: 'fixed_stride'
|
| 1317 |
+
}
|
| 1318 |
+
|
| 1319 |
+
#
|
| 1320 |
+
class GroupMode(enum.Enum):
|
| 1321 |
+
NoneGroup = enum_auto() # dense conv (G=1)
|
| 1322 |
+
SingleGroup = enum_auto() # grouped convolution (single group per CTA)
|
| 1323 |
+
MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA)
|
| 1324 |
+
Depthwise = enum_auto() # Depthwise convolution ( C=K=G )
|
| 1325 |
+
|
| 1326 |
+
#
|
| 1327 |
+
GroupModeTag = {
|
| 1328 |
+
GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone',
|
| 1329 |
+
GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup',
|
| 1330 |
+
GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup',
|
| 1331 |
+
GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise',
|
| 1332 |
+
}
|
| 1333 |
+
|
| 1334 |
+
GroupModeNames = {
|
| 1335 |
+
GroupMode.NoneGroup: '',
|
| 1336 |
+
GroupMode.SingleGroup: 'single_group',
|
| 1337 |
+
GroupMode.MultipleGroup: 'multiple_group',
|
| 1338 |
+
GroupMode.Depthwise: 'depthwise',
|
| 1339 |
+
}
|
| 1340 |
+
|
| 1341 |
+
DynamicClusterShape = [0, 0, 1]
|
| 1342 |
+
|
| 1343 |
+
###################################################################################################
|
| 1344 |
+
|
| 1345 |
+
#
|
| 1346 |
+
class MathInstruction:
|
| 1347 |
+
def __init__(self,
|
| 1348 |
+
instruction_shape, \
|
| 1349 |
+
element_a, element_b, element_accumulator, \
|
| 1350 |
+
opcode_class, math_operation = MathOperation.multiply_add \
|
| 1351 |
+
, element_scale_factor = None
|
| 1352 |
+
):
|
| 1353 |
+
|
| 1354 |
+
self.instruction_shape = instruction_shape
|
| 1355 |
+
self.element_a = element_a
|
| 1356 |
+
self.element_b = element_b
|
| 1357 |
+
self.element_accumulator = element_accumulator
|
| 1358 |
+
self.opcode_class = opcode_class
|
| 1359 |
+
self.math_operation = math_operation
|
| 1360 |
+
self.element_scale_factor = element_scale_factor
|
| 1361 |
+
|
| 1362 |
+
#
|
| 1363 |
+
class TileDescription:
|
| 1364 |
+
|
| 1365 |
+
def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1], explicit_vector_sizes = None):
|
| 1366 |
+
self.threadblock_shape = threadblock_shape
|
| 1367 |
+
self.tile_shape = threadblock_shape
|
| 1368 |
+
self.stages = stages
|
| 1369 |
+
self.warp_count = warp_count
|
| 1370 |
+
self.math_instruction = math_instruction
|
| 1371 |
+
self.minimum_compute_capability = min_compute
|
| 1372 |
+
self.maximum_compute_capability = max_compute
|
| 1373 |
+
self.cluster_shape = cluster_shape
|
| 1374 |
+
self.explicit_vector_sizes = explicit_vector_sizes
|
| 1375 |
+
|
| 1376 |
+
def procedural_name(self):
|
| 1377 |
+
if self.minimum_compute_capability >= 90:
|
| 1378 |
+
return "{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}".format(
|
| 1379 |
+
tbm = self.threadblock_shape[0],
|
| 1380 |
+
tbn = self.threadblock_shape[1],
|
| 1381 |
+
tbk = self.threadblock_shape[2],
|
| 1382 |
+
cm = self.cluster_shape[0],
|
| 1383 |
+
cn = self.cluster_shape[1],
|
| 1384 |
+
ck = self.cluster_shape[2],
|
| 1385 |
+
s = self.stages)
|
| 1386 |
+
else:
|
| 1387 |
+
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
|
| 1388 |
+
|
| 1389 |
+
#
|
| 1390 |
+
class Direct2dConvFixedStrideDilationTileDescription:
|
| 1391 |
+
def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute):
|
| 1392 |
+
self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]]
|
| 1393 |
+
self.threadblock_output_shape = threadblock_output_shape
|
| 1394 |
+
self.filter_shape = filter_shape
|
| 1395 |
+
self.stages = stages
|
| 1396 |
+
self.warp_count = warp_count
|
| 1397 |
+
self.stride = stride
|
| 1398 |
+
self.dilation = dilation
|
| 1399 |
+
self.math_instruction = math_instruction
|
| 1400 |
+
self.minimum_compute_capability = min_compute
|
| 1401 |
+
self.maximum_compute_capability = max_compute
|
| 1402 |
+
|
| 1403 |
+
def procedural_name(self):
|
| 1404 |
+
str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0],
|
| 1405 |
+
self.threadblock_shape[1],
|
| 1406 |
+
self.threadblock_shape[2],
|
| 1407 |
+
self.threadblock_output_shape[0],
|
| 1408 |
+
self.threadblock_output_shape[1],
|
| 1409 |
+
self.threadblock_output_shape[2],
|
| 1410 |
+
self.threadblock_output_shape[3],
|
| 1411 |
+
self.stages,
|
| 1412 |
+
self.filter_shape[0],
|
| 1413 |
+
self.filter_shape[1])
|
| 1414 |
+
# Fixed Strided and dilation
|
| 1415 |
+
if self.stride != [-1, -1] and self.dilation != [-1, -1]:
|
| 1416 |
+
str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0],
|
| 1417 |
+
self.stride[1],
|
| 1418 |
+
self.dilation[0],
|
| 1419 |
+
self.dilation[1])
|
| 1420 |
+
return str_name
|
| 1421 |
+
|
| 1422 |
+
#
|
| 1423 |
+
class Direct2dConvFixedStrideDilationTileDescription:
|
| 1424 |
+
def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute):
|
| 1425 |
+
self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]]
|
| 1426 |
+
self.threadblock_output_shape = threadblock_output_shape
|
| 1427 |
+
self.filter_shape = filter_shape
|
| 1428 |
+
self.stages = stages
|
| 1429 |
+
self.warp_count = warp_count
|
| 1430 |
+
self.stride = stride
|
| 1431 |
+
self.dilation = dilation
|
| 1432 |
+
self.math_instruction = math_instruction
|
| 1433 |
+
self.minimum_compute_capability = min_compute
|
| 1434 |
+
self.maximum_compute_capability = max_compute
|
| 1435 |
+
|
| 1436 |
+
def procedural_name(self):
|
| 1437 |
+
str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0],
|
| 1438 |
+
self.threadblock_shape[1],
|
| 1439 |
+
self.threadblock_shape[2],
|
| 1440 |
+
self.threadblock_output_shape[0],
|
| 1441 |
+
self.threadblock_output_shape[1],
|
| 1442 |
+
self.threadblock_output_shape[2],
|
| 1443 |
+
self.threadblock_output_shape[3],
|
| 1444 |
+
self.stages,
|
| 1445 |
+
self.filter_shape[0],
|
| 1446 |
+
self.filter_shape[1])
|
| 1447 |
+
# Fixed Strided and dilation
|
| 1448 |
+
if self.stride != [-1, -1] and self.dilation != [-1, -1]:
|
| 1449 |
+
str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0],
|
| 1450 |
+
self.stride[1],
|
| 1451 |
+
self.dilation[0],
|
| 1452 |
+
self.dilation[1])
|
| 1453 |
+
return str_name
|
| 1454 |
+
|
| 1455 |
+
#
|
| 1456 |
+
class TensorDescription:
|
| 1457 |
+
def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none):
|
| 1458 |
+
self.element = element
|
| 1459 |
+
self.layout = layout
|
| 1460 |
+
self.alignment = alignment
|
| 1461 |
+
self.complex_transform = complex_transform
|
| 1462 |
+
|
| 1463 |
+
#
|
| 1464 |
+
class SymmetricTensorDescription:
|
| 1465 |
+
def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left):
|
| 1466 |
+
self.element = element
|
| 1467 |
+
self.layout = layout
|
| 1468 |
+
self.fill_mode = fill_mode
|
| 1469 |
+
self.alignment = alignment
|
| 1470 |
+
self.complex_transform = complex_transform
|
| 1471 |
+
self.side_mode = side_mode
|
| 1472 |
+
|
| 1473 |
+
#
|
| 1474 |
+
class TriangularTensorDescription:
|
| 1475 |
+
def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none):
|
| 1476 |
+
self.element = element
|
| 1477 |
+
self.layout = layout
|
| 1478 |
+
self.side_mode = side_mode
|
| 1479 |
+
self.fill_mode = fill_mode
|
| 1480 |
+
self.diag_type = diag_type
|
| 1481 |
+
self.alignment = alignment
|
| 1482 |
+
self.complex_transform = complex_transform
|
| 1483 |
+
|
| 1484 |
+
#
|
| 1485 |
+
def CalculateSmemUsage(operation):
|
| 1486 |
+
cta_shape = operation.tile_description.threadblock_shape
|
| 1487 |
+
stages = operation.tile_description.stages
|
| 1488 |
+
|
| 1489 |
+
if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse:
|
| 1490 |
+
# Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity)
|
| 1491 |
+
if DataTypeSize[operation.A.element] == 32:
|
| 1492 |
+
elements_per_8b_md = 2
|
| 1493 |
+
elif DataTypeSize[operation.A.element] == 4:
|
| 1494 |
+
elements_per_8b_md = 8
|
| 1495 |
+
else:
|
| 1496 |
+
elements_per_8b_md = 4
|
| 1497 |
+
|
| 1498 |
+
smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \
|
| 1499 |
+
DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \
|
| 1500 |
+
cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md
|
| 1501 |
+
else:
|
| 1502 |
+
# Few BLAS3 operations only have A tensor
|
| 1503 |
+
data_type_size_a = DataTypeSize[operation.A.element]
|
| 1504 |
+
data_type_size_b = DataTypeSize[operation.A.element]
|
| 1505 |
+
if operation.is_mixed_input():
|
| 1506 |
+
data_type_size_b = DataTypeSize[operation.B.element]
|
| 1507 |
+
|
| 1508 |
+
smem_per_stage = data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + \
|
| 1509 |
+
data_type_size_b * cta_shape[1] * cta_shape[2] // 8
|
| 1510 |
+
|
| 1511 |
+
smem_usage = smem_per_stage * stages
|
| 1512 |
+
return (smem_usage >> 10)
|
| 1513 |
+
|
| 1514 |
+
|
| 1515 |
+
class GemmUniversalMode(enum.IntEnum):
|
| 1516 |
+
"""
|
| 1517 |
+
Types corresponding to GemmUniversalMode
|
| 1518 |
+
"""
|
| 1519 |
+
Gemm = 0
|
| 1520 |
+
GemmSplitKParallel = 1
|
| 1521 |
+
Batched = 2
|
| 1522 |
+
Array = 3
|
| 1523 |
+
|
| 1524 |
+
|
| 1525 |
+
class SplitKMode(enum.IntEnum):
|
| 1526 |
+
"""
|
| 1527 |
+
Types corresponding to SplitKMode
|
| 1528 |
+
"""
|
| 1529 |
+
NoneSplitK = 0
|
| 1530 |
+
Serial = 1
|
| 1531 |
+
Parallel = 2
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py
ADDED
|
@@ -0,0 +1,868 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for filtering CUTLASS library kernels and emitting library intitialization
|
| 35 |
+
and building code
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
import enum
|
| 39 |
+
import logging
|
| 40 |
+
import os.path
|
| 41 |
+
import shutil
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
import builtins
|
| 45 |
+
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 46 |
+
raise ImportError("Disabling attempt to import cutlass_library")
|
| 47 |
+
from cutlass_library.library import *
|
| 48 |
+
from cutlass_library.gemm_operation import *
|
| 49 |
+
from cutlass_library.rank_k_operation import *
|
| 50 |
+
from cutlass_library.rank_2k_operation import *
|
| 51 |
+
from cutlass_library.trmm_operation import *
|
| 52 |
+
from cutlass_library.symm_operation import *
|
| 53 |
+
from cutlass_library.conv2d_operation import *
|
| 54 |
+
from cutlass_library.conv3d_operation import *
|
| 55 |
+
except ImportError:
|
| 56 |
+
from library import *
|
| 57 |
+
from gemm_operation import *
|
| 58 |
+
from rank_k_operation import *
|
| 59 |
+
from rank_2k_operation import *
|
| 60 |
+
from trmm_operation import *
|
| 61 |
+
from symm_operation import *
|
| 62 |
+
from conv2d_operation import *
|
| 63 |
+
from conv3d_operation import *
|
| 64 |
+
|
| 65 |
+
###################################################################################################
|
| 66 |
+
_LOGGER = logging.getLogger(__name__)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class EmitOperationKindAll:
|
| 70 |
+
"""
|
| 71 |
+
Emit the OperationKind-level CUTLASS library initialization code.
|
| 72 |
+
The code is generated in the {generated_path}/{operation_kind} directory
|
| 73 |
+
(e.g., tools/library/generated/gemm in the build directory,
|
| 74 |
+
for OperationKind=Gemm), in the all_{operation_kind}_operations.cu file
|
| 75 |
+
(e.g., all_gemm_operations.cu for OperationKind=Gemm).
|
| 76 |
+
That file declares several functions in namespace cutlass::library.
|
| 77 |
+
The functions all have this form,
|
| 78 |
+
|
| 79 |
+
void initialize_{configuration_name}(Manifest& manifest);
|
| 80 |
+
|
| 81 |
+
The file also _defines_ the following function in that namespace.
|
| 82 |
+
|
| 83 |
+
void initialize_all_{operation_kind}_operations(Manifest& manifest);
|
| 84 |
+
|
| 85 |
+
That function calls all of the functions declared in this file.
|
| 86 |
+
Those functions are defined in subdirectories
|
| 87 |
+
(which this class does not create).
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, generated_path, kind, args):
|
| 91 |
+
self.generated_path = generated_path
|
| 92 |
+
self.kind = kind
|
| 93 |
+
self.args = args
|
| 94 |
+
|
| 95 |
+
self.header_template ="""
|
| 96 |
+
/*
|
| 97 |
+
Generated by manifest.py - Do not edit.
|
| 98 |
+
*/
|
| 99 |
+
|
| 100 |
+
#include "cutlass/cutlass.h"
|
| 101 |
+
#include "cutlass/library/library.h"
|
| 102 |
+
#include "cutlass/library/manifest.h"
|
| 103 |
+
|
| 104 |
+
namespace cutlass {
|
| 105 |
+
namespace library {
|
| 106 |
+
|
| 107 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 108 |
+
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
self.entry_template = """
|
| 112 |
+
|
| 113 |
+
//
|
| 114 |
+
// Entry point to construct operations
|
| 115 |
+
//
|
| 116 |
+
void initialize_all_${operation_name}_operations(Manifest &manifest) {
|
| 117 |
+
"""
|
| 118 |
+
self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n"
|
| 119 |
+
self.configuration_template =" initialize_${configuration_name}(manifest);\n"
|
| 120 |
+
|
| 121 |
+
self.epilogue_template ="""}
|
| 122 |
+
|
| 123 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 124 |
+
|
| 125 |
+
} // namespace library
|
| 126 |
+
} // namespace cutlass
|
| 127 |
+
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
#
|
| 131 |
+
def __enter__(self):
|
| 132 |
+
_LOGGER.debug("*** EmitOperationKindAll::__enter__")
|
| 133 |
+
|
| 134 |
+
self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind])
|
| 135 |
+
_LOGGER.debug('*** operation_path (directory to create): ' +
|
| 136 |
+
str(self.operation_path));
|
| 137 |
+
os.makedirs(self.operation_path, exist_ok=True)
|
| 138 |
+
|
| 139 |
+
self.top_level_path = os.path.join(self.operation_path, f"all_{OperationKindNames[self.kind]}_operations.cu")
|
| 140 |
+
_LOGGER.debug(f"*** top_level_path (file to write): {str(self.top_level_path)}")
|
| 141 |
+
|
| 142 |
+
self.top_level_file = open(self.top_level_path, "w")
|
| 143 |
+
self.top_level_file.write(self.header_template)
|
| 144 |
+
|
| 145 |
+
self.source_files = [self.top_level_path,]
|
| 146 |
+
|
| 147 |
+
self.configurations = []
|
| 148 |
+
|
| 149 |
+
return self
|
| 150 |
+
|
| 151 |
+
#
|
| 152 |
+
def emit(self, operations):
|
| 153 |
+
_LOGGER.debug('*** EmitOperationKindAll::emit')
|
| 154 |
+
_LOGGER.debug(f"*** len(operations): {len(operations)}")
|
| 155 |
+
_LOGGER.debug(f"*** min_cc list: {sorted(min_cc for min_cc, _ in operations.items())}")
|
| 156 |
+
|
| 157 |
+
for min_cc, configurations in sorted(operations.items()):
|
| 158 |
+
_LOGGER.debug(f"*** min_cc={min_cc}")
|
| 159 |
+
|
| 160 |
+
for configuration_name, _ in configurations.items():
|
| 161 |
+
_LOGGER.debug(f"*** configuration_name={configuration_name}")
|
| 162 |
+
self.configurations.append(configuration_name)
|
| 163 |
+
self.top_level_file.write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} ))
|
| 164 |
+
|
| 165 |
+
#
|
| 166 |
+
def __exit__(self, exception_type, exception_value, traceback):
|
| 167 |
+
_LOGGER.debug("*** EmitOperationKindAll::__exit__")
|
| 168 |
+
|
| 169 |
+
self.top_level_file.write(SubstituteTemplate(self.entry_template, {'operation_name': OperationKindNames[self.kind]}))
|
| 170 |
+
|
| 171 |
+
for configuration_name in self.configurations:
|
| 172 |
+
self.top_level_file.write(SubstituteTemplate(self.configuration_template, {'configuration_name': configuration_name}))
|
| 173 |
+
|
| 174 |
+
self.top_level_file.write(self.epilogue_template)
|
| 175 |
+
self.top_level_file.close()
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class EmitOperationKindLibrary:
|
| 179 |
+
"""
|
| 180 |
+
Emit the CUTLASS library initialization code for each OperationKind.
|
| 181 |
+
The code is generated in the directory
|
| 182 |
+
{generated_path}/{operation_kind}/{min_cc}
|
| 183 |
+
(e.g., tools/library/generated/gemm/90 in the build directory,
|
| 184 |
+
for min_cc=90 and OperationKind=Gemm), in the file
|
| 185 |
+
all_sm{min_cc}_{operation_kind}_operations.cu
|
| 186 |
+
(e.g., all_sm90_gemm_operations.cu for min_cc=90 and OperationKind=Gemm).
|
| 187 |
+
The min_cc variable here indicates the minimum GPU architecture version
|
| 188 |
+
that the things to be initialized require.
|
| 189 |
+
For example, min_cc=90 indicates sm90.
|
| 190 |
+
|
| 191 |
+
That file declares several functions in namespace cutlass::library.
|
| 192 |
+
The functions all have this form,
|
| 193 |
+
|
| 194 |
+
void initialize_all_sm{min_cc}_{subclass_name}_{extended_name}_operations(Manifest& manifest);
|
| 195 |
+
|
| 196 |
+
where extended_name is operation.extended_name() for all the operations
|
| 197 |
+
given to the emit method (which see below). (All operations for a given
|
| 198 |
+
configuration_name are guaranteed to have the same extended_name().)
|
| 199 |
+
|
| 200 |
+
The file also _defines_ the following function in that namespace.
|
| 201 |
+
|
| 202 |
+
void initialize_all_sm{min_cc}__{operation_kind}_operations(Manifest& manifest);
|
| 203 |
+
|
| 204 |
+
That function calls all of the functions declared in this file.
|
| 205 |
+
Those functions are defined in subdirectories.
|
| 206 |
+
The mapping from OperationKind to emitter handles the details
|
| 207 |
+
of what happens in each of those subdirectories.
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
def __init__(self, generated_path, min_cc, kind, args):
|
| 211 |
+
self.generated_path = generated_path
|
| 212 |
+
self.min_cc = min_cc
|
| 213 |
+
self.kind = kind
|
| 214 |
+
self.args = args
|
| 215 |
+
self.emitters = {
|
| 216 |
+
OperationKind.Gemm: EmitGemmConfigurationLibrary,
|
| 217 |
+
OperationKind.Conv2d: EmitConv2dConfigurationLibrary,
|
| 218 |
+
OperationKind.Conv3d: EmitConv3dConfigurationLibrary,
|
| 219 |
+
OperationKind.RankK: EmitRankKConfigurationLibrary,
|
| 220 |
+
OperationKind.Rank2K: EmitRank2KConfigurationLibrary,
|
| 221 |
+
OperationKind.Trmm: EmitTrmmConfigurationLibrary,
|
| 222 |
+
OperationKind.Symm: EmitSymmConfigurationLibrary
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
self.header_template ="""
|
| 226 |
+
/*
|
| 227 |
+
Generated by manifest.py - Do not edit.
|
| 228 |
+
*/
|
| 229 |
+
|
| 230 |
+
#include "cutlass/cutlass.h"
|
| 231 |
+
#include "cutlass/library/library.h"
|
| 232 |
+
#include "cutlass/library/manifest.h"
|
| 233 |
+
|
| 234 |
+
namespace cutlass {
|
| 235 |
+
namespace library {
|
| 236 |
+
|
| 237 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 238 |
+
|
| 239 |
+
"""
|
| 240 |
+
self.entry_template = """
|
| 241 |
+
|
| 242 |
+
//
|
| 243 |
+
// Entry point to construct operations
|
| 244 |
+
//
|
| 245 |
+
void initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest) {
|
| 246 |
+
"""
|
| 247 |
+
self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n"
|
| 248 |
+
self.configuration_template = " initialize_${configuration_name}(manifest);\n"
|
| 249 |
+
self.subclass_call_template = " initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(manifest);\n"
|
| 250 |
+
self.subclass_prototype_template = "void initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest);\n"
|
| 251 |
+
self.epilogue_template ="""}
|
| 252 |
+
|
| 253 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 254 |
+
|
| 255 |
+
} // namespace library
|
| 256 |
+
} // namespace cutlass
|
| 257 |
+
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
#
|
| 261 |
+
def __enter__(self):
|
| 262 |
+
_LOGGER.debug("*** EmitOperationKindLibrary::__enter__")
|
| 263 |
+
_LOGGER.debug(f"*** generated_path: {str(self.generated_path)}")
|
| 264 |
+
_LOGGER.debug(f"*** OperationKindNames[kind]: {OperationKindNames[self.kind]}")
|
| 265 |
+
_LOGGER.debug(f"*** min_cc: {self.min_cc}")
|
| 266 |
+
|
| 267 |
+
self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind], str(self.min_cc))
|
| 268 |
+
_LOGGER.debug(f"*** operation_path (directory to make): {str(self.operation_path)}")
|
| 269 |
+
os.makedirs(self.operation_path)
|
| 270 |
+
|
| 271 |
+
self.top_level_path = os.path.join(self.operation_path, f"all_sm{self.min_cc}_{OperationKindNames[self.kind]}_operations.cu")
|
| 272 |
+
_LOGGER.debug(f"*** top_level_path (file to write): {str(self.top_level_path)}")
|
| 273 |
+
|
| 274 |
+
self.top_level_file = open(self.top_level_path, "w")
|
| 275 |
+
self.top_level_file.write(self.header_template)
|
| 276 |
+
|
| 277 |
+
self.source_files = {}
|
| 278 |
+
|
| 279 |
+
# Each {operation_kind x cc} combination is further decomposed by the instruction
|
| 280 |
+
# types used. This dictionary used to track the file handles for the top-level
|
| 281 |
+
# files of each subclass
|
| 282 |
+
self.subclass_files = {}
|
| 283 |
+
|
| 284 |
+
# Configurations in each sub class
|
| 285 |
+
self.subclass_configurations = {}
|
| 286 |
+
|
| 287 |
+
return self
|
| 288 |
+
|
| 289 |
+
#
|
| 290 |
+
def emit(self, configuration_name, operations):
|
| 291 |
+
_LOGGER.debug("*** EmitOperationKindLibrary::emit")
|
| 292 |
+
_LOGGER.debug(f"*** configuration_name: {configuration_name}")
|
| 293 |
+
|
| 294 |
+
assert len(operations) > 0
|
| 295 |
+
|
| 296 |
+
# The extended name for all operations of a given configuration_name is guaranteed
|
| 297 |
+
# to be the same because extended_name() is used in defining configuration_name. Thus,
|
| 298 |
+
# we can safely use the extended_name() of the first operation.
|
| 299 |
+
extended_name = operations[0].extended_name()
|
| 300 |
+
_LOGGER.debug('*** extended_name (for all ops): ' + extended_name)
|
| 301 |
+
|
| 302 |
+
# Create a directory for operations with this subclass if it does not exist
|
| 303 |
+
if extended_name not in self.subclass_files:
|
| 304 |
+
subclass_path = os.path.join(self.operation_path, extended_name)
|
| 305 |
+
_LOGGER.debug(f"*** subclass_path: {str(subclass_path)}")
|
| 306 |
+
os.mkdir(subclass_path)
|
| 307 |
+
|
| 308 |
+
self.subclass_configurations[extended_name] = []
|
| 309 |
+
|
| 310 |
+
# Open a new top-level file for this sub class
|
| 311 |
+
subclass_top_level_path = os.path.join(
|
| 312 |
+
subclass_path, f"all_sm{self.min_cc}_{extended_name}_{OperationKindNames[self.kind]}_operations.cu")
|
| 313 |
+
_LOGGER.debug('*** subclass_top_level_path (min_cc, extended_name, ' +
|
| 314 |
+
'OperationKind): ' + str(subclass_top_level_path))
|
| 315 |
+
|
| 316 |
+
self.subclass_files[extended_name] = open(subclass_top_level_path, "w")
|
| 317 |
+
self.subclass_files[extended_name].write(self.header_template)
|
| 318 |
+
|
| 319 |
+
self.source_files[extended_name] = [subclass_top_level_path]
|
| 320 |
+
|
| 321 |
+
subclass_dir = os.path.dirname(self.subclass_files[extended_name].name)
|
| 322 |
+
_LOGGER.debug('*** subclass_dir: ' + str(subclass_dir))
|
| 323 |
+
|
| 324 |
+
with self.emitters[self.kind](subclass_dir, configuration_name) as configuration_emitter:
|
| 325 |
+
for operation in operations:
|
| 326 |
+
configuration_emitter.emit(operation)
|
| 327 |
+
|
| 328 |
+
_LOGGER.debug('*** configuration_emitter.configuration_path: ' +
|
| 329 |
+
str(configuration_emitter.configuration_path))
|
| 330 |
+
self.source_files[extended_name].append(configuration_emitter.configuration_path)
|
| 331 |
+
|
| 332 |
+
self.subclass_configurations[extended_name].append(configuration_name)
|
| 333 |
+
self.subclass_files[extended_name].write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} ))
|
| 334 |
+
|
| 335 |
+
#
|
| 336 |
+
def __exit__(self, exception_type, exception_value, traceback):
|
| 337 |
+
_LOGGER.debug("*** EmitOperationKindLibrary::__exit__")
|
| 338 |
+
for subclass_name, subclass_file in sorted(self.subclass_files.items()):
|
| 339 |
+
subclass_cfg = {
|
| 340 |
+
'min_cc': str(self.min_cc),
|
| 341 |
+
'subclass_name': subclass_name,
|
| 342 |
+
'operation_name': OperationKindNames[self.kind]
|
| 343 |
+
}
|
| 344 |
+
self.top_level_file.write(SubstituteTemplate(self.subclass_prototype_template, subclass_cfg))
|
| 345 |
+
|
| 346 |
+
self.top_level_file.write(
|
| 347 |
+
SubstituteTemplate(self.entry_template, {
|
| 348 |
+
'min_cc': str(self.min_cc),
|
| 349 |
+
'subclass_name': '',
|
| 350 |
+
'operation_name': OperationKindNames[self.kind]
|
| 351 |
+
}))
|
| 352 |
+
|
| 353 |
+
# Finish and close all subclass files
|
| 354 |
+
for subclass_name, subclass_file in sorted(self.subclass_files.items()):
|
| 355 |
+
subclass_cfg = {
|
| 356 |
+
'min_cc': str(self.min_cc),
|
| 357 |
+
'subclass_name': subclass_name,
|
| 358 |
+
'operation_name': OperationKindNames[self.kind]
|
| 359 |
+
}
|
| 360 |
+
subclass_file.write(SubstituteTemplate(self.entry_template, subclass_cfg))
|
| 361 |
+
|
| 362 |
+
for configuration in self.subclass_configurations[subclass_name]:
|
| 363 |
+
subclass_file.write(
|
| 364 |
+
SubstituteTemplate(self.configuration_template, {
|
| 365 |
+
'configuration_name': configuration
|
| 366 |
+
}))
|
| 367 |
+
|
| 368 |
+
subclass_file.write(self.epilogue_template)
|
| 369 |
+
subclass_file.close()
|
| 370 |
+
|
| 371 |
+
# Write the call to initialize_all for this subclass to the top-level file
|
| 372 |
+
self.top_level_file.write(SubstituteTemplate(self.subclass_call_template, subclass_cfg))
|
| 373 |
+
|
| 374 |
+
self.top_level_file.write(self.epilogue_template)
|
| 375 |
+
self.top_level_file.close()
|
| 376 |
+
|
| 377 |
+
class EmitInterfaceLibrary:
|
| 378 |
+
"""
|
| 379 |
+
Emit the topmost-level CUTLASS library initialization code.
|
| 380 |
+
The code is generated in the generated_path directory
|
| 381 |
+
(e.g., tools/library/generated in the build directory),
|
| 382 |
+
in the initialize_all.cpp file.
|
| 383 |
+
That file declares several functions in namespace cutlass::library.
|
| 384 |
+
The functions all have this form,
|
| 385 |
+
|
| 386 |
+
void initialize_all_{operation_kind}_operations(Manifest& manifest);
|
| 387 |
+
|
| 388 |
+
where {operation_kind} abbreviates the "kind" of operation
|
| 389 |
+
(e.g., gemm for matrix-matrix multiply, conv2d for 2-d convolution,
|
| 390 |
+
or trmm for triangular solve with multiple right-hand sides).
|
| 391 |
+
The definitions of these functions live in subdirectories.
|
| 392 |
+
|
| 393 |
+
The file also _defines_ the following function in that namespace.
|
| 394 |
+
|
| 395 |
+
void initialize_all(Manifest& manifest);
|
| 396 |
+
|
| 397 |
+
That function first prepares the manifest, and then
|
| 398 |
+
calls all of the functions declared in this file.
|
| 399 |
+
"""
|
| 400 |
+
|
| 401 |
+
def __init__(self, generated_path, operation_count, args):
|
| 402 |
+
self.generated_path = generated_path
|
| 403 |
+
self.args = args
|
| 404 |
+
|
| 405 |
+
self.prototypes = []
|
| 406 |
+
self.fn_calls = []
|
| 407 |
+
self.operation_count = str(operation_count)
|
| 408 |
+
|
| 409 |
+
self.top_level_hdr_template = '''
|
| 410 |
+
/*
|
| 411 |
+
Generated by manifest.py - Do not edit.
|
| 412 |
+
*/
|
| 413 |
+
'''
|
| 414 |
+
self.top_level_prologue = '''
|
| 415 |
+
|
| 416 |
+
#include "cutlass/library/library.h"
|
| 417 |
+
#include "cutlass/library/manifest.h"
|
| 418 |
+
|
| 419 |
+
namespace cutlass {
|
| 420 |
+
\tnamespace library {
|
| 421 |
+
|
| 422 |
+
${prototypes}
|
| 423 |
+
'''
|
| 424 |
+
|
| 425 |
+
self.top_level_initialize_kind = '''
|
| 426 |
+
\t\tvoid initialize_all_${kind}_operations(Manifest &manifest) {
|
| 427 |
+
${fn_calls}
|
| 428 |
+
\t\t}
|
| 429 |
+
'''
|
| 430 |
+
|
| 431 |
+
self.top_level_initialize = '''
|
| 432 |
+
\t\tvoid initialize_all(Manifest &manifest) {
|
| 433 |
+
\t\t\tmanifest.reserve(${operation_count});\n
|
| 434 |
+
${fn_calls}
|
| 435 |
+
\t\t}
|
| 436 |
+
'''
|
| 437 |
+
|
| 438 |
+
self.top_level_suffix = '''
|
| 439 |
+
\t} // namespace library
|
| 440 |
+
} // namespace cutlass
|
| 441 |
+
|
| 442 |
+
'''
|
| 443 |
+
|
| 444 |
+
#
|
| 445 |
+
def __enter__(self):
|
| 446 |
+
_LOGGER.debug("*** EmitInterfaceLibrary::__enter__")
|
| 447 |
+
|
| 448 |
+
self.top_level_path = os.path.join(self.generated_path, 'initialize_all.cpp')
|
| 449 |
+
_LOGGER.debug("*** top_level_path: " + str(self.top_level_path))
|
| 450 |
+
|
| 451 |
+
self.top_level_file = open(self.top_level_path, "w")
|
| 452 |
+
self.top_level_file.write(self.top_level_hdr_template)
|
| 453 |
+
|
| 454 |
+
self.source_files = [self.top_level_path,]
|
| 455 |
+
|
| 456 |
+
return self
|
| 457 |
+
|
| 458 |
+
#
|
| 459 |
+
def emit(self, operation_name):
|
| 460 |
+
_LOGGER.debug("*** EmitInterfaceLibrary::emit")
|
| 461 |
+
_LOGGER.debug("*** operation_name: " + operation_name)
|
| 462 |
+
|
| 463 |
+
self.prototypes.append(SubstituteTemplate(
|
| 464 |
+
"\t\tvoid initialize_all_${operation_kind}_operations(Manifest &manifest);",
|
| 465 |
+
{'operation_kind': operation_name}))
|
| 466 |
+
|
| 467 |
+
self.fn_calls.append(SubstituteTemplate(
|
| 468 |
+
"\t\t\tinitialize_all_${operation_kind}_operations(manifest);",
|
| 469 |
+
{'operation_kind': operation_name}))
|
| 470 |
+
|
| 471 |
+
#
|
| 472 |
+
def __exit__(self, exception_type, exception_value, traceback):
|
| 473 |
+
_LOGGER.debug("*** EmitInterfaceLibrary::__exit__")
|
| 474 |
+
|
| 475 |
+
self.top_level_file.write(SubstituteTemplate(self.top_level_prologue, {'prototypes':"\n".join(self.prototypes)}))
|
| 476 |
+
|
| 477 |
+
# Write out initialize_all method
|
| 478 |
+
self.top_level_file.write(SubstituteTemplate(self.top_level_initialize,
|
| 479 |
+
{'operation_count': self.operation_count, 'fn_calls':"\n".join(self.fn_calls)}))
|
| 480 |
+
|
| 481 |
+
self.top_level_file.write(self.top_level_suffix)
|
| 482 |
+
self.top_level_file.close()
|
| 483 |
+
|
| 484 |
+
###################################################################################################
|
| 485 |
+
###################################################################################################
|
| 486 |
+
|
| 487 |
+
class Options:
|
| 488 |
+
def __init__(self):
|
| 489 |
+
pass
|
| 490 |
+
|
| 491 |
+
###################################################################################################
|
| 492 |
+
|
| 493 |
+
#
|
| 494 |
+
class Manifest:
|
| 495 |
+
|
| 496 |
+
#
|
| 497 |
+
def __init__(self, args = None):
|
| 498 |
+
self.operations = {}
|
| 499 |
+
self.args = args
|
| 500 |
+
self.operation_count = 0
|
| 501 |
+
self.operations_by_name = {}
|
| 502 |
+
|
| 503 |
+
self.kernel_filter = ''
|
| 504 |
+
self.kernel_filter_list = []
|
| 505 |
+
self.kernel_names = []
|
| 506 |
+
self.operations_enabled = []
|
| 507 |
+
self.selected_kernels = []
|
| 508 |
+
self.ignore_kernel_names = []
|
| 509 |
+
self.exclude_kernel_names = []
|
| 510 |
+
self.compute_capabilities_baseline = [50,]
|
| 511 |
+
self.compute_capabilities_feature_set = ['50',]
|
| 512 |
+
self.curr_build_dir = '.'
|
| 513 |
+
self.filter_by_cc = True
|
| 514 |
+
|
| 515 |
+
if self.args:
|
| 516 |
+
self.kernel_filter = self.args.kernels
|
| 517 |
+
self.curr_build_dir = args.curr_build_dir
|
| 518 |
+
|
| 519 |
+
# A common user error is to use commas instead of semicolons.
|
| 520 |
+
if ',' in args.architectures:
|
| 521 |
+
raise RuntimeError("The list of architectures (CMake option CUTLASS_NVCC_ARCHS) must be semicolon-delimited.\nDon't use commas to separate the architectures; use semicolons.\nYou specified the list as: " + args.architectures)
|
| 522 |
+
|
| 523 |
+
self.compute_capabilities_feature_set = args.architectures.split(';') if len(args.architectures) else ['50',]
|
| 524 |
+
self.compute_capabilities_baseline = sorted(set(int(arch.split('a')[0].split('f')[0]) for arch in self.compute_capabilities_feature_set))
|
| 525 |
+
|
| 526 |
+
if args.filter_by_cc in ['false', 'False', '0']:
|
| 527 |
+
self.filter_by_cc = False
|
| 528 |
+
|
| 529 |
+
if args.operations == 'all':
|
| 530 |
+
self.operations_enabled = []
|
| 531 |
+
else:
|
| 532 |
+
operations_list = [
|
| 533 |
+
OperationKind.Gemm
|
| 534 |
+
, OperationKind.Conv2d
|
| 535 |
+
, OperationKind.Conv3d
|
| 536 |
+
, OperationKind.RankK
|
| 537 |
+
, OperationKind.Trmm
|
| 538 |
+
, OperationKind.Symm
|
| 539 |
+
]
|
| 540 |
+
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
|
| 541 |
+
|
| 542 |
+
if args.kernels == 'all':
|
| 543 |
+
self.kernel_names = []
|
| 544 |
+
else:
|
| 545 |
+
self.kernel_names = [x for x in args.kernels.split(',') if x != '']
|
| 546 |
+
|
| 547 |
+
self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != '']
|
| 548 |
+
self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != '']
|
| 549 |
+
|
| 550 |
+
if args.kernel_filter_file is None:
|
| 551 |
+
self.kernel_filter_list = []
|
| 552 |
+
else:
|
| 553 |
+
self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
|
| 554 |
+
_LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format(
|
| 555 |
+
filter_count = len(self.kernel_filter_list),
|
| 556 |
+
filter_file = args.kernel_filter_file))
|
| 557 |
+
|
| 558 |
+
self.operation_count = 0
|
| 559 |
+
self.operations_by_name = {}
|
| 560 |
+
self.disable_full_archs_compilation = args.disable_full_archs_compilation
|
| 561 |
+
self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != ''
|
| 562 |
+
self.instantiation_level = 0
|
| 563 |
+
try:
|
| 564 |
+
self.instantiation_level = int(args.instantiation_level)
|
| 565 |
+
except ValueError:
|
| 566 |
+
self.instantiation_level = 0
|
| 567 |
+
|
| 568 |
+
def add_kernel_filter(self, filter_str):
|
| 569 |
+
filter_re = re.compile(filter_str)
|
| 570 |
+
|
| 571 |
+
self.kernel_filter_list.append(filter_re)
|
| 572 |
+
|
| 573 |
+
def get_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9992):
|
| 574 |
+
# Non-negative integer which determines how many kernels are instantiated.
|
| 575 |
+
# 0 = 0000 generates the fewest kernels, 9999 generates all possible combinations.
|
| 576 |
+
# increasing first digit reduces schedule / mixed type pruning,
|
| 577 |
+
# increasing second digit generates more cluster sizes,
|
| 578 |
+
# increasing third digit generates more MMA multipliers,
|
| 579 |
+
# increasing fourth digit generates more instruction shapes.
|
| 580 |
+
|
| 581 |
+
if self.instantiation_level > 0:
|
| 582 |
+
return self.instantiation_level
|
| 583 |
+
|
| 584 |
+
elif self.is_kernel_filter_set_to_all:
|
| 585 |
+
return exhaustive_level
|
| 586 |
+
|
| 587 |
+
elif self.kernel_filter == '':
|
| 588 |
+
return pruned_level
|
| 589 |
+
|
| 590 |
+
else:
|
| 591 |
+
return default_level
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def get_kernel_filters(self, kernelListFile):
|
| 595 |
+
if os.path.isfile(kernelListFile):
|
| 596 |
+
with open(kernelListFile, 'r') as fileReader:
|
| 597 |
+
lines = [line.rstrip() for line in fileReader if not line.startswith("#")]
|
| 598 |
+
|
| 599 |
+
lines = [re.compile(line) for line in lines if line]
|
| 600 |
+
return lines
|
| 601 |
+
else:
|
| 602 |
+
return []
|
| 603 |
+
|
| 604 |
+
#
|
| 605 |
+
def filter_out_kernels(self, kernel_name, kernel_filter_list):
|
| 606 |
+
|
| 607 |
+
for kernel_filter_re in kernel_filter_list:
|
| 608 |
+
if kernel_filter_re.search(kernel_name) is not None:
|
| 609 |
+
return True
|
| 610 |
+
|
| 611 |
+
return False
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
#
|
| 615 |
+
def _filter_string_matches(self, filter_string, haystack):
|
| 616 |
+
''' Returns true if all substrings appear in the haystack in order'''
|
| 617 |
+
substrings = filter_string.split('*')
|
| 618 |
+
for sub in substrings:
|
| 619 |
+
idx = haystack.find(sub)
|
| 620 |
+
if idx < 0:
|
| 621 |
+
return False
|
| 622 |
+
haystack = haystack[idx + len(sub):]
|
| 623 |
+
return True
|
| 624 |
+
|
| 625 |
+
#
|
| 626 |
+
def filter(self, operation):
|
| 627 |
+
''' Filtering operations based on various criteria'''
|
| 628 |
+
|
| 629 |
+
# filter based on compute capability
|
| 630 |
+
enabled = not (self.filter_by_cc)
|
| 631 |
+
|
| 632 |
+
for cc in self.compute_capabilities_baseline:
|
| 633 |
+
|
| 634 |
+
if cc >= operation.tile_description.minimum_compute_capability and \
|
| 635 |
+
cc <= operation.tile_description.maximum_compute_capability and \
|
| 636 |
+
(cc not in SharedMemPerCC or SharedMemPerCC[cc] >= CalculateSmemUsage(operation)):
|
| 637 |
+
|
| 638 |
+
enabled = True
|
| 639 |
+
break
|
| 640 |
+
|
| 641 |
+
if not enabled:
|
| 642 |
+
return False
|
| 643 |
+
|
| 644 |
+
if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled:
|
| 645 |
+
return False
|
| 646 |
+
|
| 647 |
+
name = operation.procedural_name()
|
| 648 |
+
|
| 649 |
+
# eliminate duplicates
|
| 650 |
+
if name in self.operations_by_name.keys():
|
| 651 |
+
return False
|
| 652 |
+
|
| 653 |
+
# Filter based on list of valid substrings
|
| 654 |
+
if len(self.kernel_names):
|
| 655 |
+
enabled = False
|
| 656 |
+
|
| 657 |
+
# compare against the include list
|
| 658 |
+
for name_substr in self.kernel_names:
|
| 659 |
+
if self._filter_string_matches(name_substr, name):
|
| 660 |
+
_LOGGER.debug(f"Kernel {name} included due to filter string '{name_substr}'.")
|
| 661 |
+
enabled = True
|
| 662 |
+
break
|
| 663 |
+
else:
|
| 664 |
+
_LOGGER.debug(f"Kernel {name} NOT included due to not matching '{name_substr}'.")
|
| 665 |
+
|
| 666 |
+
# compare against the exclude list
|
| 667 |
+
for name_substr in self.ignore_kernel_names:
|
| 668 |
+
if self._filter_string_matches(name_substr, name):
|
| 669 |
+
_LOGGER.debug(f"Kernel {name} ignored due to filter string '{name_substr}'.")
|
| 670 |
+
enabled = False
|
| 671 |
+
break
|
| 672 |
+
else:
|
| 673 |
+
_LOGGER.debug(f"Kernel {name} NOT ignored due to not matching '{name_substr}'.")
|
| 674 |
+
|
| 675 |
+
if len(self.kernel_filter_list) > 0:
|
| 676 |
+
if self.filter_out_kernels(name, self.kernel_filter_list):
|
| 677 |
+
_LOGGER.debug(f"Kernel {name} matched via kernel filter file.")
|
| 678 |
+
enabled = True
|
| 679 |
+
else:
|
| 680 |
+
_LOGGER.debug(f"Kernel {name} culled due to no match in kernel filter file.")
|
| 681 |
+
enabled = False
|
| 682 |
+
|
| 683 |
+
# CUTLASS_LIBRARY_IGNORE_KERNELS ("ignore" list) only takes effect
|
| 684 |
+
# if CUTLASS_LIBRARY_KERNELS was specified.
|
| 685 |
+
# Changing that would break backwards compatibility.
|
| 686 |
+
# Thus, CUTLASS has introduced the new CMake option CUTLASS_LIBRARY_EXCLUDE_KERNELS,
|
| 687 |
+
# that always takes effect, whether or not CUTLASS_LIBRARY_KERNELS was specified.
|
| 688 |
+
for name_substr in self.exclude_kernel_names:
|
| 689 |
+
if self._filter_string_matches(name_substr, name):
|
| 690 |
+
_LOGGER.debug(f"Kernel {name} excluded due to filter string '{name_substr}'.")
|
| 691 |
+
enabled = False
|
| 692 |
+
break
|
| 693 |
+
else:
|
| 694 |
+
_LOGGER.debug(f"Kernel {name} NOT excluded due to not matching '{name_substr}'.")
|
| 695 |
+
|
| 696 |
+
# TODO: filter based on compute data type
|
| 697 |
+
return enabled
|
| 698 |
+
#
|
| 699 |
+
|
| 700 |
+
#
|
| 701 |
+
def append(self, operation):
|
| 702 |
+
'''
|
| 703 |
+
Inserts the operation.
|
| 704 |
+
|
| 705 |
+
operation_kind -> configuration_name -> []
|
| 706 |
+
'''
|
| 707 |
+
|
| 708 |
+
if self.filter(operation):
|
| 709 |
+
|
| 710 |
+
self.selected_kernels.append(operation.procedural_name())
|
| 711 |
+
|
| 712 |
+
self.operations_by_name[operation.procedural_name()] = operation
|
| 713 |
+
|
| 714 |
+
# add the configuration
|
| 715 |
+
configuration_name = operation.configuration_name()
|
| 716 |
+
|
| 717 |
+
# Split operations by minimum CC
|
| 718 |
+
min_cc = operation.arch
|
| 719 |
+
|
| 720 |
+
if operation.operation_kind not in self.operations.keys():
|
| 721 |
+
self.operations[operation.operation_kind] = {}
|
| 722 |
+
|
| 723 |
+
if min_cc not in self.operations[operation.operation_kind]:
|
| 724 |
+
self.operations[operation.operation_kind][min_cc] = {}
|
| 725 |
+
|
| 726 |
+
if configuration_name not in self.operations[operation.operation_kind][min_cc].keys():
|
| 727 |
+
self.operations[operation.operation_kind][min_cc][configuration_name] = []
|
| 728 |
+
|
| 729 |
+
self.operations[operation.operation_kind][min_cc][configuration_name].append(operation)
|
| 730 |
+
self.operation_count += 1
|
| 731 |
+
else:
|
| 732 |
+
_LOGGER.debug("Culled {} from manifest".format(operation.procedural_name()))
|
| 733 |
+
#
|
| 734 |
+
|
| 735 |
+
def emit_manifest_cmake(self, manifest_path, top_level_path, source_files):
|
| 736 |
+
with open(manifest_path, "w") as manifest_file:
|
| 737 |
+
|
| 738 |
+
target_text = SubstituteTemplate("""cutlass_target_sources(cutlass_library_objs PRIVATE
|
| 739 |
+
""", { })
|
| 740 |
+
manifest_file.write(target_text + '\n\n')
|
| 741 |
+
manifest_file.write(" %s\n" % str(top_level_path.replace('\\', '/')))
|
| 742 |
+
generated_path = os.path.join(self.curr_build_dir, 'generated')
|
| 743 |
+
for kind in self.operations.keys():
|
| 744 |
+
kind_str = OperationKindNames[kind]
|
| 745 |
+
all_kind_file = os.path.join(generated_path, kind_str, f"all_{kind_str}_operations.cu").replace('\\', '/')
|
| 746 |
+
manifest_file.write(f" {all_kind_file}\n")
|
| 747 |
+
manifest_file.write(')\n\n')
|
| 748 |
+
|
| 749 |
+
for kind in self.operations.keys():
|
| 750 |
+
for min_cc in sorted(self.operations[kind].keys()):
|
| 751 |
+
for subclass in sorted(source_files[kind][min_cc].keys()):
|
| 752 |
+
target_text = SubstituteTemplate("""cutlass_add_cutlass_library(
|
| 753 |
+
SUFFIX ${kind}_sm${min_cc}_${subclass}
|
| 754 |
+
""", { 'min_cc': str(min_cc), 'kind': OperationKindNames[kind], 'subclass': subclass })
|
| 755 |
+
manifest_file.write(target_text + '\n\n')
|
| 756 |
+
|
| 757 |
+
for source_file in source_files[kind][min_cc][subclass]:
|
| 758 |
+
manifest_file.write(" %s\n" % str(source_file.replace('\\', '/')))
|
| 759 |
+
|
| 760 |
+
manifest_file.write(")\n")
|
| 761 |
+
|
| 762 |
+
if self.disable_full_archs_compilation:
|
| 763 |
+
self.emit_disable_full_archs_compilation(manifest_file, source_files)
|
| 764 |
+
|
| 765 |
+
def emit_disable_full_archs_compilation(manifest_file, source_files):
|
| 766 |
+
def for_hopper(name):
|
| 767 |
+
pass
|
| 768 |
+
|
| 769 |
+
def for_ampere(name):
|
| 770 |
+
return "16816" in name or \
|
| 771 |
+
"16832" in name or \
|
| 772 |
+
"16864" in name or \
|
| 773 |
+
("1688" in name and "tf32" in name)
|
| 774 |
+
|
| 775 |
+
def for_turing(name):
|
| 776 |
+
return ("1688" in name and "tf32" not in name) or \
|
| 777 |
+
"8816" in name
|
| 778 |
+
|
| 779 |
+
def for_volta(name):
|
| 780 |
+
return "884" in name
|
| 781 |
+
|
| 782 |
+
def is_cpp(name):
|
| 783 |
+
return name.endswith(".cpp")
|
| 784 |
+
|
| 785 |
+
def get_src_archs_str_given_requested_cuda_archs(archs, source_file):
|
| 786 |
+
intersected_archs = archs & set(self.compute_capabilities_baseline)
|
| 787 |
+
if intersected_archs == set():
|
| 788 |
+
raise RuntimeError(
|
| 789 |
+
"""
|
| 790 |
+
Empty archs set for file {} after taking
|
| 791 |
+
the intersection of {} (global requested archs) and
|
| 792 |
+
{} (per file requested archs)
|
| 793 |
+
""".format(source_file, set(self.compute_capabilities_baseline), archs))
|
| 794 |
+
else:
|
| 795 |
+
return " ".join(map(str, intersected_archs))
|
| 796 |
+
|
| 797 |
+
for min_cc in sorted(source_files.keys()):
|
| 798 |
+
for source_file in source_files[min_cc]:
|
| 799 |
+
if is_cpp(source_file):
|
| 800 |
+
continue # skip because source is cpp
|
| 801 |
+
elif for_ampere(source_file):
|
| 802 |
+
archs_str = get_src_archs_str_given_requested_cuda_archs({80, 87, 90}, source_file)
|
| 803 |
+
elif for_turing(source_file):
|
| 804 |
+
archs_str = get_src_archs_str_given_requested_cuda_archs({75}, source_file)
|
| 805 |
+
elif for_volta(source_file):
|
| 806 |
+
archs_str = get_src_archs_str_given_requested_cuda_archs({70, 72}, source_file)
|
| 807 |
+
else:
|
| 808 |
+
raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file))
|
| 809 |
+
|
| 810 |
+
manifest_file.write("cutlass_apply_cuda_gencode_flags({} SM_ARCHS {})\n".format(str(source_file.replace('\\', '/')), archs_str))
|
| 811 |
+
|
| 812 |
+
#
|
| 813 |
+
def emit(self, target = GeneratorTarget.Library):
|
| 814 |
+
|
| 815 |
+
operation_emitters = {
|
| 816 |
+
GeneratorTarget.Library: EmitOperationKindLibrary
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
# Emitters for all operations that fall under a particular kind (e.g., GEMM, Conv2d)
|
| 820 |
+
kind_emitters = {
|
| 821 |
+
GeneratorTarget.Library: EmitOperationKindAll
|
| 822 |
+
}
|
| 823 |
+
|
| 824 |
+
interface_emitters = {
|
| 825 |
+
GeneratorTarget.Library: EmitInterfaceLibrary
|
| 826 |
+
}
|
| 827 |
+
|
| 828 |
+
generated_path = os.path.join(self.curr_build_dir, 'generated')
|
| 829 |
+
|
| 830 |
+
# create generated/
|
| 831 |
+
if os.path.exists(generated_path):
|
| 832 |
+
shutil.rmtree(generated_path)
|
| 833 |
+
|
| 834 |
+
os.mkdir(generated_path)
|
| 835 |
+
|
| 836 |
+
with interface_emitters[target](generated_path, self.operation_count, self.args) as iface_emitter:
|
| 837 |
+
top_level_path = iface_emitter.top_level_path
|
| 838 |
+
for operation_kind in self.operations.keys():
|
| 839 |
+
iface_emitter.emit(OperationKindNames[operation_kind])
|
| 840 |
+
|
| 841 |
+
source_files = {}
|
| 842 |
+
for kind in self.operations.keys():
|
| 843 |
+
source_files[kind] = {}
|
| 844 |
+
for min_cc in self.operations[kind].keys():
|
| 845 |
+
source_files[kind][min_cc] = {}
|
| 846 |
+
|
| 847 |
+
for operation_kind, ops in self.operations.items():
|
| 848 |
+
for min_cc, configurations in sorted(ops.items()):
|
| 849 |
+
with operation_emitters[target](generated_path, min_cc, operation_kind, self.args) as operation_kind_emitter:
|
| 850 |
+
for configuration_name, operations in configurations.items():
|
| 851 |
+
_LOGGER.info(f"Emitting {configuration_name} with {len(operations)} operation{'' if len(operations) == 1 else 's'}.")
|
| 852 |
+
operation_kind_emitter.emit(configuration_name, operations)
|
| 853 |
+
|
| 854 |
+
for subclass, files in operation_kind_emitter.source_files.items():
|
| 855 |
+
if subclass not in source_files[operation_kind][min_cc]:
|
| 856 |
+
source_files[operation_kind][min_cc][subclass] = []
|
| 857 |
+
source_files[operation_kind][min_cc][subclass].extend(operation_kind_emitter.source_files[subclass])
|
| 858 |
+
|
| 859 |
+
# Emit top level all_{gemm, conv2d, ...}_operations.cu files
|
| 860 |
+
with kind_emitters[target](generated_path, operation_kind, self.args) as operation_kind_emitter:
|
| 861 |
+
operation_kind_emitter.emit(ops)
|
| 862 |
+
|
| 863 |
+
# write the manifest.cmake file containing paths from all targets
|
| 864 |
+
manifest_path = os.path.join(generated_path, "manifest.cmake")
|
| 865 |
+
|
| 866 |
+
self.emit_manifest_cmake(manifest_path, top_level_path, source_files)
|
| 867 |
+
|
| 868 |
+
###################################################################################################
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for emitting Rank2K kernels
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import enum
|
| 38 |
+
import functools
|
| 39 |
+
import operator
|
| 40 |
+
import os.path
|
| 41 |
+
import shutil
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
import builtins
|
| 45 |
+
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 46 |
+
raise ImportError("Disabling attempt to import cutlass_library")
|
| 47 |
+
from cutlass_library.library import *
|
| 48 |
+
except ImportError:
|
| 49 |
+
from library import *
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
###################################################################################################
|
| 53 |
+
#
|
| 54 |
+
# Data structure modeling a Rank K update operation
|
| 55 |
+
#
|
| 56 |
+
###################################################################################################
|
| 57 |
+
|
| 58 |
+
#
|
| 59 |
+
class Rank2KOperation:
|
| 60 |
+
#
|
| 61 |
+
def __init__(self, rank_k_kind, arch, tile_description, A, C, element_epilogue, \
|
| 62 |
+
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \
|
| 63 |
+
blas_mode = BlasMode.symmetric):
|
| 64 |
+
|
| 65 |
+
self.blas_mode = blas_mode
|
| 66 |
+
self.operation_kind = OperationKind.Rank2K
|
| 67 |
+
self.arch = arch
|
| 68 |
+
self.tile_description = tile_description
|
| 69 |
+
self.rank_k_kind = rank_k_kind
|
| 70 |
+
# tensor A and B have same data type and layout
|
| 71 |
+
self.A = A
|
| 72 |
+
self.B = A
|
| 73 |
+
self.C = C
|
| 74 |
+
self.element_epilogue = element_epilogue
|
| 75 |
+
self.epilogue_functor = epilogue_functor
|
| 76 |
+
self.swizzling_functor = swizzling_functor
|
| 77 |
+
|
| 78 |
+
#
|
| 79 |
+
def is_complex(self):
|
| 80 |
+
complex_operators = [
|
| 81 |
+
MathOperation.multiply_add_complex,
|
| 82 |
+
MathOperation.multiply_add_complex_gaussian,
|
| 83 |
+
MathOperation.multiply_add_complex_fast_f32
|
| 84 |
+
]
|
| 85 |
+
return self.tile_description.math_instruction.math_operation in complex_operators
|
| 86 |
+
return False
|
| 87 |
+
|
| 88 |
+
#
|
| 89 |
+
def is_mixed_input(self):
|
| 90 |
+
return self.A.element != self.B.element
|
| 91 |
+
|
| 92 |
+
#
|
| 93 |
+
def is_planar_complex(self):
|
| 94 |
+
return False
|
| 95 |
+
|
| 96 |
+
#
|
| 97 |
+
def accumulator_type(self):
|
| 98 |
+
accum = self.tile_description.math_instruction.element_accumulator
|
| 99 |
+
|
| 100 |
+
if self.is_complex():
|
| 101 |
+
return get_complex_from_real(accum)
|
| 102 |
+
|
| 103 |
+
return accum
|
| 104 |
+
|
| 105 |
+
#
|
| 106 |
+
def short_math_name(self):
|
| 107 |
+
if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
|
| 108 |
+
return "g%s" % ShortDataTypeNames[self.accumulator_type()]
|
| 109 |
+
return ShortDataTypeNames[self.accumulator_type()]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
#
|
| 113 |
+
def core_name(self):
|
| 114 |
+
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
| 115 |
+
|
| 116 |
+
inst_shape = ''
|
| 117 |
+
inst_operation = ''
|
| 118 |
+
intermediate_type = ''
|
| 119 |
+
|
| 120 |
+
math_operations_map = {
|
| 121 |
+
MathOperation.xor_popc: 'xor',
|
| 122 |
+
MathOperation.and_popc: 'and'
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
|
| 126 |
+
self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
|
| 127 |
+
|
| 128 |
+
math_op = self.tile_description.math_instruction.math_operation
|
| 129 |
+
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
|
| 130 |
+
|
| 131 |
+
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
| 132 |
+
inst_shape += math_op_string
|
| 133 |
+
|
| 134 |
+
if self.tile_description.math_instruction.element_a != self.A.element and \
|
| 135 |
+
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
| 136 |
+
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
| 137 |
+
|
| 138 |
+
operation_name = 'syr2k' if self.blas_mode == BlasMode.symmetric else 'her2k'
|
| 139 |
+
|
| 140 |
+
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name)
|
| 141 |
+
|
| 142 |
+
#
|
| 143 |
+
def extended_name(self):
|
| 144 |
+
''' Append data types if they differ from compute type. '''
|
| 145 |
+
if self.is_complex():
|
| 146 |
+
extended_name = "${core_name}"
|
| 147 |
+
else:
|
| 148 |
+
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
| 149 |
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 150 |
+
extended_name = "${element_c}_${core_name}_${element_a}"
|
| 151 |
+
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
| 152 |
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 153 |
+
extended_name = "${core_name}_${element_a}"
|
| 154 |
+
else:
|
| 155 |
+
extended_name = "${core_name}"
|
| 156 |
+
|
| 157 |
+
extended_name = SubstituteTemplate(extended_name, {
|
| 158 |
+
'element_a': DataTypeNames[self.A.element],
|
| 159 |
+
'element_c': DataTypeNames[self.C.element],
|
| 160 |
+
'core_name': self.core_name()
|
| 161 |
+
})
|
| 162 |
+
|
| 163 |
+
return extended_name
|
| 164 |
+
|
| 165 |
+
#
|
| 166 |
+
def layout_name(self):
|
| 167 |
+
if self.is_complex() or self.is_planar_complex():
|
| 168 |
+
return "%s" % (
|
| 169 |
+
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)]
|
| 170 |
+
)
|
| 171 |
+
return "%s" % (ShortLayoutTypeNames[self.A.layout])
|
| 172 |
+
|
| 173 |
+
#
|
| 174 |
+
def fill_mode_name(self):
|
| 175 |
+
return "%s" % (ShortFillModeNames[self.C.fill_mode])
|
| 176 |
+
|
| 177 |
+
#
|
| 178 |
+
def procedural_name(self):
|
| 179 |
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 180 |
+
threadblock = self.tile_description.procedural_name()
|
| 181 |
+
|
| 182 |
+
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
| 183 |
+
|
| 184 |
+
alignment = max([self.A.alignment, self.C.alignment])
|
| 185 |
+
|
| 186 |
+
return SubstituteTemplate(
|
| 187 |
+
"cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${fill_mode}_align${alignment}",
|
| 188 |
+
{
|
| 189 |
+
'opcode_class': opcode_class_name,
|
| 190 |
+
'extended_name': self.extended_name(),
|
| 191 |
+
'threadblock': threadblock,
|
| 192 |
+
'layout': self.layout_name(),
|
| 193 |
+
'fill_mode': self.fill_mode_name(),
|
| 194 |
+
'alignment': "%d" % self.A.alignment,
|
| 195 |
+
}
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
#
|
| 199 |
+
def configuration_name(self):
|
| 200 |
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 201 |
+
return self.procedural_name()
|
| 202 |
+
|
| 203 |
+
###################################################################################################
|
| 204 |
+
#
|
| 205 |
+
# Emits single instances of a CUTLASS device-wide operator
|
| 206 |
+
#
|
| 207 |
+
###################################################################################################
|
| 208 |
+
|
| 209 |
+
#
|
| 210 |
+
class EmitRank2KUniversalInstance:
|
| 211 |
+
''' Responsible for emitting a CUTLASS template definition'''
|
| 212 |
+
|
| 213 |
+
def __init__(self):
|
| 214 |
+
self.rank_k_template = """
|
| 215 |
+
// Rank K operator ${operation_name}
|
| 216 |
+
using Operation_${operation_name} =
|
| 217 |
+
typename cutlass::gemm::device::Rank2K<
|
| 218 |
+
${element_a}, ${layout_a},
|
| 219 |
+
${element_b}, ${layout_b},
|
| 220 |
+
${element_c}, ${layout_c}, ${fill_mode},
|
| 221 |
+
${element_accumulator},
|
| 222 |
+
${opcode_class},
|
| 223 |
+
${arch},
|
| 224 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 225 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 226 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 227 |
+
${epilogue_functor}<
|
| 228 |
+
${element_c},
|
| 229 |
+
${epilogue_vector_length},
|
| 230 |
+
${element_accumulator},
|
| 231 |
+
${element_epilogue}
|
| 232 |
+
>,
|
| 233 |
+
${swizzling_functor},
|
| 234 |
+
${stages},
|
| 235 |
+
${align_a},
|
| 236 |
+
${align_b},
|
| 237 |
+
${split_k_serial},
|
| 238 |
+
${math_operation}
|
| 239 |
+
>;
|
| 240 |
+
"""
|
| 241 |
+
self.rank_k_complex_template = """
|
| 242 |
+
// Rank K operator ${operation_name}
|
| 243 |
+
using Operation_${operation_name} =
|
| 244 |
+
typename cutlass::gemm::device::Rank2K<
|
| 245 |
+
${element_a}, ${layout_a},
|
| 246 |
+
${element_b}, ${layout_b},
|
| 247 |
+
${element_c}, ${layout_c}, ${fill_mode},
|
| 248 |
+
${element_accumulator},
|
| 249 |
+
${opcode_class},
|
| 250 |
+
${arch},
|
| 251 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 252 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 253 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 254 |
+
${epilogue_functor}<
|
| 255 |
+
${element_c},
|
| 256 |
+
${epilogue_vector_length},
|
| 257 |
+
${element_accumulator},
|
| 258 |
+
${element_epilogue}
|
| 259 |
+
>,
|
| 260 |
+
${swizzling_functor},
|
| 261 |
+
${stages},
|
| 262 |
+
${align_a},
|
| 263 |
+
${align_b},
|
| 264 |
+
${split_k_serial},
|
| 265 |
+
${math_operation},
|
| 266 |
+
${transform_a},
|
| 267 |
+
${transform_b},
|
| 268 |
+
${blas_mode}
|
| 269 |
+
>;
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
def emit(self, operation):
|
| 273 |
+
|
| 274 |
+
threadblock_shape = operation.tile_description.threadblock_shape
|
| 275 |
+
|
| 276 |
+
warp_count = operation.tile_description.warp_count
|
| 277 |
+
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
| 278 |
+
|
| 279 |
+
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
| 280 |
+
|
| 281 |
+
values = {
|
| 282 |
+
'operation_name': operation.procedural_name(),
|
| 283 |
+
'element_a': DataTypeTag[operation.A.element],
|
| 284 |
+
'layout_a': LayoutTag[operation.A.layout],
|
| 285 |
+
'element_b': DataTypeTag[operation.B.element],
|
| 286 |
+
'layout_b': LayoutTag[operation.B.layout],
|
| 287 |
+
'element_c': DataTypeTag[operation.C.element],
|
| 288 |
+
'layout_c': LayoutTag[operation.C.layout],
|
| 289 |
+
'fill_mode': FillModeTag[operation.C.fill_mode],
|
| 290 |
+
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
| 291 |
+
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 292 |
+
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
| 293 |
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
| 294 |
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
| 295 |
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
| 296 |
+
'warp_shape_m': str(warp_shape[0]),
|
| 297 |
+
'warp_shape_n': str(warp_shape[1]),
|
| 298 |
+
'warp_shape_k': str(warp_shape[2]),
|
| 299 |
+
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 300 |
+
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 301 |
+
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 302 |
+
'epilogue_vector_length': str(epilogue_vector_length),
|
| 303 |
+
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 304 |
+
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
| 305 |
+
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
| 306 |
+
'stages': str(operation.tile_description.stages),
|
| 307 |
+
'align_a': str(operation.A.alignment),
|
| 308 |
+
'align_b': str(operation.B.alignment),
|
| 309 |
+
'split_k_serial': 'false',
|
| 310 |
+
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
| 311 |
+
'transform_a': ComplexTransformTag[operation.A.complex_transform],
|
| 312 |
+
'transform_b': ComplexTransformTag[operation.B.complex_transform],
|
| 313 |
+
'blas_mode': BlasModeTag[operation.blas_mode]
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
rank_k_template = self.rank_k_complex_template if operation.is_complex() else self.rank_k_template
|
| 317 |
+
|
| 318 |
+
return SubstituteTemplate(rank_k_template, values)
|
| 319 |
+
|
| 320 |
+
###################################################################################################
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
###################################################################################################
|
| 324 |
+
#
|
| 325 |
+
# Emitters functions for all targets
|
| 326 |
+
#
|
| 327 |
+
###################################################################################################
|
| 328 |
+
|
| 329 |
+
class EmitRank2KConfigurationLibrary:
|
| 330 |
+
def __init__(self, operation_path, configuration_name):
|
| 331 |
+
self.configuration_name = configuration_name
|
| 332 |
+
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
|
| 333 |
+
|
| 334 |
+
self.instance_emitter = {
|
| 335 |
+
RankKKind.Universal: EmitRank2KUniversalInstance,
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
self.rank_k_kind_wrappers = {
|
| 339 |
+
RankKKind.Universal: 'Rank2KOperation',
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
self.instance_template = {
|
| 343 |
+
RankKKind.Universal: """
|
| 344 |
+
${compile_guard_start}
|
| 345 |
+
manifest.append(new ${rank_k_kind}<
|
| 346 |
+
Operation_${operation_name}
|
| 347 |
+
>("${operation_name}"));
|
| 348 |
+
${compile_guard_end}
|
| 349 |
+
"""
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
self.header_template = """
|
| 353 |
+
/*
|
| 354 |
+
Generated by rank_2k_operation.py - Do not edit.
|
| 355 |
+
*/
|
| 356 |
+
|
| 357 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 358 |
+
#include "cutlass/cutlass.h"
|
| 359 |
+
#include "cutlass/library/library.h"
|
| 360 |
+
#include "cutlass/library/manifest.h"
|
| 361 |
+
|
| 362 |
+
#include "library_internal.h"
|
| 363 |
+
#include "rank_2k_operation.h"
|
| 364 |
+
|
| 365 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 366 |
+
|
| 367 |
+
"""
|
| 368 |
+
|
| 369 |
+
self.initialize_function_template = """
|
| 370 |
+
|
| 371 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 372 |
+
|
| 373 |
+
namespace cutlass {
|
| 374 |
+
namespace library {
|
| 375 |
+
|
| 376 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 377 |
+
|
| 378 |
+
void initialize_${configuration_name}(Manifest &manifest) {
|
| 379 |
+
|
| 380 |
+
"""
|
| 381 |
+
self.epilogue_template = """
|
| 382 |
+
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 386 |
+
|
| 387 |
+
} // namespace library
|
| 388 |
+
} // namespace cutlass
|
| 389 |
+
|
| 390 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 391 |
+
|
| 392 |
+
"""
|
| 393 |
+
|
| 394 |
+
def __enter__(self):
|
| 395 |
+
self.configuration_file = open(self.configuration_path, "w")
|
| 396 |
+
self.configuration_file.write(self.header_template)
|
| 397 |
+
|
| 398 |
+
self.instance_definitions = []
|
| 399 |
+
self.instance_wrappers = []
|
| 400 |
+
|
| 401 |
+
self.operations = []
|
| 402 |
+
return self
|
| 403 |
+
|
| 404 |
+
def emit(self, operation):
|
| 405 |
+
emitter = self.instance_emitter[operation.rank_k_kind]()
|
| 406 |
+
|
| 407 |
+
self.operations.append(operation)
|
| 408 |
+
|
| 409 |
+
self.instance_definitions.append(emitter.emit(operation))
|
| 410 |
+
|
| 411 |
+
self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.rank_k_kind], {
|
| 412 |
+
'configuration_name': self.configuration_name,
|
| 413 |
+
'operation_name': operation.procedural_name(),
|
| 414 |
+
'rank_k_kind': self.rank_k_kind_wrappers[operation.rank_k_kind],
|
| 415 |
+
'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
|
| 416 |
+
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
|
| 417 |
+
'compile_guard_end': "#endif" \
|
| 418 |
+
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
|
| 419 |
+
}))
|
| 420 |
+
|
| 421 |
+
def __exit__(self, exception_type, exception_value, traceback):
|
| 422 |
+
|
| 423 |
+
# Write instance definitions in top-level namespace
|
| 424 |
+
for instance_definition in self.instance_definitions:
|
| 425 |
+
self.configuration_file.write(instance_definition)
|
| 426 |
+
|
| 427 |
+
# Add wrapper objects within initialize() function
|
| 428 |
+
self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
|
| 429 |
+
'configuration_name': self.configuration_name
|
| 430 |
+
}))
|
| 431 |
+
|
| 432 |
+
for instance_wrapper in self.instance_wrappers:
|
| 433 |
+
self.configuration_file.write(instance_wrapper)
|
| 434 |
+
|
| 435 |
+
self.configuration_file.write(self.epilogue_template)
|
| 436 |
+
self.configuration_file.close()
|
| 437 |
+
|
| 438 |
+
###################################################################################################
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for emitting RankK kernels
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import enum
|
| 38 |
+
import functools
|
| 39 |
+
import operator
|
| 40 |
+
import os.path
|
| 41 |
+
import shutil
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
import builtins
|
| 45 |
+
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 46 |
+
raise ImportError("Disabling attempt to import cutlass_library")
|
| 47 |
+
from cutlass_library.library import *
|
| 48 |
+
except ImportError:
|
| 49 |
+
from library import *
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
###################################################################################################
|
| 53 |
+
#
|
| 54 |
+
# Data structure modeling a Rank K update operation
|
| 55 |
+
#
|
| 56 |
+
###################################################################################################
|
| 57 |
+
|
| 58 |
+
#
|
| 59 |
+
class RankKOperation:
|
| 60 |
+
#
|
| 61 |
+
def __init__(self, rank_k_kind, arch, tile_description, A, C, element_epilogue, \
|
| 62 |
+
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \
|
| 63 |
+
blas_mode = BlasMode.symmetric):
|
| 64 |
+
|
| 65 |
+
self.blas_mode = blas_mode
|
| 66 |
+
self.operation_kind = OperationKind.RankK
|
| 67 |
+
self.arch = arch
|
| 68 |
+
self.tile_description = tile_description
|
| 69 |
+
self.rank_k_kind = rank_k_kind
|
| 70 |
+
self.A = A
|
| 71 |
+
self.C = C
|
| 72 |
+
self.element_epilogue = element_epilogue
|
| 73 |
+
self.epilogue_functor = epilogue_functor
|
| 74 |
+
self.swizzling_functor = swizzling_functor
|
| 75 |
+
|
| 76 |
+
#
|
| 77 |
+
def is_complex(self):
|
| 78 |
+
complex_operators = [
|
| 79 |
+
MathOperation.multiply_add_complex,
|
| 80 |
+
MathOperation.multiply_add_complex_gaussian,
|
| 81 |
+
MathOperation.multiply_add_complex_fast_f32
|
| 82 |
+
]
|
| 83 |
+
return self.tile_description.math_instruction.math_operation in complex_operators
|
| 84 |
+
return False
|
| 85 |
+
|
| 86 |
+
#
|
| 87 |
+
def is_mixed_input(self):
|
| 88 |
+
return False
|
| 89 |
+
|
| 90 |
+
#
|
| 91 |
+
def is_planar_complex(self):
|
| 92 |
+
return False
|
| 93 |
+
|
| 94 |
+
#
|
| 95 |
+
def accumulator_type(self):
|
| 96 |
+
accum = self.tile_description.math_instruction.element_accumulator
|
| 97 |
+
|
| 98 |
+
if self.is_complex():
|
| 99 |
+
return get_complex_from_real(accum)
|
| 100 |
+
|
| 101 |
+
return accum
|
| 102 |
+
|
| 103 |
+
#
|
| 104 |
+
def short_math_name(self):
|
| 105 |
+
if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
|
| 106 |
+
return "g%s" % ShortDataTypeNames[self.accumulator_type()]
|
| 107 |
+
return ShortDataTypeNames[self.accumulator_type()]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
#
|
| 111 |
+
def core_name(self):
|
| 112 |
+
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
| 113 |
+
|
| 114 |
+
inst_shape = ''
|
| 115 |
+
inst_operation = ''
|
| 116 |
+
intermediate_type = ''
|
| 117 |
+
|
| 118 |
+
math_operations_map = {
|
| 119 |
+
MathOperation.xor_popc: 'xor',
|
| 120 |
+
MathOperation.and_popc: 'and'
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
|
| 124 |
+
self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
|
| 125 |
+
|
| 126 |
+
math_op = self.tile_description.math_instruction.math_operation
|
| 127 |
+
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
|
| 128 |
+
|
| 129 |
+
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
| 130 |
+
inst_shape += math_op_string
|
| 131 |
+
|
| 132 |
+
if self.tile_description.math_instruction.element_a != self.A.element and \
|
| 133 |
+
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
| 134 |
+
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
| 135 |
+
|
| 136 |
+
operation_name = 'syrk' if self.blas_mode == BlasMode.symmetric else 'herk'
|
| 137 |
+
|
| 138 |
+
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name)
|
| 139 |
+
|
| 140 |
+
#
|
| 141 |
+
def extended_name(self):
|
| 142 |
+
''' Append data types if they differ from compute type. '''
|
| 143 |
+
if self.is_complex():
|
| 144 |
+
extended_name = "${core_name}"
|
| 145 |
+
else:
|
| 146 |
+
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
| 147 |
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 148 |
+
extended_name = "${element_c}_${core_name}_${element_a}"
|
| 149 |
+
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
| 150 |
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 151 |
+
extended_name = "${core_name}_${element_a}"
|
| 152 |
+
else:
|
| 153 |
+
extended_name = "${core_name}"
|
| 154 |
+
|
| 155 |
+
extended_name = SubstituteTemplate(extended_name, {
|
| 156 |
+
'element_a': DataTypeNames[self.A.element],
|
| 157 |
+
'element_c': DataTypeNames[self.C.element],
|
| 158 |
+
'core_name': self.core_name()
|
| 159 |
+
})
|
| 160 |
+
|
| 161 |
+
return extended_name
|
| 162 |
+
|
| 163 |
+
#
|
| 164 |
+
def layout_name(self):
|
| 165 |
+
if self.is_complex() or self.is_planar_complex():
|
| 166 |
+
return "%s" % (
|
| 167 |
+
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)]
|
| 168 |
+
)
|
| 169 |
+
return "%s" % (ShortLayoutTypeNames[self.A.layout])
|
| 170 |
+
|
| 171 |
+
#
|
| 172 |
+
def fill_mode_name(self):
|
| 173 |
+
return "%s" % (ShortFillModeNames[self.C.fill_mode])
|
| 174 |
+
|
| 175 |
+
#
|
| 176 |
+
def procedural_name(self):
|
| 177 |
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 178 |
+
threadblock = self.tile_description.procedural_name()
|
| 179 |
+
|
| 180 |
+
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
| 181 |
+
|
| 182 |
+
alignment = max([self.A.alignment, self.C.alignment])
|
| 183 |
+
|
| 184 |
+
return SubstituteTemplate(
|
| 185 |
+
"cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${fill_mode}_align${alignment}",
|
| 186 |
+
{
|
| 187 |
+
'opcode_class': opcode_class_name,
|
| 188 |
+
'extended_name': self.extended_name(),
|
| 189 |
+
'threadblock': threadblock,
|
| 190 |
+
'layout': self.layout_name(),
|
| 191 |
+
'fill_mode': self.fill_mode_name(),
|
| 192 |
+
'alignment': "%d" % self.A.alignment,
|
| 193 |
+
}
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
#
|
| 197 |
+
def configuration_name(self):
|
| 198 |
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 199 |
+
return self.procedural_name()
|
| 200 |
+
|
| 201 |
+
###################################################################################################
|
| 202 |
+
#
|
| 203 |
+
# Emits single instances of a CUTLASS device-wide operator
|
| 204 |
+
#
|
| 205 |
+
###################################################################################################
|
| 206 |
+
|
| 207 |
+
#
|
| 208 |
+
class EmitRankKUniversalInstance:
|
| 209 |
+
''' Responsible for emitting a CUTLASS template definition'''
|
| 210 |
+
|
| 211 |
+
def __init__(self):
|
| 212 |
+
self.rank_k_template = """
|
| 213 |
+
// Rank K operator ${operation_name}
|
| 214 |
+
using Operation_${operation_name} =
|
| 215 |
+
typename cutlass::gemm::device::RankK<
|
| 216 |
+
${element_a}, ${layout_a},
|
| 217 |
+
${element_c}, ${layout_c}, ${fill_mode},
|
| 218 |
+
${element_accumulator},
|
| 219 |
+
${opcode_class},
|
| 220 |
+
${arch},
|
| 221 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 222 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 223 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 224 |
+
${epilogue_functor}<
|
| 225 |
+
${element_c},
|
| 226 |
+
${epilogue_vector_length},
|
| 227 |
+
${element_accumulator},
|
| 228 |
+
${element_epilogue}
|
| 229 |
+
>,
|
| 230 |
+
${swizzling_functor},
|
| 231 |
+
${stages},
|
| 232 |
+
${align_a},
|
| 233 |
+
${split_k_serial},
|
| 234 |
+
${math_operation}
|
| 235 |
+
>;
|
| 236 |
+
"""
|
| 237 |
+
self.rank_k_complex_template = """
|
| 238 |
+
// Rank K operator ${operation_name}
|
| 239 |
+
using Operation_${operation_name} =
|
| 240 |
+
typename cutlass::gemm::device::RankK<
|
| 241 |
+
${element_a}, ${layout_a},
|
| 242 |
+
${element_c}, ${layout_c}, ${fill_mode},
|
| 243 |
+
${element_accumulator},
|
| 244 |
+
${opcode_class},
|
| 245 |
+
${arch},
|
| 246 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 247 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 248 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 249 |
+
${epilogue_functor}<
|
| 250 |
+
${element_c},
|
| 251 |
+
${epilogue_vector_length},
|
| 252 |
+
${element_accumulator},
|
| 253 |
+
${element_epilogue}
|
| 254 |
+
>,
|
| 255 |
+
${swizzling_functor},
|
| 256 |
+
${stages},
|
| 257 |
+
${align_a},
|
| 258 |
+
${split_k_serial},
|
| 259 |
+
${math_operation},
|
| 260 |
+
${transform_a},
|
| 261 |
+
${blas_mode}
|
| 262 |
+
>;
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
def emit(self, operation):
|
| 266 |
+
|
| 267 |
+
threadblock_shape = operation.tile_description.threadblock_shape
|
| 268 |
+
|
| 269 |
+
warp_count = operation.tile_description.warp_count
|
| 270 |
+
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
| 271 |
+
|
| 272 |
+
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
| 273 |
+
|
| 274 |
+
values = {
|
| 275 |
+
'operation_name': operation.procedural_name(),
|
| 276 |
+
'element_a': DataTypeTag[operation.A.element],
|
| 277 |
+
'layout_a': LayoutTag[operation.A.layout],
|
| 278 |
+
'element_c': DataTypeTag[operation.C.element],
|
| 279 |
+
'layout_c': LayoutTag[operation.C.layout],
|
| 280 |
+
'fill_mode': FillModeTag[operation.C.fill_mode],
|
| 281 |
+
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
| 282 |
+
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 283 |
+
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
| 284 |
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
| 285 |
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
| 286 |
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
| 287 |
+
'warp_shape_m': str(warp_shape[0]),
|
| 288 |
+
'warp_shape_n': str(warp_shape[1]),
|
| 289 |
+
'warp_shape_k': str(warp_shape[2]),
|
| 290 |
+
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 291 |
+
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 292 |
+
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 293 |
+
'epilogue_vector_length': str(epilogue_vector_length),
|
| 294 |
+
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 295 |
+
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
| 296 |
+
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
| 297 |
+
'stages': str(operation.tile_description.stages),
|
| 298 |
+
'align_a': str(operation.A.alignment),
|
| 299 |
+
'split_k_serial': 'false',
|
| 300 |
+
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
| 301 |
+
'transform_a': ComplexTransformTag[operation.A.complex_transform],
|
| 302 |
+
'blas_mode': BlasModeTag[operation.blas_mode]
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
rank_k_template = self.rank_k_complex_template if operation.is_complex() else self.rank_k_template
|
| 306 |
+
|
| 307 |
+
return SubstituteTemplate(rank_k_template, values)
|
| 308 |
+
|
| 309 |
+
###################################################################################################
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
###################################################################################################
|
| 313 |
+
#
|
| 314 |
+
# Emitters functions for all targets
|
| 315 |
+
#
|
| 316 |
+
###################################################################################################
|
| 317 |
+
|
| 318 |
+
class EmitRankKConfigurationLibrary:
|
| 319 |
+
def __init__(self, operation_path, configuration_name):
|
| 320 |
+
self.configuration_name = configuration_name
|
| 321 |
+
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
|
| 322 |
+
|
| 323 |
+
self.instance_emitter = {
|
| 324 |
+
RankKKind.Universal: EmitRankKUniversalInstance,
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
self.rank_k_kind_wrappers = {
|
| 328 |
+
RankKKind.Universal: 'RankKOperation',
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
self.instance_template = {
|
| 332 |
+
RankKKind.Universal: """
|
| 333 |
+
${compile_guard_start}
|
| 334 |
+
manifest.append(new ${rank_k_kind}<
|
| 335 |
+
Operation_${operation_name}
|
| 336 |
+
>("${operation_name}"));
|
| 337 |
+
${compile_guard_end}
|
| 338 |
+
"""
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
self.header_template = """
|
| 342 |
+
/*
|
| 343 |
+
Generated by rank_k_operation.py - Do not edit.
|
| 344 |
+
*/
|
| 345 |
+
|
| 346 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 347 |
+
#include "cutlass/cutlass.h"
|
| 348 |
+
#include "cutlass/library/library.h"
|
| 349 |
+
#include "cutlass/library/manifest.h"
|
| 350 |
+
|
| 351 |
+
#include "library_internal.h"
|
| 352 |
+
#include "rank_k_operation.h"
|
| 353 |
+
|
| 354 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 355 |
+
|
| 356 |
+
"""
|
| 357 |
+
|
| 358 |
+
self.initialize_function_template = """
|
| 359 |
+
|
| 360 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 361 |
+
|
| 362 |
+
namespace cutlass {
|
| 363 |
+
namespace library {
|
| 364 |
+
|
| 365 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 366 |
+
|
| 367 |
+
void initialize_${configuration_name}(Manifest &manifest) {
|
| 368 |
+
|
| 369 |
+
"""
|
| 370 |
+
self.epilogue_template = """
|
| 371 |
+
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 375 |
+
|
| 376 |
+
} // namespace library
|
| 377 |
+
} // namespace cutlass
|
| 378 |
+
|
| 379 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 380 |
+
|
| 381 |
+
"""
|
| 382 |
+
|
| 383 |
+
def __enter__(self):
|
| 384 |
+
self.configuration_file = open(self.configuration_path, "w")
|
| 385 |
+
self.configuration_file.write(self.header_template)
|
| 386 |
+
|
| 387 |
+
self.instance_definitions = []
|
| 388 |
+
self.instance_wrappers = []
|
| 389 |
+
|
| 390 |
+
self.operations = []
|
| 391 |
+
return self
|
| 392 |
+
|
| 393 |
+
def emit(self, operation):
|
| 394 |
+
emitter = self.instance_emitter[operation.rank_k_kind]()
|
| 395 |
+
|
| 396 |
+
self.operations.append(operation)
|
| 397 |
+
|
| 398 |
+
self.instance_definitions.append(emitter.emit(operation))
|
| 399 |
+
|
| 400 |
+
self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.rank_k_kind], {
|
| 401 |
+
'configuration_name': self.configuration_name,
|
| 402 |
+
'operation_name': operation.procedural_name(),
|
| 403 |
+
'rank_k_kind': self.rank_k_kind_wrappers[operation.rank_k_kind],
|
| 404 |
+
'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
|
| 405 |
+
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
|
| 406 |
+
'compile_guard_end': "#endif" \
|
| 407 |
+
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
|
| 408 |
+
}))
|
| 409 |
+
|
| 410 |
+
def __exit__(self, exception_type, exception_value, traceback):
|
| 411 |
+
|
| 412 |
+
# Write instance definitions in top-level namespace
|
| 413 |
+
for instance_definition in self.instance_definitions:
|
| 414 |
+
self.configuration_file.write(instance_definition)
|
| 415 |
+
|
| 416 |
+
# Add wrapper objects within initialize() function
|
| 417 |
+
self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
|
| 418 |
+
'configuration_name': self.configuration_name
|
| 419 |
+
}))
|
| 420 |
+
|
| 421 |
+
for instance_wrapper in self.instance_wrappers:
|
| 422 |
+
self.configuration_file.write(instance_wrapper)
|
| 423 |
+
|
| 424 |
+
self.configuration_file.write(self.epilogue_template)
|
| 425 |
+
self.configuration_file.close()
|
| 426 |
+
|
| 427 |
+
###################################################################################################
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Valid tcgen05 shapes and cluster sizes for SM100, associated with levels.
|
| 35 |
+
These shape and level pairs are defined as dicts, where keys are shapes and values are their
|
| 36 |
+
associated levels. If the user input level for that category (tcgen05 shape, cluster
|
| 37 |
+
size) is smaller than a shape's associated level, it will be excluded, and otherwise, included.
|
| 38 |
+
Higher levels are therefore less likely emitted, but lower levels are more emitted more frequently.
|
| 39 |
+
Level 0 is always emitted.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
from .library import DynamicClusterShape
|
| 44 |
+
except:
|
| 45 |
+
from library import DynamicClusterShape
|
| 46 |
+
|
| 47 |
+
SM100_CLUSTER_SHAPES_1SM = {
|
| 48 |
+
tuple(DynamicClusterShape) : 0,
|
| 49 |
+
# size 1 cluster
|
| 50 |
+
(1, 1, 1): 1,
|
| 51 |
+
# size 2 cluster
|
| 52 |
+
(1, 2, 1): 2,
|
| 53 |
+
(2, 1, 1): 5,
|
| 54 |
+
# size 4 clusters
|
| 55 |
+
(2, 2, 1): 6,
|
| 56 |
+
(1, 4, 1): 3,
|
| 57 |
+
(4, 1, 1): 6,
|
| 58 |
+
# size 8 clusters
|
| 59 |
+
(2, 4, 1): 7,
|
| 60 |
+
(4, 2, 1): 7,
|
| 61 |
+
(1, 8, 1): 8,
|
| 62 |
+
(8, 1, 1): 8,
|
| 63 |
+
# size 16 cluster
|
| 64 |
+
(4, 4, 1): 4,
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
SM100_CLUSTER_SHAPES_2SM = {
|
| 68 |
+
tuple(DynamicClusterShape) : 0,
|
| 69 |
+
# size 2 cluster
|
| 70 |
+
(2, 1, 1): 1,
|
| 71 |
+
# size 4 clusters
|
| 72 |
+
(2, 2, 1): 2,
|
| 73 |
+
(4, 1, 1): 2,
|
| 74 |
+
# size 8 clusters
|
| 75 |
+
(2, 4, 1): 3,
|
| 76 |
+
(4, 2, 1): 3,
|
| 77 |
+
(8, 1, 1): 6,
|
| 78 |
+
# size 16 cluster
|
| 79 |
+
(4, 4, 1): 4,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
# MMA shapes
|
| 83 |
+
|
| 84 |
+
# 16b Dense
|
| 85 |
+
|
| 86 |
+
SM100_MMA_SHAPES_16b_DENSE_1SM = {
|
| 87 |
+
(64, 8, 16): 5,
|
| 88 |
+
(64, 16, 16): 2,
|
| 89 |
+
(64, 24, 16): 5,
|
| 90 |
+
(64, 32, 16): 2,
|
| 91 |
+
(64, 40, 16): 5,
|
| 92 |
+
(64, 48, 16): 5,
|
| 93 |
+
(64, 56, 16): 5,
|
| 94 |
+
(64, 64, 16): 2,
|
| 95 |
+
(64, 72, 16): 5,
|
| 96 |
+
(64, 80, 16): 5,
|
| 97 |
+
(64, 88, 16): 5,
|
| 98 |
+
(64, 96, 16): 5,
|
| 99 |
+
(64, 104, 16): 5,
|
| 100 |
+
(64, 112, 16): 5,
|
| 101 |
+
(64, 120, 16): 5,
|
| 102 |
+
(64, 128, 16): 0,
|
| 103 |
+
(64, 136, 16): 5,
|
| 104 |
+
(64, 144, 16): 5,
|
| 105 |
+
(64, 152, 16): 5,
|
| 106 |
+
(64, 160, 16): 5,
|
| 107 |
+
(64, 168, 16): 5,
|
| 108 |
+
(64, 176, 16): 5,
|
| 109 |
+
(64, 184, 16): 5,
|
| 110 |
+
(64, 192, 16): 3,
|
| 111 |
+
(64, 200, 16): 5,
|
| 112 |
+
(64, 208, 16): 5,
|
| 113 |
+
(64, 216, 16): 5,
|
| 114 |
+
(64, 224, 16): 5,
|
| 115 |
+
(64, 232, 16): 5,
|
| 116 |
+
(64, 240, 16): 5,
|
| 117 |
+
(64, 248, 16): 5,
|
| 118 |
+
(64, 256, 16): 3,
|
| 119 |
+
|
| 120 |
+
(128, 16, 16): 2,
|
| 121 |
+
(128, 32, 16): 2,
|
| 122 |
+
(128, 48, 16): 5,
|
| 123 |
+
(128, 64, 16): 2,
|
| 124 |
+
(128, 80, 16): 5,
|
| 125 |
+
(128, 96, 16): 5,
|
| 126 |
+
(128, 112, 16): 5,
|
| 127 |
+
(128, 128, 16): 0,
|
| 128 |
+
(128, 144, 16): 5,
|
| 129 |
+
(128, 160, 16): 5,
|
| 130 |
+
(128, 176, 16): 5,
|
| 131 |
+
(128, 192, 16): 3,
|
| 132 |
+
(128, 208, 16): 5,
|
| 133 |
+
(128, 224, 16): 5,
|
| 134 |
+
(128, 240, 16): 5,
|
| 135 |
+
(128, 256, 16): 0,
|
| 136 |
+
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
SM100_MMA_SHAPES_16b_DENSE_2SM = {
|
| 141 |
+
(128, 32, 16): 2,
|
| 142 |
+
(128, 64, 16): 2,
|
| 143 |
+
(128, 96, 16): 5,
|
| 144 |
+
(128, 128, 16): 0,
|
| 145 |
+
(128, 160, 16): 5,
|
| 146 |
+
(128, 192, 16): 5,
|
| 147 |
+
(128, 224, 16): 5,
|
| 148 |
+
(128, 256, 16): 0,
|
| 149 |
+
|
| 150 |
+
(256, 32, 16): 2,
|
| 151 |
+
(256, 64, 16): 2,
|
| 152 |
+
(256, 96, 16): 5,
|
| 153 |
+
(256, 128, 16): 0,
|
| 154 |
+
(256, 160, 16): 5,
|
| 155 |
+
(256, 192, 16): 3,
|
| 156 |
+
(256, 224, 16): 5,
|
| 157 |
+
(256, 256, 16): 0,
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
# TF32 Dense
|
| 161 |
+
|
| 162 |
+
SM100_MMA_SHAPES_TF32_DENSE_1SM = {
|
| 163 |
+
(64, 8, 8): 5,
|
| 164 |
+
(64, 16, 8): 2,
|
| 165 |
+
(64, 24, 8): 5,
|
| 166 |
+
(64, 32, 8): 2,
|
| 167 |
+
(64, 40, 8): 5,
|
| 168 |
+
(64, 48, 8): 5,
|
| 169 |
+
(64, 56, 8): 5,
|
| 170 |
+
(64, 64, 8): 1,
|
| 171 |
+
(64, 72, 8): 5,
|
| 172 |
+
(64, 80, 8): 5,
|
| 173 |
+
(64, 88, 8): 5,
|
| 174 |
+
(64, 96, 8): 5,
|
| 175 |
+
(64, 104, 8): 5,
|
| 176 |
+
(64, 112, 8): 5,
|
| 177 |
+
(64, 120, 8): 5,
|
| 178 |
+
(64, 128, 8): 0,
|
| 179 |
+
(64, 136, 8): 5,
|
| 180 |
+
(64, 144, 8): 5,
|
| 181 |
+
(64, 152, 8): 5,
|
| 182 |
+
(64, 160, 8): 5,
|
| 183 |
+
(64, 168, 8): 5,
|
| 184 |
+
(64, 176, 8): 5,
|
| 185 |
+
(64, 184, 8): 5,
|
| 186 |
+
(64, 192, 8): 3,
|
| 187 |
+
(64, 200, 8): 5,
|
| 188 |
+
(64, 208, 8): 5,
|
| 189 |
+
(64, 216, 8): 5,
|
| 190 |
+
(64, 224, 8): 5,
|
| 191 |
+
(64, 232, 8): 5,
|
| 192 |
+
(64, 240, 8): 5,
|
| 193 |
+
(64, 248, 8): 5,
|
| 194 |
+
(64, 256, 8): 3,
|
| 195 |
+
|
| 196 |
+
(128, 16, 8): 2,
|
| 197 |
+
(128, 32, 8): 2,
|
| 198 |
+
(128, 48, 8): 5,
|
| 199 |
+
(128, 64, 8): 2,
|
| 200 |
+
(128, 80, 8): 5,
|
| 201 |
+
(128, 96, 8): 5,
|
| 202 |
+
(128, 112, 8): 5,
|
| 203 |
+
(128, 128, 8): 0,
|
| 204 |
+
(128, 144, 8): 5,
|
| 205 |
+
(128, 160, 8): 5,
|
| 206 |
+
(128, 176, 8): 5,
|
| 207 |
+
(128, 192, 8): 3,
|
| 208 |
+
(128, 208, 8): 5,
|
| 209 |
+
(128, 224, 8): 5,
|
| 210 |
+
(128, 240, 8): 5,
|
| 211 |
+
(128, 256, 8): 0,
|
| 212 |
+
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
SM100_MMA_SHAPES_TF32_DENSE_2SM = {
|
| 216 |
+
(128, 32, 8): 2,
|
| 217 |
+
(128, 64, 8): 1,
|
| 218 |
+
(128, 96, 8): 5,
|
| 219 |
+
(128, 128, 8): 0,
|
| 220 |
+
(128, 160, 8): 5,
|
| 221 |
+
(128, 192, 8): 5,
|
| 222 |
+
(128, 224, 8): 5,
|
| 223 |
+
(128, 256, 8): 0,
|
| 224 |
+
|
| 225 |
+
(256, 32, 8): 2,
|
| 226 |
+
(256, 64, 8): 1,
|
| 227 |
+
(256, 96, 8): 5,
|
| 228 |
+
(256, 128, 8): 0,
|
| 229 |
+
(256, 160, 8): 5,
|
| 230 |
+
(256, 192, 8): 5,
|
| 231 |
+
(256, 224, 8): 5,
|
| 232 |
+
(256, 256, 8): 0,
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
# F8F6F4
|
| 236 |
+
SM100_MMA_SHAPES_F8F6F4_DENSE_1SM = {
|
| 237 |
+
(64, 8, 32): 4,
|
| 238 |
+
(64, 16, 32): 4,
|
| 239 |
+
(64, 24, 32): 5,
|
| 240 |
+
(64, 32, 32): 3,
|
| 241 |
+
(64, 40, 32): 5,
|
| 242 |
+
(64, 48, 32): 5,
|
| 243 |
+
(64, 56, 32): 5,
|
| 244 |
+
(64, 64, 32): 2,
|
| 245 |
+
(64, 72, 32): 5,
|
| 246 |
+
(64, 80, 32): 5,
|
| 247 |
+
(64, 88, 32): 5,
|
| 248 |
+
(64, 96, 32): 5,
|
| 249 |
+
(64, 104, 32): 5,
|
| 250 |
+
(64, 112, 32): 5,
|
| 251 |
+
(64, 120, 32): 5,
|
| 252 |
+
(64, 128, 32): 0,
|
| 253 |
+
(64, 136, 32): 5,
|
| 254 |
+
(64, 144, 32): 5,
|
| 255 |
+
(64, 152, 32): 5,
|
| 256 |
+
(64, 160, 32): 5,
|
| 257 |
+
(64, 168, 32): 5,
|
| 258 |
+
(64, 176, 32): 5,
|
| 259 |
+
(64, 184, 32): 5,
|
| 260 |
+
(64, 192, 32): 5,
|
| 261 |
+
(64, 200, 32): 5,
|
| 262 |
+
(64, 208, 32): 5,
|
| 263 |
+
(64, 216, 32): 5,
|
| 264 |
+
(64, 224, 32): 5,
|
| 265 |
+
(64, 232, 32): 5,
|
| 266 |
+
(64, 240, 32): 5,
|
| 267 |
+
(64, 248, 32): 5,
|
| 268 |
+
(64, 256, 32): 0,
|
| 269 |
+
|
| 270 |
+
(128, 16, 32): 4,
|
| 271 |
+
(128, 32, 32): 3,
|
| 272 |
+
(128, 48, 32): 5,
|
| 273 |
+
(128, 64, 32): 2,
|
| 274 |
+
(128, 80, 32): 5,
|
| 275 |
+
(128, 96, 32): 5,
|
| 276 |
+
(128, 112, 32): 5,
|
| 277 |
+
(128, 128, 32): 0,
|
| 278 |
+
(128, 144, 32): 5,
|
| 279 |
+
(128, 160, 32): 5,
|
| 280 |
+
(128, 176, 32): 5,
|
| 281 |
+
(128, 192, 32): 5,
|
| 282 |
+
(128, 208, 32): 5,
|
| 283 |
+
(128, 224, 32): 5,
|
| 284 |
+
(128, 240, 32): 5,
|
| 285 |
+
(128, 256, 32): 0,
|
| 286 |
+
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
SM100_MMA_SHAPES_F8F6F4_DENSE_2SM = {
|
| 290 |
+
(128, 32, 32): 3,
|
| 291 |
+
(128, 64, 32): 2,
|
| 292 |
+
(128, 96, 32): 5,
|
| 293 |
+
(128, 128, 32): 1,
|
| 294 |
+
(128, 160, 32): 5,
|
| 295 |
+
(128, 192, 32): 5,
|
| 296 |
+
(128, 224, 32): 5,
|
| 297 |
+
(128, 256, 32): 1,
|
| 298 |
+
|
| 299 |
+
(256, 32, 32): 2,
|
| 300 |
+
(256, 64, 32): 2,
|
| 301 |
+
(256, 96, 32): 5,
|
| 302 |
+
(256, 128, 32): 0,
|
| 303 |
+
(256, 160, 32): 5,
|
| 304 |
+
(256, 192, 32): 5,
|
| 305 |
+
(256, 224, 32): 5,
|
| 306 |
+
(256, 256, 32): 0,
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
# MXF8F6F4
|
| 310 |
+
SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM = {
|
| 311 |
+
(128, 64, 32): 1,
|
| 312 |
+
(128, 128, 32): 0,
|
| 313 |
+
(128, 192, 32): 1,
|
| 314 |
+
(128, 256, 32): 0,
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM = {
|
| 319 |
+
(256, 64, 32): 1,
|
| 320 |
+
(256, 128, 32): 0,
|
| 321 |
+
(256, 192, 32): 1,
|
| 322 |
+
(256, 256, 32): 0,
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
# MXF4NVF4
|
| 328 |
+
SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM = {
|
| 329 |
+
(128, 64, 64): 1,
|
| 330 |
+
(128, 128, 64): 0,
|
| 331 |
+
(128, 192, 64): 1,
|
| 332 |
+
(128, 256, 64): 0,
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM = {
|
| 336 |
+
# Multiples of 16 for N
|
| 337 |
+
(256, 64, 64): 1,
|
| 338 |
+
(256, 128, 64): 0,
|
| 339 |
+
(256, 192, 64): 1,
|
| 340 |
+
(256, 256, 64): 0,
|
| 341 |
+
|
| 342 |
+
}
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py
ADDED
|
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for enumerating CUTLASS library SM100 kernels
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import argparse
|
| 38 |
+
import enum
|
| 39 |
+
from itertools import product
|
| 40 |
+
import math
|
| 41 |
+
import logging
|
| 42 |
+
import os.path
|
| 43 |
+
import shutil
|
| 44 |
+
import sys
|
| 45 |
+
import copy
|
| 46 |
+
from typing import Any, Optional, Sequence, Tuple, List, Union, Callable
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
import builtins
|
| 50 |
+
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 51 |
+
raise ImportError("Disabling attempt to import cutlass_library")
|
| 52 |
+
from cutlass_library.library import *
|
| 53 |
+
except ImportError:
|
| 54 |
+
from library import *
|
| 55 |
+
|
| 56 |
+
#### Step 0: define levels
|
| 57 |
+
|
| 58 |
+
# One integer level controls multiple "generators" and how many
|
| 59 |
+
# combinations they generate. That is the "global" level.
|
| 60 |
+
# "Generators" are WGMMA shapes, MMA multipliers, cluster sizes, and
|
| 61 |
+
# anything that is eventually involved in the Cartesian product
|
| 62 |
+
# which yields our kernel configurations.
|
| 63 |
+
# For simplicity, each generator defines their own levels,
|
| 64 |
+
# starting from 0. As a rule we assume 10 or fewer levels, making
|
| 65 |
+
# their level a digit.
|
| 66 |
+
# The "global" level simply stacks these digits and represents them
|
| 67 |
+
# as a single integer.
|
| 68 |
+
#
|
| 69 |
+
# For example, level 500 indicates cluster sizes are at level 5, MMA
|
| 70 |
+
# multipliers are at level 0, and WGMMA shapes are at level 0 as well.
|
| 71 |
+
#
|
| 72 |
+
# Here we define the global level to generator level mappings.
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_tcgen05_level_from_global_level(global_level: int):
|
| 76 |
+
return global_level % 10
|
| 77 |
+
|
| 78 |
+
def get_mma_level_from_global_level(global_level: int):
|
| 79 |
+
return (global_level // 10) % 10
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_cluster_level_from_global_level(global_level: int):
|
| 83 |
+
return (global_level // 100) % 10
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_pruning_level_from_global_level(global_level: int):
|
| 87 |
+
return (global_level // 1000) % 10
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
#### Step 1: generate MMA instruction shapes based on levels
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
from .sm100_shapes import *
|
| 94 |
+
except:
|
| 95 |
+
from sm100_shapes import *
|
| 96 |
+
|
| 97 |
+
###########
|
| 98 |
+
|
| 99 |
+
def generate_tf32_math_instructions_sm100(level: int):
|
| 100 |
+
"""
|
| 101 |
+
Generate all TensorOp math instructions for TF32 MMA that are supported by SM100 at or above the given level.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
level: The global level to generate math instructions for.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
A tuple of two lists of MathInstruction objects.
|
| 108 |
+
The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM.
|
| 109 |
+
"""
|
| 110 |
+
tcgen05_level = get_tcgen05_level_from_global_level(level)
|
| 111 |
+
math_instructions_1sm = []
|
| 112 |
+
math_instructions_2sm = []
|
| 113 |
+
|
| 114 |
+
shapes_1sm = [
|
| 115 |
+
shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_1SM.items() if tcgen05_level >= min_level
|
| 116 |
+
]
|
| 117 |
+
shapes_2sm = [
|
| 118 |
+
shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_2SM.items() if tcgen05_level >= min_level
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
for shape in shapes_1sm:
|
| 122 |
+
math_instructions_1sm.append(
|
| 123 |
+
MathInstruction(
|
| 124 |
+
shape,
|
| 125 |
+
DataType.tf32, DataType.tf32, DataType.f32,
|
| 126 |
+
OpcodeClass.TensorOp,
|
| 127 |
+
MathOperation.multiply_add)
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
for shape in shapes_2sm:
|
| 131 |
+
math_instructions_2sm.append(
|
| 132 |
+
MathInstruction(
|
| 133 |
+
shape,
|
| 134 |
+
DataType.tf32, DataType.tf32, DataType.f32,
|
| 135 |
+
OpcodeClass.TensorOp,
|
| 136 |
+
MathOperation.multiply_add)
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
return math_instructions_1sm, math_instructions_2sm
|
| 140 |
+
|
| 141 |
+
def generate_16b_math_instructions_sm100(level: int):
|
| 142 |
+
"""
|
| 143 |
+
Generate all TensorOp math instructions for 16b MMA that are supported by SM100 at or above the given level.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
level: The global level to generate math instructions for.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
A tuple of two lists of MathInstruction objects.
|
| 150 |
+
The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM.
|
| 151 |
+
"""
|
| 152 |
+
tcgen05_level = get_tcgen05_level_from_global_level(level)
|
| 153 |
+
math_instructions_1sm = []
|
| 154 |
+
math_instructions_2sm = []
|
| 155 |
+
|
| 156 |
+
shapes_1sm = [
|
| 157 |
+
shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_1SM.items() if tcgen05_level >= min_level
|
| 158 |
+
]
|
| 159 |
+
shapes_2sm = [
|
| 160 |
+
shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_2SM.items() if tcgen05_level >= min_level
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
for shape in shapes_1sm:
|
| 164 |
+
math_instructions_1sm.append(
|
| 165 |
+
MathInstruction(
|
| 166 |
+
shape,
|
| 167 |
+
DataType.f16, DataType.f16, DataType.f16,
|
| 168 |
+
OpcodeClass.TensorOp,
|
| 169 |
+
MathOperation.multiply_add)
|
| 170 |
+
)
|
| 171 |
+
math_instructions_1sm.append(
|
| 172 |
+
MathInstruction(
|
| 173 |
+
shape,
|
| 174 |
+
DataType.f16, DataType.f16, DataType.f32,
|
| 175 |
+
OpcodeClass.TensorOp,
|
| 176 |
+
MathOperation.multiply_add)
|
| 177 |
+
)
|
| 178 |
+
math_instructions_1sm.append(
|
| 179 |
+
MathInstruction(
|
| 180 |
+
shape,
|
| 181 |
+
DataType.bf16, DataType.bf16, DataType.f32,
|
| 182 |
+
OpcodeClass.TensorOp,
|
| 183 |
+
MathOperation.multiply_add)
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
for shape in shapes_2sm:
|
| 188 |
+
math_instructions_2sm.append(
|
| 189 |
+
MathInstruction(
|
| 190 |
+
shape,
|
| 191 |
+
DataType.f16, DataType.f16, DataType.f16,
|
| 192 |
+
OpcodeClass.TensorOp,
|
| 193 |
+
MathOperation.multiply_add)
|
| 194 |
+
)
|
| 195 |
+
math_instructions_2sm.append(
|
| 196 |
+
MathInstruction(
|
| 197 |
+
shape,
|
| 198 |
+
DataType.f16, DataType.f16, DataType.f32,
|
| 199 |
+
OpcodeClass.TensorOp,
|
| 200 |
+
MathOperation.multiply_add)
|
| 201 |
+
)
|
| 202 |
+
math_instructions_2sm.append(
|
| 203 |
+
MathInstruction(
|
| 204 |
+
shape,
|
| 205 |
+
DataType.bf16, DataType.bf16, DataType.f32,
|
| 206 |
+
OpcodeClass.TensorOp,
|
| 207 |
+
MathOperation.multiply_add)
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
return math_instructions_1sm, math_instructions_2sm
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def generate_fp8_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True):
|
| 214 |
+
"""
|
| 215 |
+
Generate all TensorOp math instructions for FP8 MMA that are supported by SM100 at or above the given level.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
level: The global level to generate math instructions for.
|
| 219 |
+
enable_runtime_dtype: Whether to generate runtime dtype math instructions.
|
| 220 |
+
enable_compile_time_dtype: Whether to generate compile time dtype math instructions.
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
A tuple of two lists of MathInstruction objects.
|
| 224 |
+
The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM.
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
tcgen05_level = get_tcgen05_level_from_global_level(level)
|
| 228 |
+
pruning_level = get_pruning_level_from_global_level(level)
|
| 229 |
+
math_instructions_1sm = []
|
| 230 |
+
math_instructions_2sm = []
|
| 231 |
+
|
| 232 |
+
shapes_1sm = [
|
| 233 |
+
shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level
|
| 234 |
+
]
|
| 235 |
+
shapes_2sm = [
|
| 236 |
+
shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level
|
| 237 |
+
]
|
| 238 |
+
|
| 239 |
+
for shape in shapes_1sm:
|
| 240 |
+
if enable_runtime_dtype:
|
| 241 |
+
math_instructions_1sm.append(
|
| 242 |
+
MathInstruction(
|
| 243 |
+
shape,
|
| 244 |
+
DataType.f8, DataType.f8, DataType.f32,
|
| 245 |
+
OpcodeClass.TensorOp,
|
| 246 |
+
MathOperation.multiply_add)
|
| 247 |
+
)
|
| 248 |
+
if enable_compile_time_dtype:
|
| 249 |
+
math_instructions_1sm.append(
|
| 250 |
+
MathInstruction(
|
| 251 |
+
shape,
|
| 252 |
+
DataType.e4m3, DataType.e4m3, DataType.f32,
|
| 253 |
+
OpcodeClass.TensorOp,
|
| 254 |
+
MathOperation.multiply_add)
|
| 255 |
+
)
|
| 256 |
+
math_instructions_1sm.append(
|
| 257 |
+
MathInstruction(
|
| 258 |
+
shape,
|
| 259 |
+
DataType.e5m2, DataType.e4m3, DataType.f32,
|
| 260 |
+
OpcodeClass.TensorOp,
|
| 261 |
+
MathOperation.multiply_add)
|
| 262 |
+
)
|
| 263 |
+
math_instructions_1sm.append(
|
| 264 |
+
MathInstruction(
|
| 265 |
+
shape,
|
| 266 |
+
DataType.e4m3, DataType.e5m2, DataType.f32,
|
| 267 |
+
OpcodeClass.TensorOp,
|
| 268 |
+
MathOperation.multiply_add)
|
| 269 |
+
)
|
| 270 |
+
if pruning_level >= 2:
|
| 271 |
+
math_instructions_1sm.append(
|
| 272 |
+
MathInstruction(
|
| 273 |
+
shape,
|
| 274 |
+
DataType.e5m2, DataType.e5m2, DataType.f32,
|
| 275 |
+
OpcodeClass.TensorOp,
|
| 276 |
+
MathOperation.multiply_add)
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
for shape in shapes_2sm:
|
| 280 |
+
if enable_runtime_dtype:
|
| 281 |
+
math_instructions_2sm.append(
|
| 282 |
+
MathInstruction(
|
| 283 |
+
shape,
|
| 284 |
+
DataType.f8, DataType.f8, DataType.f32,
|
| 285 |
+
OpcodeClass.TensorOp,
|
| 286 |
+
MathOperation.multiply_add)
|
| 287 |
+
)
|
| 288 |
+
if enable_compile_time_dtype:
|
| 289 |
+
math_instructions_2sm.append(
|
| 290 |
+
MathInstruction(
|
| 291 |
+
shape,
|
| 292 |
+
DataType.e4m3, DataType.e4m3, DataType.f32,
|
| 293 |
+
OpcodeClass.TensorOp,
|
| 294 |
+
MathOperation.multiply_add)
|
| 295 |
+
)
|
| 296 |
+
math_instructions_2sm.append(
|
| 297 |
+
MathInstruction(
|
| 298 |
+
shape,
|
| 299 |
+
DataType.e5m2, DataType.e4m3, DataType.f32,
|
| 300 |
+
OpcodeClass.TensorOp,
|
| 301 |
+
MathOperation.multiply_add)
|
| 302 |
+
)
|
| 303 |
+
math_instructions_2sm.append(
|
| 304 |
+
MathInstruction(
|
| 305 |
+
shape,
|
| 306 |
+
DataType.e4m3, DataType.e5m2, DataType.f32,
|
| 307 |
+
OpcodeClass.TensorOp,
|
| 308 |
+
MathOperation.multiply_add)
|
| 309 |
+
)
|
| 310 |
+
if pruning_level >= 2:
|
| 311 |
+
math_instructions_2sm.append(
|
| 312 |
+
MathInstruction(
|
| 313 |
+
shape,
|
| 314 |
+
DataType.e5m2, DataType.e5m2, DataType.f32,
|
| 315 |
+
OpcodeClass.TensorOp,
|
| 316 |
+
MathOperation.multiply_add)
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
return math_instructions_1sm, math_instructions_2sm
|
| 320 |
+
|
| 321 |
+
def generate_f8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True):
|
| 322 |
+
"""
|
| 323 |
+
Generate all TensorOp math instructions for FP8 FP6 and FP4 MMA that are supported by SM100 at or above the given level.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
level: The global level to generate math instructions for.
|
| 327 |
+
enable_runtime_dtype: Whether to generate runtime dtype math instructions.
|
| 328 |
+
enable_compile_time_dtype: Whether to generate compile time dtype math instructions.
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
A tuple of two lists of MathInstruction objects.
|
| 332 |
+
The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM.
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
tcgen05_level = get_tcgen05_level_from_global_level(level)
|
| 336 |
+
math_instructions_1sm = []
|
| 337 |
+
math_instructions_2sm = []
|
| 338 |
+
|
| 339 |
+
shapes_1sm = [
|
| 340 |
+
shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level
|
| 341 |
+
]
|
| 342 |
+
shapes_2sm = [
|
| 343 |
+
shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level
|
| 344 |
+
]
|
| 345 |
+
|
| 346 |
+
for shape in shapes_1sm:
|
| 347 |
+
if enable_runtime_dtype:
|
| 348 |
+
|
| 349 |
+
runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ]
|
| 350 |
+
|
| 351 |
+
for a_type, b_type in product(runtime_types, repeat=2):
|
| 352 |
+
math_instructions_1sm.append(
|
| 353 |
+
MathInstruction(
|
| 354 |
+
shape,
|
| 355 |
+
a_type, b_type, DataType.f32,
|
| 356 |
+
OpcodeClass.TensorOp,
|
| 357 |
+
MathOperation.multiply_add)
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
if enable_compile_time_dtype:
|
| 361 |
+
compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ]
|
| 362 |
+
|
| 363 |
+
for a_type, b_type in product(compile_time_types, repeat=2):
|
| 364 |
+
math_instructions_1sm.append(
|
| 365 |
+
MathInstruction(
|
| 366 |
+
shape,
|
| 367 |
+
a_type, b_type, DataType.f32,
|
| 368 |
+
OpcodeClass.TensorOp,
|
| 369 |
+
MathOperation.multiply_add)
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
for shape in shapes_2sm:
|
| 374 |
+
if enable_runtime_dtype:
|
| 375 |
+
|
| 376 |
+
runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ]
|
| 377 |
+
|
| 378 |
+
for a_type, b_type in product(runtime_types, repeat=2):
|
| 379 |
+
math_instructions_2sm.append(
|
| 380 |
+
MathInstruction(
|
| 381 |
+
shape,
|
| 382 |
+
a_type, b_type, DataType.f32,
|
| 383 |
+
OpcodeClass.TensorOp,
|
| 384 |
+
MathOperation.multiply_add)
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
if enable_compile_time_dtype:
|
| 388 |
+
compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ]
|
| 389 |
+
|
| 390 |
+
for a_type, b_type in product(compile_time_types, repeat=2):
|
| 391 |
+
math_instructions_2sm.append(
|
| 392 |
+
MathInstruction(
|
| 393 |
+
shape,
|
| 394 |
+
a_type, b_type, DataType.f32,
|
| 395 |
+
OpcodeClass.TensorOp,
|
| 396 |
+
MathOperation.multiply_add)
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
return math_instructions_1sm, math_instructions_2sm
|
| 400 |
+
|
| 401 |
+
def generate_mxf8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True):
|
| 402 |
+
"""
|
| 403 |
+
Generate all BlockScaledTensorOp math instructions for MXFP8, MXFP6, and MXFP4 MMA that are supported by SM100 at or above the given level.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
level: The global level to generate math instructions for.
|
| 407 |
+
enable_runtime_dtype: Whether to generate runtime dtype math instructions.
|
| 408 |
+
enable_compile_time_dtype: Whether to generate compile time dtype math instructions.
|
| 409 |
+
|
| 410 |
+
Returns:
|
| 411 |
+
A tuple of two lists of MathInstruction objects.
|
| 412 |
+
The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM.
|
| 413 |
+
"""
|
| 414 |
+
|
| 415 |
+
tcgen05_level = get_tcgen05_level_from_global_level(level)
|
| 416 |
+
pruning_level = get_pruning_level_from_global_level(level)
|
| 417 |
+
|
| 418 |
+
math_instructions_1sm = []
|
| 419 |
+
math_instructions_2sm = []
|
| 420 |
+
|
| 421 |
+
shapes_1sm = [
|
| 422 |
+
shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level
|
| 423 |
+
]
|
| 424 |
+
shapes_2sm = [
|
| 425 |
+
shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level
|
| 426 |
+
]
|
| 427 |
+
|
| 428 |
+
for shape in shapes_1sm:
|
| 429 |
+
if enable_runtime_dtype:
|
| 430 |
+
|
| 431 |
+
runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ]
|
| 432 |
+
|
| 433 |
+
for a_type, b_type in product(runtime_types, repeat=2):
|
| 434 |
+
|
| 435 |
+
if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)):
|
| 436 |
+
continue
|
| 437 |
+
|
| 438 |
+
math_instructions_1sm.append(
|
| 439 |
+
MathInstruction(
|
| 440 |
+
shape,
|
| 441 |
+
a_type, b_type, DataType.f32,
|
| 442 |
+
OpcodeClass.BlockScaledTensorOp,
|
| 443 |
+
MathOperation.multiply_add,
|
| 444 |
+
DataType.ue8m0)
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
if enable_compile_time_dtype:
|
| 448 |
+
compile_time_types = [ DataType.e4m3,
|
| 449 |
+
DataType.e5m2,
|
| 450 |
+
DataType.e3m2,
|
| 451 |
+
DataType.e2m3,
|
| 452 |
+
DataType.e2m1 ]
|
| 453 |
+
|
| 454 |
+
for a_type, b_type in product(compile_time_types, repeat=2):
|
| 455 |
+
math_instructions_1sm.append(
|
| 456 |
+
MathInstruction(
|
| 457 |
+
shape,
|
| 458 |
+
a_type, b_type, DataType.f32,
|
| 459 |
+
OpcodeClass.BlockScaledTensorOp,
|
| 460 |
+
MathOperation.multiply_add,
|
| 461 |
+
DataType.ue8m0)
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
for shape in shapes_2sm:
|
| 466 |
+
if enable_runtime_dtype:
|
| 467 |
+
|
| 468 |
+
runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ]
|
| 469 |
+
|
| 470 |
+
for a_type, b_type in product(runtime_types, repeat=2):
|
| 471 |
+
|
| 472 |
+
if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)):
|
| 473 |
+
continue
|
| 474 |
+
|
| 475 |
+
math_instructions_2sm.append(
|
| 476 |
+
MathInstruction(
|
| 477 |
+
shape,
|
| 478 |
+
a_type, b_type, DataType.f32,
|
| 479 |
+
OpcodeClass.BlockScaledTensorOp,
|
| 480 |
+
MathOperation.multiply_add,
|
| 481 |
+
DataType.ue8m0)
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
if enable_compile_time_dtype:
|
| 485 |
+
compile_time_types = [ DataType.e4m3,
|
| 486 |
+
DataType.e5m2,
|
| 487 |
+
DataType.e3m2,
|
| 488 |
+
DataType.e2m3,
|
| 489 |
+
DataType.e2m1 ]
|
| 490 |
+
|
| 491 |
+
for a_type, b_type in product(compile_time_types, repeat=2):
|
| 492 |
+
math_instructions_2sm.append(
|
| 493 |
+
MathInstruction(
|
| 494 |
+
shape,
|
| 495 |
+
a_type, b_type, DataType.f32,
|
| 496 |
+
OpcodeClass.BlockScaledTensorOp,
|
| 497 |
+
MathOperation.multiply_add,
|
| 498 |
+
DataType.ue8m0)
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
return math_instructions_1sm, math_instructions_2sm
|
| 502 |
+
|
| 503 |
+
def generate_mxf4nvf4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True):
|
| 504 |
+
"""
|
| 505 |
+
Generate all BlockScaledTensorOp math instructions for MXFP4 and MXFP4 MMA that are supported by SM100 at or above the given level.
|
| 506 |
+
|
| 507 |
+
Args:
|
| 508 |
+
level: The global level to generate math instructions for.
|
| 509 |
+
enable_runtime_dtype: Whether to generate runtime dtype math instructions.
|
| 510 |
+
enable_compile_time_dtype: Whether to generate compile time dtype math instructions.
|
| 511 |
+
|
| 512 |
+
Returns:
|
| 513 |
+
A tuple of two lists of MathInstruction objects.
|
| 514 |
+
The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM.
|
| 515 |
+
"""
|
| 516 |
+
tcgen05_level = get_tcgen05_level_from_global_level(level)
|
| 517 |
+
math_instructions_1sm = []
|
| 518 |
+
math_instructions_2sm = []
|
| 519 |
+
|
| 520 |
+
shapes_1sm = [
|
| 521 |
+
shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM.items() if tcgen05_level >= min_level
|
| 522 |
+
]
|
| 523 |
+
shapes_2sm = [
|
| 524 |
+
shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM.items() if tcgen05_level >= min_level
|
| 525 |
+
]
|
| 526 |
+
|
| 527 |
+
for shape in shapes_1sm:
|
| 528 |
+
if enable_runtime_dtype:
|
| 529 |
+
|
| 530 |
+
runtime_types = [ DataType.f4 ]
|
| 531 |
+
|
| 532 |
+
for a_type, b_type in product(runtime_types, repeat=2):
|
| 533 |
+
math_instructions_1sm.append(
|
| 534 |
+
MathInstruction(
|
| 535 |
+
shape,
|
| 536 |
+
a_type, b_type, DataType.f32,
|
| 537 |
+
OpcodeClass.BlockScaledTensorOp,
|
| 538 |
+
MathOperation.multiply_add,
|
| 539 |
+
DataType.ue8m0)
|
| 540 |
+
)
|
| 541 |
+
math_instructions_1sm.append(
|
| 542 |
+
MathInstruction(
|
| 543 |
+
shape,
|
| 544 |
+
a_type, b_type, DataType.f32,
|
| 545 |
+
OpcodeClass.BlockScaledTensorOp,
|
| 546 |
+
MathOperation.multiply_add,
|
| 547 |
+
DataType.ue4m3)
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
if enable_compile_time_dtype:
|
| 552 |
+
compile_time_types = [ DataType.e2m1,
|
| 553 |
+
]
|
| 554 |
+
|
| 555 |
+
for a_type, b_type in product(compile_time_types, repeat=2):
|
| 556 |
+
math_instructions_1sm.append(
|
| 557 |
+
MathInstruction(
|
| 558 |
+
shape,
|
| 559 |
+
a_type, b_type, DataType.f32,
|
| 560 |
+
OpcodeClass.BlockScaledTensorOp,
|
| 561 |
+
MathOperation.multiply_add,
|
| 562 |
+
DataType.ue8m0)
|
| 563 |
+
)
|
| 564 |
+
math_instructions_1sm.append(
|
| 565 |
+
MathInstruction(
|
| 566 |
+
shape,
|
| 567 |
+
a_type, b_type, DataType.f32,
|
| 568 |
+
OpcodeClass.BlockScaledTensorOp,
|
| 569 |
+
MathOperation.multiply_add,
|
| 570 |
+
DataType.ue4m3)
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
for shape in shapes_2sm:
|
| 575 |
+
if enable_runtime_dtype:
|
| 576 |
+
|
| 577 |
+
runtime_types = [ DataType.f4 ]
|
| 578 |
+
|
| 579 |
+
for a_type, b_type in product(runtime_types, repeat=2):
|
| 580 |
+
math_instructions_2sm.append(
|
| 581 |
+
MathInstruction(
|
| 582 |
+
shape,
|
| 583 |
+
a_type, b_type, DataType.f32,
|
| 584 |
+
OpcodeClass.BlockScaledTensorOp,
|
| 585 |
+
MathOperation.multiply_add,
|
| 586 |
+
DataType.ue8m0)
|
| 587 |
+
)
|
| 588 |
+
math_instructions_2sm.append(
|
| 589 |
+
MathInstruction(
|
| 590 |
+
shape,
|
| 591 |
+
a_type, b_type, DataType.f32,
|
| 592 |
+
OpcodeClass.BlockScaledTensorOp,
|
| 593 |
+
MathOperation.multiply_add,
|
| 594 |
+
DataType.ue4m3)
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
if enable_compile_time_dtype:
|
| 599 |
+
compile_time_types = [ DataType.e2m1,
|
| 600 |
+
]
|
| 601 |
+
|
| 602 |
+
for a_type, b_type in product(compile_time_types, repeat=2):
|
| 603 |
+
math_instructions_2sm.append(
|
| 604 |
+
MathInstruction(
|
| 605 |
+
shape,
|
| 606 |
+
a_type, b_type, DataType.f32,
|
| 607 |
+
OpcodeClass.BlockScaledTensorOp,
|
| 608 |
+
MathOperation.multiply_add,
|
| 609 |
+
DataType.ue8m0)
|
| 610 |
+
)
|
| 611 |
+
math_instructions_2sm.append(
|
| 612 |
+
MathInstruction(
|
| 613 |
+
shape,
|
| 614 |
+
a_type, b_type, DataType.f32,
|
| 615 |
+
OpcodeClass.BlockScaledTensorOp,
|
| 616 |
+
MathOperation.multiply_add,
|
| 617 |
+
DataType.ue4m3)
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
return math_instructions_1sm, math_instructions_2sm
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def generate_cluster_shapes_sm100(level: int, change_priority_func : Union[Callable, None] = None):
|
| 625 |
+
"""
|
| 626 |
+
Generate all cluster shapes for SM100 at or above the given level.
|
| 627 |
+
|
| 628 |
+
Args:
|
| 629 |
+
level: The global level to generate cluster shapes for.
|
| 630 |
+
|
| 631 |
+
Returns:
|
| 632 |
+
A tuple of two lists of cluster shapes.
|
| 633 |
+
The first list contains the cluster shapes for 1SM, and the second list contains the cluster shapes for 2SM.
|
| 634 |
+
"""
|
| 635 |
+
cluster_level = get_cluster_level_from_global_level(level)
|
| 636 |
+
|
| 637 |
+
assert cluster_level >= 4
|
| 638 |
+
|
| 639 |
+
if change_priority_func is not None:
|
| 640 |
+
SM100_CLUSTER_SHAPES_1SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_1SM)
|
| 641 |
+
SM100_CLUSTER_SHAPES_2SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_2SM)
|
| 642 |
+
change_priority_func(SM100_CLUSTER_SHAPES_1SM_CPY, SM100_CLUSTER_SHAPES_2SM_CPY)
|
| 643 |
+
shapes_1sm = [
|
| 644 |
+
list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM_CPY.items() if cluster_level >= min_level
|
| 645 |
+
]
|
| 646 |
+
shapes_2sm = [
|
| 647 |
+
list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM_CPY.items() if cluster_level >= min_level
|
| 648 |
+
]
|
| 649 |
+
|
| 650 |
+
return shapes_1sm, shapes_2sm
|
| 651 |
+
|
| 652 |
+
else:
|
| 653 |
+
|
| 654 |
+
shapes_1sm = [
|
| 655 |
+
list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM.items() if cluster_level >= min_level
|
| 656 |
+
]
|
| 657 |
+
shapes_2sm = [
|
| 658 |
+
list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM.items() if cluster_level >= min_level
|
| 659 |
+
]
|
| 660 |
+
|
| 661 |
+
return shapes_1sm, shapes_2sm
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Valid WGMMA shapes, MMA multipliers, and cluster sizes for SM90, associated with levels.
|
| 35 |
+
These shape and level pairs are defined as dicts, where keys are shapes and values are their
|
| 36 |
+
associated levels. If the user input level for that category (MMA multiplier, WGMMA shape, cluster
|
| 37 |
+
size) is smaller than a shape's associated level, it will be excluded, and otherwise, included.
|
| 38 |
+
Higher levels are therefore less likely emitted, but lower levels are more emitted more frequently.
|
| 39 |
+
Level 0 is always emitted. The default behavior in `generator.py` is that level 1 is only emitted
|
| 40 |
+
when the `--kernel` argument is non-empty.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
# NOTE: more combinations are possible here.
|
| 44 |
+
# Levels [0, 3] exist in order to control exactly what configs are generated in different dtypes.
|
| 45 |
+
# The rest are only used in the exhaustive mode (when the corresponding level digit is 9).
|
| 46 |
+
# MMA multipliers are multiplied by MMA instruction shapes (WGMMA shapes) to get CTA shapes.
|
| 47 |
+
SM90_MMA_MULTIPLIERS = {
|
| 48 |
+
(2, 1, 4): 0,
|
| 49 |
+
(1, 1, 4): 1,
|
| 50 |
+
(4, 1, 4): 2,
|
| 51 |
+
(2, 2, 4): 3,
|
| 52 |
+
(2, 1, 8): 4,
|
| 53 |
+
(4, 1, 8): 4,
|
| 54 |
+
(1, 1, 8): 4,
|
| 55 |
+
(2, 2, 8): 4,
|
| 56 |
+
(2, 1, 16): 5,
|
| 57 |
+
(4, 1, 16): 5,
|
| 58 |
+
(1, 1, 16): 5,
|
| 59 |
+
(2, 2, 16): 5,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# Level 0: only (1, 2, 1) -- fp8 dense gemms in pruned case
|
| 63 |
+
# Level 1: clusters with 2 CTAs -- all but fp8 (s8, u8, f16, b16, f32, tf32) dense gemms in pruned case
|
| 64 |
+
# Level 2: clusters with 1 or 2 CTAs
|
| 65 |
+
# Level 3: clusters with 1, 2, or 4 CTAs
|
| 66 |
+
# Level 4: clusters with 1, 2, 4, or 8 CTAs
|
| 67 |
+
# Level 5: clusters with 1, 2, 4, 8, or 16 CTAs
|
| 68 |
+
SM90_CLUSTER_SIZES = {
|
| 69 |
+
(1, 2, 1): 0,
|
| 70 |
+
(2, 1, 1): 1,
|
| 71 |
+
(1, 1, 1): 2,
|
| 72 |
+
(2, 2, 1): 3,
|
| 73 |
+
(1, 4, 1): 3,
|
| 74 |
+
(4, 1, 1): 3,
|
| 75 |
+
(2, 4, 1): 4,
|
| 76 |
+
(4, 2, 1): 4,
|
| 77 |
+
(1, 8, 1): 4,
|
| 78 |
+
(8, 1, 1): 4,
|
| 79 |
+
(4, 4, 1): 5,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# WGMMA shapes
|
| 84 |
+
# Level 0: "default" shape only,
|
| 85 |
+
# Level 1: additional shapes for the unpruned case (tf32 only)
|
| 86 |
+
# Level 2: shapes that are all powers of 2
|
| 87 |
+
# Level 3: all other shapes
|
| 88 |
+
SM90_WGMMA_SHAPES_FP16_BF16_DENSE = {
|
| 89 |
+
(64, 8, 16): 2,
|
| 90 |
+
(64, 16, 16): 2,
|
| 91 |
+
(64, 24, 16): 3,
|
| 92 |
+
(64, 32, 16): 2,
|
| 93 |
+
(64, 40, 16): 3,
|
| 94 |
+
(64, 48, 16): 3,
|
| 95 |
+
(64, 56, 16): 3,
|
| 96 |
+
(64, 64, 16): 2,
|
| 97 |
+
(64, 72, 16): 3,
|
| 98 |
+
(64, 80, 16): 3,
|
| 99 |
+
(64, 88, 16): 3,
|
| 100 |
+
(64, 96, 16): 3,
|
| 101 |
+
(64, 104, 16): 3,
|
| 102 |
+
(64, 112, 16): 3,
|
| 103 |
+
(64, 120, 16): 3,
|
| 104 |
+
(64, 128, 16): 0,
|
| 105 |
+
(64, 136, 16): 3,
|
| 106 |
+
(64, 144, 16): 3,
|
| 107 |
+
(64, 152, 16): 3,
|
| 108 |
+
(64, 160, 16): 3,
|
| 109 |
+
(64, 168, 16): 3,
|
| 110 |
+
(64, 176, 16): 3,
|
| 111 |
+
(64, 184, 16): 3,
|
| 112 |
+
(64, 192, 16): 3,
|
| 113 |
+
(64, 200, 16): 3,
|
| 114 |
+
(64, 208, 16): 3,
|
| 115 |
+
(64, 216, 16): 3,
|
| 116 |
+
(64, 224, 16): 3,
|
| 117 |
+
(64, 232, 16): 3,
|
| 118 |
+
(64, 240, 16): 3,
|
| 119 |
+
(64, 248, 16): 3,
|
| 120 |
+
(64, 256, 16): 1,
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
SM90_WGMMA_SHAPES_TF32_DENSE = {
|
| 124 |
+
(64, 8, 8): 2,
|
| 125 |
+
(64, 16, 8): 2,
|
| 126 |
+
(64, 24, 8): 3,
|
| 127 |
+
(64, 32, 8): 2,
|
| 128 |
+
(64, 40, 8): 3,
|
| 129 |
+
(64, 48, 8): 3,
|
| 130 |
+
(64, 56, 8): 3,
|
| 131 |
+
(64, 64, 8): 2,
|
| 132 |
+
(64, 72, 8): 3,
|
| 133 |
+
(64, 80, 8): 3,
|
| 134 |
+
(64, 88, 8): 3,
|
| 135 |
+
(64, 96, 8): 3,
|
| 136 |
+
(64, 104, 8): 3,
|
| 137 |
+
(64, 112, 8): 3,
|
| 138 |
+
(64, 120, 8): 3,
|
| 139 |
+
(64, 128, 8): 0,
|
| 140 |
+
(64, 136, 8): 3,
|
| 141 |
+
(64, 144, 8): 3,
|
| 142 |
+
(64, 152, 8): 3,
|
| 143 |
+
(64, 160, 8): 3,
|
| 144 |
+
(64, 168, 8): 3,
|
| 145 |
+
(64, 176, 8): 3,
|
| 146 |
+
(64, 184, 8): 3,
|
| 147 |
+
(64, 192, 8): 3,
|
| 148 |
+
(64, 200, 8): 3,
|
| 149 |
+
(64, 208, 8): 3,
|
| 150 |
+
(64, 216, 8): 3,
|
| 151 |
+
(64, 224, 8): 3,
|
| 152 |
+
(64, 232, 8): 3,
|
| 153 |
+
(64, 240, 8): 3,
|
| 154 |
+
(64, 248, 8): 3,
|
| 155 |
+
(64, 256, 8): 1,
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
SM90_WGMMA_SHAPES_FP8_DENSE = {
|
| 159 |
+
(64, 8, 32): 2,
|
| 160 |
+
(64, 16, 32): 2,
|
| 161 |
+
(64, 24, 32): 3,
|
| 162 |
+
(64, 32, 32): 2,
|
| 163 |
+
(64, 40, 32): 3,
|
| 164 |
+
(64, 48, 32): 3,
|
| 165 |
+
(64, 56, 32): 3,
|
| 166 |
+
(64, 64, 32): 2,
|
| 167 |
+
(64, 72, 32): 3,
|
| 168 |
+
(64, 80, 32): 3,
|
| 169 |
+
(64, 88, 32): 3,
|
| 170 |
+
(64, 96, 32): 3,
|
| 171 |
+
(64, 104, 32): 3,
|
| 172 |
+
(64, 112, 32): 3,
|
| 173 |
+
(64, 120, 32): 3,
|
| 174 |
+
(64, 128, 32): 0,
|
| 175 |
+
(64, 136, 32): 3,
|
| 176 |
+
(64, 144, 32): 3,
|
| 177 |
+
(64, 152, 32): 3,
|
| 178 |
+
(64, 160, 32): 3,
|
| 179 |
+
(64, 168, 32): 3,
|
| 180 |
+
(64, 176, 32): 3,
|
| 181 |
+
(64, 184, 32): 3,
|
| 182 |
+
(64, 192, 32): 3,
|
| 183 |
+
(64, 200, 32): 3,
|
| 184 |
+
(64, 208, 32): 3,
|
| 185 |
+
(64, 216, 32): 3,
|
| 186 |
+
(64, 224, 32): 3,
|
| 187 |
+
(64, 232, 32): 3,
|
| 188 |
+
(64, 240, 32): 3,
|
| 189 |
+
(64, 248, 32): 3,
|
| 190 |
+
(64, 256, 32): 1,
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
SM90_WGMMA_SHAPES_INT8_DENSE = {
|
| 194 |
+
(64, 8, 32): 2,
|
| 195 |
+
(64, 16, 32): 2,
|
| 196 |
+
(64, 24, 32): 3,
|
| 197 |
+
(64, 32, 32): 2,
|
| 198 |
+
(64, 48, 32): 3,
|
| 199 |
+
(64, 64, 32): 2,
|
| 200 |
+
(64, 80, 32): 3,
|
| 201 |
+
(64, 96, 32): 3,
|
| 202 |
+
(64, 112, 32): 3,
|
| 203 |
+
(64, 128, 32): 0,
|
| 204 |
+
(64, 144, 32): 3,
|
| 205 |
+
(64, 160, 32): 3,
|
| 206 |
+
(64, 176, 32): 3,
|
| 207 |
+
(64, 192, 32): 3,
|
| 208 |
+
(64, 208, 32): 3,
|
| 209 |
+
(64, 224, 32): 3,
|
| 210 |
+
(64, 240, 32): 3,
|
| 211 |
+
(64, 256, 32): 1,
|
| 212 |
+
}
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py
ADDED
|
@@ -0,0 +1,753 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for enumerating CUTLASS library SM90 kernels
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import argparse
|
| 38 |
+
import enum
|
| 39 |
+
from itertools import product
|
| 40 |
+
import math
|
| 41 |
+
import logging
|
| 42 |
+
import os.path
|
| 43 |
+
import shutil
|
| 44 |
+
import sys
|
| 45 |
+
import copy
|
| 46 |
+
from typing import Any, Optional, Sequence, Tuple, List
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
import builtins
|
| 50 |
+
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 51 |
+
raise ImportError("Disabling attempt to import cutlass_library")
|
| 52 |
+
from cutlass_library.library import *
|
| 53 |
+
except ImportError:
|
| 54 |
+
from library import *
|
| 55 |
+
|
| 56 |
+
# NOTE: this is a duplicate of CudaToolkitVersionSatisfies in generator.py
|
| 57 |
+
def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0):
|
| 58 |
+
|
| 59 |
+
# by default, use the latest CUDA Toolkit version
|
| 60 |
+
cuda_version = [11, 0, 132]
|
| 61 |
+
|
| 62 |
+
# Update cuda_version based on parsed string
|
| 63 |
+
if semantic_ver_string != '':
|
| 64 |
+
for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')[:3]]):
|
| 65 |
+
if i < len(cuda_version):
|
| 66 |
+
cuda_version[i] = x
|
| 67 |
+
else:
|
| 68 |
+
cuda_version.append(x)
|
| 69 |
+
return cuda_version >= [major, minor, patch]
|
| 70 |
+
|
| 71 |
+
#### Step 0: define levels
|
| 72 |
+
|
| 73 |
+
# One integer level controls multiple "generators" and how many
|
| 74 |
+
# combinations they generate. That is the "global" level.
|
| 75 |
+
# "Generators" are WGMMA shapes, MMA multipliers, cluster sizes, and
|
| 76 |
+
# anything that is eventually involved in the Cartesian product
|
| 77 |
+
# which yields our kernel configurations.
|
| 78 |
+
# For simplicity, each generator defines their own levels,
|
| 79 |
+
# starting from 0. As a rule we assume 10 or fewer levels, making
|
| 80 |
+
# their level a digit.
|
| 81 |
+
# The "global" level simply stacks these digits and represents them
|
| 82 |
+
# as a single integer.
|
| 83 |
+
#
|
| 84 |
+
# For example, level 500 indicates cluster sizes are at level 5, MMA
|
| 85 |
+
# multipliers are at level 0, and WGMMA shapes are at level 0 as well.
|
| 86 |
+
#
|
| 87 |
+
# Here we define the global level to generator level mappings.
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_wgmma_level_from_global_level(global_level: int):
|
| 91 |
+
return global_level % 10
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_mma_level_from_global_level(global_level: int):
|
| 95 |
+
return (global_level // 10) % 10
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_cluster_level_from_global_level(global_level: int):
|
| 99 |
+
return (global_level // 100) % 10
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def get_pruning_level_from_global_level(global_level: int):
|
| 103 |
+
return (global_level // 1000) % 10
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
#### Step 1: generate MMA instruction shapes based on levels
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
from .sm90_shapes import (
|
| 110 |
+
SM90_MMA_MULTIPLIERS,
|
| 111 |
+
SM90_CLUSTER_SIZES,
|
| 112 |
+
SM90_WGMMA_SHAPES_TF32_DENSE,
|
| 113 |
+
SM90_WGMMA_SHAPES_FP16_BF16_DENSE,
|
| 114 |
+
SM90_WGMMA_SHAPES_FP8_DENSE,
|
| 115 |
+
SM90_WGMMA_SHAPES_INT8_DENSE,
|
| 116 |
+
)
|
| 117 |
+
except:
|
| 118 |
+
from sm90_shapes import (
|
| 119 |
+
SM90_MMA_MULTIPLIERS,
|
| 120 |
+
SM90_CLUSTER_SIZES,
|
| 121 |
+
SM90_WGMMA_SHAPES_TF32_DENSE,
|
| 122 |
+
SM90_WGMMA_SHAPES_FP16_BF16_DENSE,
|
| 123 |
+
SM90_WGMMA_SHAPES_FP8_DENSE,
|
| 124 |
+
SM90_WGMMA_SHAPES_INT8_DENSE,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def generate_tf32_math_instruction_shapes_sm90(level: int):
|
| 129 |
+
assert isinstance(level, int) and level >= 0
|
| 130 |
+
filtered_list_of_wgmma_shapes = [
|
| 131 |
+
wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_TF32_DENSE.items() if level >= min_level
|
| 132 |
+
]
|
| 133 |
+
return filtered_list_of_wgmma_shapes
|
| 134 |
+
|
| 135 |
+
def generate_fp16_bf16_math_instruction_shapes_sm90(level: int):
|
| 136 |
+
assert isinstance(level, int) and level >= 0
|
| 137 |
+
filtered_list_of_wgmma_shapes = [
|
| 138 |
+
wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_FP16_BF16_DENSE.items() if level >= min_level
|
| 139 |
+
]
|
| 140 |
+
return filtered_list_of_wgmma_shapes
|
| 141 |
+
|
| 142 |
+
def generate_fp8_math_instruction_shapes_sm90(level: int):
|
| 143 |
+
assert isinstance(level, int) and level >= 0
|
| 144 |
+
filtered_list_of_wgmma_shapes = [
|
| 145 |
+
wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_FP8_DENSE.items() if level >= min_level
|
| 146 |
+
]
|
| 147 |
+
return filtered_list_of_wgmma_shapes
|
| 148 |
+
|
| 149 |
+
def generate_int8_math_instruction_shapes_sm90(level: int):
|
| 150 |
+
assert isinstance(level, int) and level >= 0
|
| 151 |
+
filtered_list_of_wgmma_shapes = [
|
| 152 |
+
wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_INT8_DENSE.items() if level >= min_level
|
| 153 |
+
]
|
| 154 |
+
return filtered_list_of_wgmma_shapes
|
| 155 |
+
|
| 156 |
+
def generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level: int, a_type: DataType, b_type: DataType):
|
| 157 |
+
# DataTypeSize are in the unit of bits
|
| 158 |
+
a_bytes = DataTypeSize[a_type] // 8
|
| 159 |
+
b_bytes = DataTypeSize[b_type] // 8
|
| 160 |
+
if a_bytes == 4 or b_bytes == 4:
|
| 161 |
+
return generate_tf32_math_instruction_shapes_sm90(wgmma_level)
|
| 162 |
+
elif a_bytes == 2 or b_bytes == 2:
|
| 163 |
+
return generate_fp16_bf16_math_instruction_shapes_sm90(wgmma_level)
|
| 164 |
+
else:
|
| 165 |
+
return generate_fp8_math_instruction_shapes_sm90(wgmma_level)
|
| 166 |
+
|
| 167 |
+
###########
|
| 168 |
+
|
| 169 |
+
def generate_tf32_math_instructions_sm90(level: int):
|
| 170 |
+
wgmma_level = get_wgmma_level_from_global_level(level)
|
| 171 |
+
math_instructions = []
|
| 172 |
+
for math_instruction_shape in generate_tf32_math_instruction_shapes_sm90(wgmma_level):
|
| 173 |
+
math_instructions.append(
|
| 174 |
+
MathInstruction(
|
| 175 |
+
math_instruction_shape,
|
| 176 |
+
DataType.tf32, DataType.tf32, DataType.f32,
|
| 177 |
+
OpcodeClass.TensorOp,
|
| 178 |
+
MathOperation.multiply_add)
|
| 179 |
+
)
|
| 180 |
+
return math_instructions
|
| 181 |
+
|
| 182 |
+
def generate_fp16_bf16_math_instructions_sm90(level: int):
|
| 183 |
+
wgmma_level = get_wgmma_level_from_global_level(level)
|
| 184 |
+
math_instructions = []
|
| 185 |
+
for math_instruction_shape in generate_fp16_bf16_math_instruction_shapes_sm90(wgmma_level):
|
| 186 |
+
math_instructions += [
|
| 187 |
+
MathInstruction(
|
| 188 |
+
math_instruction_shape,
|
| 189 |
+
DataType.f16, DataType.f16, DataType.f16,
|
| 190 |
+
OpcodeClass.TensorOp,
|
| 191 |
+
MathOperation.multiply_add),
|
| 192 |
+
MathInstruction(
|
| 193 |
+
math_instruction_shape,
|
| 194 |
+
DataType.f16, DataType.f16, DataType.f32,
|
| 195 |
+
OpcodeClass.TensorOp,
|
| 196 |
+
MathOperation.multiply_add),
|
| 197 |
+
MathInstruction(
|
| 198 |
+
math_instruction_shape,
|
| 199 |
+
DataType.bf16, DataType.bf16, DataType.f32,
|
| 200 |
+
OpcodeClass.TensorOp,
|
| 201 |
+
MathOperation.multiply_add),
|
| 202 |
+
]
|
| 203 |
+
return math_instructions
|
| 204 |
+
|
| 205 |
+
def generate_fp8_math_instructions_sm90(level: int):
|
| 206 |
+
wgmma_level = get_wgmma_level_from_global_level(level)
|
| 207 |
+
math_instructions = []
|
| 208 |
+
for math_instruction_shape in generate_fp8_math_instruction_shapes_sm90(wgmma_level):
|
| 209 |
+
math_instructions += [
|
| 210 |
+
MathInstruction(
|
| 211 |
+
math_instruction_shape,
|
| 212 |
+
DataType.e4m3, DataType.e4m3, DataType.f32,
|
| 213 |
+
OpcodeClass.TensorOp,
|
| 214 |
+
MathOperation.multiply_add),
|
| 215 |
+
MathInstruction(
|
| 216 |
+
math_instruction_shape,
|
| 217 |
+
DataType.e4m3, DataType.e5m2, DataType.f32,
|
| 218 |
+
OpcodeClass.TensorOp,
|
| 219 |
+
MathOperation.multiply_add),
|
| 220 |
+
MathInstruction(
|
| 221 |
+
math_instruction_shape,
|
| 222 |
+
DataType.e5m2, DataType.e4m3, DataType.f32,
|
| 223 |
+
OpcodeClass.TensorOp,
|
| 224 |
+
MathOperation.multiply_add),
|
| 225 |
+
MathInstruction(
|
| 226 |
+
math_instruction_shape,
|
| 227 |
+
DataType.e5m2, DataType.e5m2, DataType.f32,
|
| 228 |
+
OpcodeClass.TensorOp,
|
| 229 |
+
MathOperation.multiply_add),
|
| 230 |
+
]
|
| 231 |
+
return math_instructions
|
| 232 |
+
|
| 233 |
+
def generate_mixed_dtype_math_instructions_sm90(level: int, types_of_a_b_acc: List[Tuple[DataType, DataType, DataType]]):
|
| 234 |
+
wgmma_level = get_wgmma_level_from_global_level(level)
|
| 235 |
+
math_instructions = []
|
| 236 |
+
for a_type, b_type, acc_type in types_of_a_b_acc:
|
| 237 |
+
math_instruction_shapes = generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level, a_type, b_type)
|
| 238 |
+
for math_instruction_shape in math_instruction_shapes:
|
| 239 |
+
math_instructions += [
|
| 240 |
+
MathInstruction(
|
| 241 |
+
math_instruction_shape,
|
| 242 |
+
a_type, b_type, acc_type,
|
| 243 |
+
OpcodeClass.TensorOp,
|
| 244 |
+
MathOperation.multiply_add
|
| 245 |
+
),
|
| 246 |
+
]
|
| 247 |
+
return math_instructions
|
| 248 |
+
|
| 249 |
+
def generate_int8_math_instructions_sm90(level: int):
|
| 250 |
+
wgmma_level = get_wgmma_level_from_global_level(level)
|
| 251 |
+
math_instructions = []
|
| 252 |
+
for math_instruction_shape in generate_int8_math_instruction_shapes_sm90(wgmma_level):
|
| 253 |
+
math_instructions += [
|
| 254 |
+
MathInstruction(
|
| 255 |
+
math_instruction_shape,
|
| 256 |
+
DataType.s8, DataType.s8, DataType.s32,
|
| 257 |
+
OpcodeClass.TensorOp,
|
| 258 |
+
MathOperation.multiply_add),
|
| 259 |
+
MathInstruction(
|
| 260 |
+
math_instruction_shape,
|
| 261 |
+
DataType.u8, DataType.u8, DataType.s32,
|
| 262 |
+
OpcodeClass.TensorOp,
|
| 263 |
+
MathOperation.multiply_add),
|
| 264 |
+
]
|
| 265 |
+
return math_instructions
|
| 266 |
+
|
| 267 |
+
def make_sparse_math_instructions(math_instructions):
|
| 268 |
+
sparse_instructions = []
|
| 269 |
+
for inst in math_instructions:
|
| 270 |
+
if inst.opcode_class == OpcodeClass.TensorOp:
|
| 271 |
+
sparse_instructions.append(MathInstruction(
|
| 272 |
+
(inst.instruction_shape[0], inst.instruction_shape[1], inst.instruction_shape[2] * 2),
|
| 273 |
+
inst.element_a, inst.element_b, inst.element_accumulator,
|
| 274 |
+
OpcodeClass.SparseTensorOp,
|
| 275 |
+
inst.math_operation),)
|
| 276 |
+
return sparse_instructions
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
#### Step 2: generate tile descriptions from math instruction shapes
|
| 280 |
+
|
| 281 |
+
def is_tile_desc_valid(tile_description):
|
| 282 |
+
if tile_description.minimum_compute_capability != 90 or tile_description.maximum_compute_capability != 90:
|
| 283 |
+
return False
|
| 284 |
+
|
| 285 |
+
element_a, element_b, element_accum = (
|
| 286 |
+
tile_description.math_instruction.element_a,
|
| 287 |
+
tile_description.math_instruction.element_b,
|
| 288 |
+
tile_description.math_instruction.element_accumulator
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
cluster_size, cta_shape = (
|
| 292 |
+
tile_description.cluster_shape,
|
| 293 |
+
tile_description.threadblock_shape,
|
| 294 |
+
)
|
| 295 |
+
grid_size = (
|
| 296 |
+
cta_shape[0] * cluster_size[0] +
|
| 297 |
+
cta_shape[1] * cluster_size[1] +
|
| 298 |
+
cta_shape[2] * cluster_size[2]
|
| 299 |
+
)
|
| 300 |
+
num_ctas_in_cluster = cluster_size[0] * cluster_size[1] * cluster_size[2]
|
| 301 |
+
cluster_shape = (
|
| 302 |
+
cluster_size[0] * cta_shape[0],
|
| 303 |
+
cluster_size[1] * cta_shape[1],
|
| 304 |
+
cluster_size[2] * cta_shape[2]
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
FP32_TYPES = [DataType.f32, DataType.tf32]
|
| 308 |
+
FP16_TYPES = [DataType.f16, DataType.bf16]
|
| 309 |
+
is_fp32 = element_a in FP32_TYPES and element_b in FP32_TYPES
|
| 310 |
+
is_fp16 = element_a in FP16_TYPES and element_b in FP16_TYPES
|
| 311 |
+
|
| 312 |
+
# Maximum number of CTAs per cluster is 8 for Hopper, but up to 16 is
|
| 313 |
+
# allowed for non portable clusters.
|
| 314 |
+
if num_ctas_in_cluster > 16 or num_ctas_in_cluster < 1:
|
| 315 |
+
return False
|
| 316 |
+
|
| 317 |
+
if grid_size < 1:
|
| 318 |
+
return False
|
| 319 |
+
|
| 320 |
+
# SM90 WGMMA shapes are always 64 across M, therefore
|
| 321 |
+
# CTA shape across M must always be a multiple of 64.
|
| 322 |
+
if cta_shape[0] < 64 or cta_shape[0] % 64 != 0:
|
| 323 |
+
return False
|
| 324 |
+
|
| 325 |
+
# The minimum WGMMA shape across N is 8, and increments
|
| 326 |
+
# vary across different dtypes, but they're never smaller
|
| 327 |
+
# than 8. The minimum CTA shape allowed across N though is 16.
|
| 328 |
+
if cta_shape[1] < 16 or cta_shape[1] % 8 != 0:
|
| 329 |
+
return False
|
| 330 |
+
|
| 331 |
+
# SM90 WGMMA shapes across K are always 8 for 32 bit dense
|
| 332 |
+
# operations, 16 for 16 bit, and 32 for 8 bit. In any case,
|
| 333 |
+
# the CTA shape across K should be a multiple of 8 and at least
|
| 334 |
+
# twice the WGMMA shape across K.
|
| 335 |
+
if cta_shape[2] < 16 or cta_shape[2] % 8 != 0:
|
| 336 |
+
return False
|
| 337 |
+
|
| 338 |
+
# Minimum of 2 stages (very rough heuristic that may filter out valid kernel configs)
|
| 339 |
+
if (cluster_shape[0] >= 128 or cluster_shape[1] >= 128) and cluster_shape[2] >= 256:
|
| 340 |
+
return False
|
| 341 |
+
|
| 342 |
+
if is_fp32 and (cluster_shape[0] >= 128 or cluster_shape[1] >= 128) and cluster_shape[2] >= 128:
|
| 343 |
+
return False
|
| 344 |
+
|
| 345 |
+
if is_fp32 and cluster_shape[0] >= 256 and cluster_shape[1] >= 256 and cluster_shape[2] >= 64:
|
| 346 |
+
return False
|
| 347 |
+
|
| 348 |
+
if is_fp16 and cluster_shape[0] >= 256 and cluster_shape[1] >= 256 and cluster_shape[2] >= 128:
|
| 349 |
+
return False
|
| 350 |
+
|
| 351 |
+
# CTA shape upper bound: <256, 256, 256>
|
| 352 |
+
if cta_shape[0] > 256 or cta_shape[1] > 256 or cta_shape[2] > 256:
|
| 353 |
+
return False
|
| 354 |
+
|
| 355 |
+
return True
|
| 356 |
+
|
| 357 |
+
def get_mma_multipliers(level: int):
|
| 358 |
+
assert isinstance(level, int) and level >= 0
|
| 359 |
+
mma_level = get_mma_level_from_global_level(level)
|
| 360 |
+
return [
|
| 361 |
+
mma_mul for mma_mul, mma_min_level in SM90_MMA_MULTIPLIERS.items() if mma_level >= mma_min_level
|
| 362 |
+
]
|
| 363 |
+
|
| 364 |
+
def get_cluster_sizes(level: int, is_aligned: bool):
|
| 365 |
+
if not is_aligned:
|
| 366 |
+
return [(1, 1, 1)]
|
| 367 |
+
assert isinstance(level, int) and level >= 0
|
| 368 |
+
cluster_level = get_cluster_level_from_global_level(level)
|
| 369 |
+
return [
|
| 370 |
+
cluster_size for cluster_size, cluster_min_level in SM90_CLUSTER_SIZES.items() if cluster_level >= cluster_min_level
|
| 371 |
+
]
|
| 372 |
+
|
| 373 |
+
def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level: int):
|
| 374 |
+
tile_descriptions = set()
|
| 375 |
+
mma_multipliers, cluster_sizes = get_mma_multipliers(level), get_cluster_sizes(level, is_aligned)
|
| 376 |
+
for math_inst, mma_mul, cluster_size in product(math_instructions, mma_multipliers, cluster_sizes):
|
| 377 |
+
|
| 378 |
+
# generator can stamp out duplicate kernels, because it doesn't explicitly set instruction
|
| 379 |
+
# shape for SM90 kernels, and the 3.X collective API doesn't directly expose them when using
|
| 380 |
+
# the auto kernel schedule.
|
| 381 |
+
|
| 382 |
+
math_inst_stub = copy.deepcopy(math_inst)
|
| 383 |
+
math_inst_stub.instruction_shape = [0, 0, 0]
|
| 384 |
+
|
| 385 |
+
tile_desc = TileDescription(
|
| 386 |
+
threadblock_shape=[
|
| 387 |
+
math_inst.instruction_shape[0] * mma_mul[0],
|
| 388 |
+
math_inst.instruction_shape[1] * mma_mul[1],
|
| 389 |
+
math_inst.instruction_shape[2] * mma_mul[2]
|
| 390 |
+
],
|
| 391 |
+
stages=0,
|
| 392 |
+
warp_count=[4, 1, 1],
|
| 393 |
+
math_instruction=math_inst_stub,
|
| 394 |
+
min_compute=90,
|
| 395 |
+
max_compute=90,
|
| 396 |
+
cluster_shape=cluster_size)
|
| 397 |
+
# For sparse kernels K-tile is twice as large (due to 2x MMA-K size)
|
| 398 |
+
# Reduce it to same size as dense to afford more smem stages
|
| 399 |
+
if math_inst.opcode_class == OpcodeClass.SparseTensorOp:
|
| 400 |
+
tile_desc.threadblock_shape[2] = tile_desc.threadblock_shape[2] // 2
|
| 401 |
+
if is_tile_desc_valid(tile_desc):
|
| 402 |
+
tile_descriptions.add(tile_desc)
|
| 403 |
+
|
| 404 |
+
return tile_descriptions
|
| 405 |
+
|
| 406 |
+
#### Step 3: map tile description to valid schedules
|
| 407 |
+
|
| 408 |
+
def is_tile_desc_compatible_with_cooperative(tile_description):
|
| 409 |
+
# Cooperative kernels require a minimum CTA-M of 128
|
| 410 |
+
return tile_description.threadblock_shape[0] % 128 == 0
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def can_tile_desc_use_shmem_in_epilogue(tile_description, data_types):
|
| 414 |
+
dtype_a, dtype_b, dtype_c, dtype_d, dtype_acc, dtype_epi = (
|
| 415 |
+
data_types["a_type"],
|
| 416 |
+
data_types["b_type"],
|
| 417 |
+
data_types["c_type"],
|
| 418 |
+
data_types["d_type"],
|
| 419 |
+
data_types["acc_type"],
|
| 420 |
+
data_types["epi_type"]
|
| 421 |
+
)
|
| 422 |
+
mn = tile_description.threadblock_shape[0] * tile_description.threadblock_shape[1]
|
| 423 |
+
bitsize_c, bitsize_d = DataTypeSize[dtype_c], DataTypeSize[dtype_d]
|
| 424 |
+
|
| 425 |
+
shmem_bits_c, shmem_bits_d = bitsize_c * mn, bitsize_d * mn
|
| 426 |
+
shmem_bits_total = shmem_bits_c + shmem_bits_d
|
| 427 |
+
# Magic number: 2^20
|
| 428 |
+
# Existing logic suggested that tile shape 256x128 (or 128x256)
|
| 429 |
+
# would run out of shmem if D is FP32, and source is needed.
|
| 430 |
+
# That would be 256 * 128 * 32 == 2^21 (~262 KB), which is over the limit.
|
| 431 |
+
# Hopper's max shmem size is 228 KB, and 2^20 ~= 131 KB.
|
| 432 |
+
# Since epilogue can't possibly use ALL of the shmem available
|
| 433 |
+
# we can just settle on 2^20 bits (~ 131 KB) being the upper bound
|
| 434 |
+
# we would allow for epilogue.
|
| 435 |
+
# This can be different for non-persistent kernels where epilogue and
|
| 436 |
+
# mainloop shmem is shared.
|
| 437 |
+
if shmem_bits_total > 2 ** 20:
|
| 438 |
+
return False
|
| 439 |
+
|
| 440 |
+
return True
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, layout,
|
| 444 |
+
instantiation_level, enable_fp8_fast_acc=True, gemm_kind=GemmKind.Universal3x):
|
| 445 |
+
# Level 0: prune according to existing generator.py behavior
|
| 446 |
+
# Level >= 1: no pruning
|
| 447 |
+
level = get_pruning_level_from_global_level(instantiation_level)
|
| 448 |
+
schedules = []
|
| 449 |
+
stream_k_schedules = []
|
| 450 |
+
|
| 451 |
+
if not is_tile_desc_valid(tile_description):
|
| 452 |
+
return schedules, stream_k_schedules
|
| 453 |
+
|
| 454 |
+
FP16_TYPES = [DataType.f16, DataType.bf16]
|
| 455 |
+
is_fp16 = data_types["a_type"] in FP16_TYPES and data_types["b_type"] in FP16_TYPES
|
| 456 |
+
|
| 457 |
+
FP8_TYPES = [DataType.e4m3, DataType.e5m2]
|
| 458 |
+
is_fp8 = data_types["a_type"] in FP8_TYPES and data_types["b_type"] in FP8_TYPES
|
| 459 |
+
can_do_fp8_fast_accum = is_fp8 and enable_fp8_fast_acc
|
| 460 |
+
|
| 461 |
+
FP32_TYPES = [DataType.f32, DataType.tf32]
|
| 462 |
+
is_fp32 = data_types["a_type"] in FP32_TYPES and data_types["b_type"] in FP32_TYPES
|
| 463 |
+
requires_transposed_epilogue = is_fp32 and layout[0][0] == LayoutType.RowMajor and layout[1][0] == LayoutType.RowMajor
|
| 464 |
+
|
| 465 |
+
can_do_cooperative = is_tile_desc_compatible_with_cooperative(tile_description)
|
| 466 |
+
can_do_tma_epilogue = is_aligned and not requires_transposed_epilogue and can_tile_desc_use_shmem_in_epilogue(tile_description, data_types)
|
| 467 |
+
|
| 468 |
+
default_epilogue = EpilogueScheduleType.NoSmemWarpSpecialized if not requires_transposed_epilogue else EpilogueScheduleType.EpilogueTransposed
|
| 469 |
+
auto_epilogue = EpilogueScheduleType.ScheduleAuto if not requires_transposed_epilogue else EpilogueScheduleType.EpilogueTransposed
|
| 470 |
+
|
| 471 |
+
cta_m, cta_n, cta_k = (
|
| 472 |
+
tile_description.threadblock_shape[0],
|
| 473 |
+
tile_description.threadblock_shape[1],
|
| 474 |
+
tile_description.threadblock_shape[2]
|
| 475 |
+
)
|
| 476 |
+
c_type = data_types["c_type"]
|
| 477 |
+
d_type = data_types["d_type"]
|
| 478 |
+
is_void_c = c_type == DataType.void
|
| 479 |
+
|
| 480 |
+
# Filter out invalid kernels
|
| 481 |
+
is_nt = layout[0][0] == LayoutType.ColumnMajor and layout[1][0] == LayoutType.RowMajor
|
| 482 |
+
is_tn = layout[0][0] == LayoutType.RowMajor and layout[1][0] == LayoutType.ColumnMajor
|
| 483 |
+
is_nn = layout[0][0] == LayoutType.ColumnMajor and layout[1][0] == LayoutType.ColumnMajor
|
| 484 |
+
|
| 485 |
+
# static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0,
|
| 486 |
+
# "Copy size must evenly divide SMEM tile.");
|
| 487 |
+
if is_fp32 and is_nt and (cta_n % cta_k != 0):
|
| 488 |
+
return [], []
|
| 489 |
+
|
| 490 |
+
# static_assert(!TransposeB || (cutlass::bits_to_bytes((size<1>(SmemLayoutB{}) * sizeof_bits<InternalElementB>::value))) == 128,
|
| 491 |
+
# "SmemLayoutB K must be 128bytes to be transposed.")
|
| 492 |
+
if is_fp32 and is_nt and cta_k != 32:
|
| 493 |
+
return [], []
|
| 494 |
+
|
| 495 |
+
# Static assert failure when instantiating SmemLayoutB
|
| 496 |
+
if is_fp32 and (is_tn or is_nn) and (cta_n % cta_k != 0):
|
| 497 |
+
return [], []
|
| 498 |
+
|
| 499 |
+
grouped = is_grouped(gemm_kind)
|
| 500 |
+
if grouped:
|
| 501 |
+
# the following cases are unsupported by grouped GEMM
|
| 502 |
+
if not is_aligned:
|
| 503 |
+
return [], []
|
| 504 |
+
if requires_transposed_epilogue:
|
| 505 |
+
return [], []
|
| 506 |
+
|
| 507 |
+
# Early pruning
|
| 508 |
+
if level < 1:
|
| 509 |
+
# Don't stamp out FP16/BF16 kernels smaller than or equal to 64x128x64
|
| 510 |
+
if is_fp16 and cta_m <= 64 and cta_n <= 128 and cta_k <= 64:
|
| 511 |
+
return [], []
|
| 512 |
+
|
| 513 |
+
# FP8 configs with CTA tile larger than or equal to 256x128x128 limit data types and schedules
|
| 514 |
+
is_large_fp8_tile = is_fp8 and cta_m >= 256 and cta_n >= 128 and cta_k >= 128
|
| 515 |
+
if is_large_fp8_tile:
|
| 516 |
+
# Only void-C, and only FP8 outputs allowed
|
| 517 |
+
if not is_void_c or d_type not in FP8_TYPES:
|
| 518 |
+
return [], []
|
| 519 |
+
if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue:
|
| 520 |
+
schedules = []
|
| 521 |
+
if is_blockwise(gemm_kind):
|
| 522 |
+
schedules.append(
|
| 523 |
+
[
|
| 524 |
+
to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
|
| 525 |
+
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
| 526 |
+
])
|
| 527 |
+
else:
|
| 528 |
+
schedules.append(
|
| 529 |
+
[
|
| 530 |
+
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
| 531 |
+
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
| 532 |
+
])
|
| 533 |
+
schedules.append(
|
| 534 |
+
[
|
| 535 |
+
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
| 536 |
+
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
| 537 |
+
])
|
| 538 |
+
return schedules, []
|
| 539 |
+
return [], []
|
| 540 |
+
|
| 541 |
+
if is_fp8 and not is_large_fp8_tile:
|
| 542 |
+
valid_dtypes_for_c = [DataType.f32, DataType.bf16, DataType.f16, DataType.void]
|
| 543 |
+
# Prune all configs with fp8 source, and all configs with non-fp8 output
|
| 544 |
+
# that have different dtypes for source and output.
|
| 545 |
+
if c_type not in valid_dtypes_for_c or (d_type not in FP8_TYPES and c_type != d_type):
|
| 546 |
+
return [], []
|
| 547 |
+
|
| 548 |
+
# FP32/TF32 kernels don't stamp out void-C
|
| 549 |
+
if is_fp32 and is_void_c:
|
| 550 |
+
return [], []
|
| 551 |
+
|
| 552 |
+
# Void-c only makes a difference for TMA epilogues
|
| 553 |
+
if is_void_c and not can_do_tma_epilogue:
|
| 554 |
+
return [], []
|
| 555 |
+
|
| 556 |
+
# For mixed input data types
|
| 557 |
+
a_type_size = DataTypeSize[data_types["a_type"]]
|
| 558 |
+
b_type_size = DataTypeSize[data_types["b_type"]]
|
| 559 |
+
if a_type_size != b_type_size and CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
| 560 |
+
schedules = []
|
| 561 |
+
stream_k_schedules = []
|
| 562 |
+
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized
|
| 563 |
+
if a_type_size > b_type_size:
|
| 564 |
+
epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
|
| 565 |
+
|
| 566 |
+
if not is_blockwise(gemm_kind):
|
| 567 |
+
schedules.append([
|
| 568 |
+
KernelScheduleType.TmaWarpSpecialized,
|
| 569 |
+
epilogue_schedule
|
| 570 |
+
])
|
| 571 |
+
schedules.append([
|
| 572 |
+
KernelScheduleType.TmaWarpSpecializedPingpong,
|
| 573 |
+
epilogue_schedule
|
| 574 |
+
])
|
| 575 |
+
if cta_m >= 128:
|
| 576 |
+
if a_type_size > b_type_size:
|
| 577 |
+
epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
|
| 578 |
+
else:
|
| 579 |
+
epilogue_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
| 580 |
+
if is_blockwise(gemm_kind):
|
| 581 |
+
schedules.append([
|
| 582 |
+
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
|
| 583 |
+
epilogue_schedule
|
| 584 |
+
])
|
| 585 |
+
else:
|
| 586 |
+
schedules.append([
|
| 587 |
+
KernelScheduleType.TmaWarpSpecializedCooperative,
|
| 588 |
+
epilogue_schedule
|
| 589 |
+
])
|
| 590 |
+
stream_k_schedules.append([
|
| 591 |
+
KernelScheduleType.TmaWarpSpecializedCooperative,
|
| 592 |
+
epilogue_schedule
|
| 593 |
+
])
|
| 594 |
+
return schedules, stream_k_schedules
|
| 595 |
+
|
| 596 |
+
if not is_aligned and not is_blockwise(gemm_kind):
|
| 597 |
+
schedules = [[KernelScheduleType.CpAsyncWarpSpecialized,
|
| 598 |
+
default_epilogue]]
|
| 599 |
+
stream_k_schedules = []
|
| 600 |
+
|
| 601 |
+
if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative:
|
| 602 |
+
schedules.append([
|
| 603 |
+
KernelScheduleType.CpAsyncWarpSpecializedCooperative,
|
| 604 |
+
default_epilogue
|
| 605 |
+
])
|
| 606 |
+
stream_k_schedules.append([
|
| 607 |
+
KernelScheduleType.CpAsyncWarpSpecializedCooperative,
|
| 608 |
+
default_epilogue
|
| 609 |
+
])
|
| 610 |
+
|
| 611 |
+
return schedules, stream_k_schedules
|
| 612 |
+
|
| 613 |
+
schedules = []
|
| 614 |
+
# Pruning: emit Void-C and Grouped kernels with persistent kernels only
|
| 615 |
+
if (level >= 1 or not is_void_c) and not grouped and not is_blockwise(gemm_kind):
|
| 616 |
+
# Pruning: don't stamp out fp8 kernels with auto schedule
|
| 617 |
+
if not is_fp8:
|
| 618 |
+
schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue])
|
| 619 |
+
schedules.append([KernelScheduleType.TmaWarpSpecialized, default_epilogue])
|
| 620 |
+
stream_k_schedules = []
|
| 621 |
+
|
| 622 |
+
if CudaToolkitVersionSatisfies(cuda_version, 12, 0):
|
| 623 |
+
if can_do_tma_epilogue:
|
| 624 |
+
assert not requires_transposed_epilogue
|
| 625 |
+
# Inconsistency: fp8 pingpong only gets stamped out with fast accum
|
| 626 |
+
if (not is_fp8 or level >= 1) and not is_blockwise(gemm_kind):
|
| 627 |
+
schedules.append([
|
| 628 |
+
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped),
|
| 629 |
+
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
|
| 630 |
+
])
|
| 631 |
+
if can_do_fp8_fast_accum:
|
| 632 |
+
schedules.append([
|
| 633 |
+
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped),
|
| 634 |
+
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
|
| 635 |
+
])
|
| 636 |
+
|
| 637 |
+
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
| 638 |
+
# Pruning: don't stamp out fp8 ping-pong kernel with non-tma epilogue
|
| 639 |
+
if not is_fp8 or level >= 1:
|
| 640 |
+
if not is_blockwise(gemm_kind):
|
| 641 |
+
schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)])
|
| 642 |
+
else:
|
| 643 |
+
schedules.append([to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)])
|
| 644 |
+
|
| 645 |
+
if can_do_fp8_fast_accum:
|
| 646 |
+
if not grouped:
|
| 647 |
+
schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue])
|
| 648 |
+
schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)])
|
| 649 |
+
|
| 650 |
+
if can_do_cooperative:
|
| 651 |
+
if is_blockwise(gemm_kind):
|
| 652 |
+
schedules.append([
|
| 653 |
+
to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
|
| 654 |
+
to_grouped_schedule(default_epilogue, grouped)
|
| 655 |
+
])
|
| 656 |
+
stream_k_schedules.append([
|
| 657 |
+
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
|
| 658 |
+
default_epilogue
|
| 659 |
+
])
|
| 660 |
+
else:
|
| 661 |
+
schedules.append([
|
| 662 |
+
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
| 663 |
+
to_grouped_schedule(default_epilogue, grouped)
|
| 664 |
+
])
|
| 665 |
+
stream_k_schedules.append([
|
| 666 |
+
KernelScheduleType.TmaWarpSpecializedCooperative,
|
| 667 |
+
default_epilogue
|
| 668 |
+
])
|
| 669 |
+
if can_do_fp8_fast_accum:
|
| 670 |
+
schedules.append([
|
| 671 |
+
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
| 672 |
+
to_grouped_schedule(default_epilogue, grouped)
|
| 673 |
+
])
|
| 674 |
+
stream_k_schedules.append([
|
| 675 |
+
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
| 676 |
+
default_epilogue
|
| 677 |
+
])
|
| 678 |
+
|
| 679 |
+
# persistent kernels with TMA epilogues
|
| 680 |
+
if can_do_tma_epilogue:
|
| 681 |
+
assert not requires_transposed_epilogue
|
| 682 |
+
if can_do_cooperative:
|
| 683 |
+
if is_blockwise(gemm_kind):
|
| 684 |
+
schedules.append([
|
| 685 |
+
to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
|
| 686 |
+
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
| 687 |
+
])
|
| 688 |
+
stream_k_schedules.append([
|
| 689 |
+
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
|
| 690 |
+
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
| 691 |
+
])
|
| 692 |
+
else:
|
| 693 |
+
schedules.append([
|
| 694 |
+
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
|
| 695 |
+
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
| 696 |
+
])
|
| 697 |
+
stream_k_schedules.append([
|
| 698 |
+
KernelScheduleType.TmaWarpSpecializedCooperative,
|
| 699 |
+
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
| 700 |
+
])
|
| 701 |
+
if can_do_fp8_fast_accum:
|
| 702 |
+
schedules.append([
|
| 703 |
+
to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
|
| 704 |
+
to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
|
| 705 |
+
])
|
| 706 |
+
stream_k_schedules.append([
|
| 707 |
+
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
| 708 |
+
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
| 709 |
+
])
|
| 710 |
+
# Grouped GEMM do not support Stream-K scheduler
|
| 711 |
+
if grouped:
|
| 712 |
+
return schedules, []
|
| 713 |
+
return schedules, stream_k_schedules
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
#### Misc: helpers
|
| 717 |
+
|
| 718 |
+
def generate_data_types_from_math_instruction(math_instruction, element_source = None, element_dest = None, element_epilogue = None):
|
| 719 |
+
element_a, element_b = math_instruction.element_a, math_instruction.element_b
|
| 720 |
+
element_accumulator = math_instruction.element_accumulator
|
| 721 |
+
element_c = element_source or element_accumulator
|
| 722 |
+
element_d = element_dest or element_accumulator
|
| 723 |
+
element_epilogue = element_epilogue or element_accumulator
|
| 724 |
+
data_types = {
|
| 725 |
+
"a_type" : element_a,
|
| 726 |
+
"b_type" : element_b,
|
| 727 |
+
"c_type" : element_c,
|
| 728 |
+
"d_type" : element_d,
|
| 729 |
+
"acc_type" : element_accumulator,
|
| 730 |
+
"epi_type" : element_epilogue
|
| 731 |
+
}
|
| 732 |
+
return data_types
|
| 733 |
+
|
| 734 |
+
def fix_alignments(data_types, layout, alignment_bits = 128):
|
| 735 |
+
operand_keys = ["a_type", "b_type", "c_type"]
|
| 736 |
+
operands_to_fix = ["c_type"]
|
| 737 |
+
new_layout = []
|
| 738 |
+
assert len(layout) == len(operand_keys)
|
| 739 |
+
for i, k in enumerate(operand_keys):
|
| 740 |
+
assert k in data_types and data_types[k] in DataTypeSize
|
| 741 |
+
dtype = data_types[k]
|
| 742 |
+
dtype_size_bits = DataTypeSize[dtype]
|
| 743 |
+
|
| 744 |
+
layout_type = layout[i][0]
|
| 745 |
+
layout_alignment = layout[i][1]
|
| 746 |
+
|
| 747 |
+
# Don't modify alignment if dtype's been changed to void
|
| 748 |
+
if k in operands_to_fix and dtype_size_bits >= 1:
|
| 749 |
+
layout_alignment = alignment_bits // dtype_size_bits
|
| 750 |
+
|
| 751 |
+
new_layout.append([layout_type, layout_alignment])
|
| 752 |
+
|
| 753 |
+
return new_layout
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for emitting Symm kernels
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import enum
|
| 38 |
+
import functools
|
| 39 |
+
import operator
|
| 40 |
+
import os.path
|
| 41 |
+
import shutil
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
import builtins
|
| 45 |
+
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 46 |
+
raise ImportError("Disabling attempt to import cutlass_library")
|
| 47 |
+
from cutlass_library.library import *
|
| 48 |
+
except ImportError:
|
| 49 |
+
from library import *
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
###################################################################################################
|
| 53 |
+
#
|
| 54 |
+
# Data structure modeling a Symm update operation
|
| 55 |
+
#
|
| 56 |
+
###################################################################################################
|
| 57 |
+
|
| 58 |
+
#
|
| 59 |
+
class SymmOperation:
|
| 60 |
+
#
|
| 61 |
+
def __init__(self, symm_kind, arch, tile_description, A, B, C, element_epilogue, \
|
| 62 |
+
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \
|
| 63 |
+
blas_mode = BlasMode.symmetric):
|
| 64 |
+
|
| 65 |
+
self.blas_mode = blas_mode
|
| 66 |
+
self.operation_kind = OperationKind.Symm
|
| 67 |
+
self.arch = arch
|
| 68 |
+
self.tile_description = tile_description
|
| 69 |
+
self.symm_kind = symm_kind
|
| 70 |
+
# tensor A and B have same data type and layout
|
| 71 |
+
self.A = A
|
| 72 |
+
self.B = B
|
| 73 |
+
self.C = C
|
| 74 |
+
self.element_epilogue = element_epilogue
|
| 75 |
+
self.epilogue_functor = epilogue_functor
|
| 76 |
+
self.swizzling_functor = swizzling_functor
|
| 77 |
+
|
| 78 |
+
#
|
| 79 |
+
def is_complex(self):
|
| 80 |
+
complex_operators = [
|
| 81 |
+
MathOperation.multiply_add_complex,
|
| 82 |
+
MathOperation.multiply_add_complex_gaussian,
|
| 83 |
+
MathOperation.multiply_add_complex_fast_f32
|
| 84 |
+
]
|
| 85 |
+
return self.tile_description.math_instruction.math_operation in complex_operators
|
| 86 |
+
return False
|
| 87 |
+
|
| 88 |
+
#
|
| 89 |
+
def is_mixed_input(self):
|
| 90 |
+
return self.A.element != self.B.element
|
| 91 |
+
|
| 92 |
+
#
|
| 93 |
+
def is_planar_complex(self):
|
| 94 |
+
return False
|
| 95 |
+
|
| 96 |
+
#
|
| 97 |
+
def accumulator_type(self):
|
| 98 |
+
accum = self.tile_description.math_instruction.element_accumulator
|
| 99 |
+
|
| 100 |
+
if self.is_complex():
|
| 101 |
+
return get_complex_from_real(accum)
|
| 102 |
+
|
| 103 |
+
return accum
|
| 104 |
+
|
| 105 |
+
#
|
| 106 |
+
def short_math_name(self):
|
| 107 |
+
if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
|
| 108 |
+
return "g%s" % ShortDataTypeNames[self.accumulator_type()]
|
| 109 |
+
return ShortDataTypeNames[self.accumulator_type()]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
#
|
| 113 |
+
def core_name(self):
|
| 114 |
+
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
| 115 |
+
|
| 116 |
+
inst_shape = ''
|
| 117 |
+
inst_operation = ''
|
| 118 |
+
intermediate_type = ''
|
| 119 |
+
|
| 120 |
+
math_operations_map = {
|
| 121 |
+
MathOperation.xor_popc: 'xor',
|
| 122 |
+
MathOperation.and_popc: 'and'
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
|
| 126 |
+
self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
|
| 127 |
+
|
| 128 |
+
math_op = self.tile_description.math_instruction.math_operation
|
| 129 |
+
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
|
| 130 |
+
|
| 131 |
+
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
| 132 |
+
inst_shape += math_op_string
|
| 133 |
+
|
| 134 |
+
if self.tile_description.math_instruction.element_a != self.A.element and \
|
| 135 |
+
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
| 136 |
+
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
| 137 |
+
|
| 138 |
+
operation_name = 'symm' if self.blas_mode == BlasMode.symmetric else 'hemm'
|
| 139 |
+
|
| 140 |
+
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name)
|
| 141 |
+
|
| 142 |
+
#
|
| 143 |
+
def extended_name(self):
|
| 144 |
+
''' Append data types if they differ from compute type. '''
|
| 145 |
+
if self.is_complex():
|
| 146 |
+
extended_name = "${core_name}"
|
| 147 |
+
else:
|
| 148 |
+
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
| 149 |
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 150 |
+
extended_name = "${element_c}_${core_name}_${element_a}"
|
| 151 |
+
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
| 152 |
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 153 |
+
extended_name = "${core_name}_${element_a}"
|
| 154 |
+
else:
|
| 155 |
+
extended_name = "${core_name}"
|
| 156 |
+
|
| 157 |
+
extended_name = SubstituteTemplate(extended_name, {
|
| 158 |
+
'element_a': DataTypeNames[self.A.element],
|
| 159 |
+
'element_c': DataTypeNames[self.C.element],
|
| 160 |
+
'core_name': self.core_name()
|
| 161 |
+
})
|
| 162 |
+
|
| 163 |
+
return extended_name
|
| 164 |
+
|
| 165 |
+
#
|
| 166 |
+
def layout_name(self):
|
| 167 |
+
if self.is_complex() or self.is_planar_complex():
|
| 168 |
+
return "%s" % (
|
| 169 |
+
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)]
|
| 170 |
+
)
|
| 171 |
+
return "%s" % (ShortLayoutTypeNames[self.A.layout])
|
| 172 |
+
|
| 173 |
+
#
|
| 174 |
+
def side_mode_name(self):
|
| 175 |
+
return "%s" % (ShortSideModeNames[self.A.side_mode])
|
| 176 |
+
|
| 177 |
+
#
|
| 178 |
+
def fill_mode_name(self):
|
| 179 |
+
return "%s" % (ShortFillModeNames[self.A.fill_mode])
|
| 180 |
+
|
| 181 |
+
#
|
| 182 |
+
def procedural_name(self):
|
| 183 |
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 184 |
+
threadblock = self.tile_description.procedural_name()
|
| 185 |
+
|
| 186 |
+
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
| 187 |
+
|
| 188 |
+
alignment = self.C.alignment
|
| 189 |
+
|
| 190 |
+
return SubstituteTemplate(
|
| 191 |
+
"cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${side_mode}_${fill_mode}_align${alignment}",
|
| 192 |
+
{
|
| 193 |
+
'opcode_class': opcode_class_name,
|
| 194 |
+
'extended_name': self.extended_name(),
|
| 195 |
+
'threadblock': threadblock,
|
| 196 |
+
'layout': self.layout_name(),
|
| 197 |
+
'side_mode': self.side_mode_name(),
|
| 198 |
+
'fill_mode': self.fill_mode_name(),
|
| 199 |
+
'alignment': "%d" % alignment,
|
| 200 |
+
}
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
#
|
| 204 |
+
def configuration_name(self):
|
| 205 |
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 206 |
+
return self.procedural_name()
|
| 207 |
+
|
| 208 |
+
###################################################################################################
|
| 209 |
+
#
|
| 210 |
+
# Emits single instances of a CUTLASS device-wide operator
|
| 211 |
+
#
|
| 212 |
+
###################################################################################################
|
| 213 |
+
|
| 214 |
+
#
|
| 215 |
+
class EmitSymmUniversalInstance:
|
| 216 |
+
''' Responsible for emitting a CUTLASS template definition'''
|
| 217 |
+
|
| 218 |
+
def __init__(self):
|
| 219 |
+
self.symm_template = """
|
| 220 |
+
// Symm operator ${operation_name}
|
| 221 |
+
using Operation_${operation_name} =
|
| 222 |
+
typename cutlass::gemm::device::Symm<
|
| 223 |
+
${element_a}, ${layout_a}, ${side_mode}, ${fill_mode},
|
| 224 |
+
${element_b}, ${layout_b},
|
| 225 |
+
${element_c}, ${layout_c},
|
| 226 |
+
${element_accumulator},
|
| 227 |
+
${opcode_class},
|
| 228 |
+
${arch},
|
| 229 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 230 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 231 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 232 |
+
${epilogue_functor}<
|
| 233 |
+
${element_c},
|
| 234 |
+
${epilogue_vector_length},
|
| 235 |
+
${element_accumulator},
|
| 236 |
+
${element_epilogue}
|
| 237 |
+
>,
|
| 238 |
+
${swizzling_functor},
|
| 239 |
+
${stages},
|
| 240 |
+
${align_a},
|
| 241 |
+
${align_b},
|
| 242 |
+
${split_k_serial},
|
| 243 |
+
${math_operation}
|
| 244 |
+
>;
|
| 245 |
+
"""
|
| 246 |
+
self.symm_complex_template = """
|
| 247 |
+
// Symm operator ${operation_name}
|
| 248 |
+
using Operation_${operation_name} =
|
| 249 |
+
typename cutlass::gemm::device::Symm<
|
| 250 |
+
${element_a}, ${layout_a}, ${side_mode}, ${fill_mode},
|
| 251 |
+
${element_b}, ${layout_b},
|
| 252 |
+
${element_c}, ${layout_c},
|
| 253 |
+
${element_accumulator},
|
| 254 |
+
${opcode_class},
|
| 255 |
+
${arch},
|
| 256 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 257 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 258 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 259 |
+
${epilogue_functor}<
|
| 260 |
+
${element_c},
|
| 261 |
+
${epilogue_vector_length},
|
| 262 |
+
${element_accumulator},
|
| 263 |
+
${element_epilogue}
|
| 264 |
+
>,
|
| 265 |
+
${swizzling_functor},
|
| 266 |
+
${stages},
|
| 267 |
+
${align_a},
|
| 268 |
+
${align_b},
|
| 269 |
+
${split_k_serial},
|
| 270 |
+
${math_operation},
|
| 271 |
+
${blas_mode}
|
| 272 |
+
>;
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
def emit(self, operation):
|
| 276 |
+
|
| 277 |
+
threadblock_shape = operation.tile_description.threadblock_shape
|
| 278 |
+
|
| 279 |
+
warp_count = operation.tile_description.warp_count
|
| 280 |
+
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
| 281 |
+
|
| 282 |
+
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
| 283 |
+
|
| 284 |
+
values = {
|
| 285 |
+
'operation_name': operation.procedural_name(),
|
| 286 |
+
'element_a': DataTypeTag[operation.A.element],
|
| 287 |
+
'layout_a': LayoutTag[operation.A.layout],
|
| 288 |
+
'side_mode': SideModeTag[operation.A.side_mode],
|
| 289 |
+
'fill_mode': FillModeTag[operation.A.fill_mode],
|
| 290 |
+
'element_b': DataTypeTag[operation.B.element],
|
| 291 |
+
'layout_b': LayoutTag[operation.B.layout],
|
| 292 |
+
'element_c': DataTypeTag[operation.C.element],
|
| 293 |
+
'layout_c': LayoutTag[operation.C.layout],
|
| 294 |
+
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
| 295 |
+
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 296 |
+
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
| 297 |
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
| 298 |
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
| 299 |
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
| 300 |
+
'warp_shape_m': str(warp_shape[0]),
|
| 301 |
+
'warp_shape_n': str(warp_shape[1]),
|
| 302 |
+
'warp_shape_k': str(warp_shape[2]),
|
| 303 |
+
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 304 |
+
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 305 |
+
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 306 |
+
'epilogue_vector_length': str(epilogue_vector_length),
|
| 307 |
+
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 308 |
+
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
| 309 |
+
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
| 310 |
+
'stages': str(operation.tile_description.stages),
|
| 311 |
+
'align_a': str(operation.A.alignment),
|
| 312 |
+
'align_b': str(operation.B.alignment),
|
| 313 |
+
'split_k_serial': 'false',
|
| 314 |
+
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
| 315 |
+
'blas_mode': BlasModeTag[operation.blas_mode]
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
symm_template = self.symm_complex_template if operation.is_complex() else self.symm_template
|
| 319 |
+
|
| 320 |
+
return SubstituteTemplate(symm_template, values)
|
| 321 |
+
|
| 322 |
+
###################################################################################################
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
###################################################################################################
|
| 326 |
+
#
|
| 327 |
+
# Emitters functions for all targets
|
| 328 |
+
#
|
| 329 |
+
###################################################################################################
|
| 330 |
+
|
| 331 |
+
class EmitSymmConfigurationLibrary:
|
| 332 |
+
def __init__(self, operation_path, configuration_name):
|
| 333 |
+
self.configuration_name = configuration_name
|
| 334 |
+
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
|
| 335 |
+
|
| 336 |
+
self.instance_emitter = {
|
| 337 |
+
SymmKind.Universal: EmitSymmUniversalInstance,
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
self.symm_kind_wrappers = {
|
| 341 |
+
SymmKind.Universal: 'SymmOperation',
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
self.instance_template = {
|
| 345 |
+
SymmKind.Universal: """
|
| 346 |
+
${compile_guard_start}
|
| 347 |
+
manifest.append(new ${symm_kind}<
|
| 348 |
+
Operation_${operation_name}
|
| 349 |
+
>("${operation_name}"));
|
| 350 |
+
${compile_guard_end}
|
| 351 |
+
"""
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
self.header_template = """
|
| 355 |
+
/*
|
| 356 |
+
Generated by symm_operation.py - Do not edit.
|
| 357 |
+
*/
|
| 358 |
+
|
| 359 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 360 |
+
#include "cutlass/cutlass.h"
|
| 361 |
+
#include "cutlass/library/library.h"
|
| 362 |
+
#include "cutlass/library/manifest.h"
|
| 363 |
+
|
| 364 |
+
#include "library_internal.h"
|
| 365 |
+
#include "symm_operation.h"
|
| 366 |
+
|
| 367 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 368 |
+
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
self.initialize_function_template = """
|
| 372 |
+
|
| 373 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 374 |
+
|
| 375 |
+
namespace cutlass {
|
| 376 |
+
namespace library {
|
| 377 |
+
|
| 378 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 379 |
+
|
| 380 |
+
void initialize_${configuration_name}(Manifest &manifest) {
|
| 381 |
+
|
| 382 |
+
"""
|
| 383 |
+
self.epilogue_template = """
|
| 384 |
+
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 388 |
+
|
| 389 |
+
} // namespace library
|
| 390 |
+
} // namespace cutlass
|
| 391 |
+
|
| 392 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 393 |
+
|
| 394 |
+
"""
|
| 395 |
+
|
| 396 |
+
def __enter__(self):
|
| 397 |
+
self.configuration_file = open(self.configuration_path, "w")
|
| 398 |
+
self.configuration_file.write(self.header_template)
|
| 399 |
+
|
| 400 |
+
self.instance_definitions = []
|
| 401 |
+
self.instance_wrappers = []
|
| 402 |
+
|
| 403 |
+
self.operations = []
|
| 404 |
+
return self
|
| 405 |
+
|
| 406 |
+
def emit(self, operation):
|
| 407 |
+
emitter = self.instance_emitter[operation.symm_kind]()
|
| 408 |
+
|
| 409 |
+
self.operations.append(operation)
|
| 410 |
+
|
| 411 |
+
self.instance_definitions.append(emitter.emit(operation))
|
| 412 |
+
|
| 413 |
+
self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.symm_kind], {
|
| 414 |
+
'configuration_name': self.configuration_name,
|
| 415 |
+
'operation_name': operation.procedural_name(),
|
| 416 |
+
'symm_kind': self.symm_kind_wrappers[operation.symm_kind],
|
| 417 |
+
'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
|
| 418 |
+
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
|
| 419 |
+
'compile_guard_end': "#endif" \
|
| 420 |
+
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
|
| 421 |
+
}))
|
| 422 |
+
|
| 423 |
+
def __exit__(self, exception_type, exception_value, traceback):
|
| 424 |
+
|
| 425 |
+
# Write instance definitions in top-level namespace
|
| 426 |
+
for instance_definition in self.instance_definitions:
|
| 427 |
+
self.configuration_file.write(instance_definition)
|
| 428 |
+
|
| 429 |
+
# Add wrapper objects within initialize() function
|
| 430 |
+
self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
|
| 431 |
+
'configuration_name': self.configuration_name
|
| 432 |
+
}))
|
| 433 |
+
|
| 434 |
+
for instance_wrapper in self.instance_wrappers:
|
| 435 |
+
self.configuration_file.write(instance_wrapper)
|
| 436 |
+
|
| 437 |
+
self.configuration_file.write(self.epilogue_template)
|
| 438 |
+
self.configuration_file.close()
|
| 439 |
+
|
| 440 |
+
###################################################################################################
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for emitting Trmm kernels
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import enum
|
| 38 |
+
import functools
|
| 39 |
+
import operator
|
| 40 |
+
import os.path
|
| 41 |
+
import shutil
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
import builtins
|
| 45 |
+
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 46 |
+
raise ImportError("Disabling attempt to import cutlass_library")
|
| 47 |
+
from cutlass_library.library import *
|
| 48 |
+
except ImportError:
|
| 49 |
+
from library import *
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
###################################################################################################
|
| 53 |
+
#
|
| 54 |
+
# Data structure modeling a TRMM operation
|
| 55 |
+
#
|
| 56 |
+
###################################################################################################
|
| 57 |
+
|
| 58 |
+
#
|
| 59 |
+
class TrmmOperation:
|
| 60 |
+
#
|
| 61 |
+
def __init__(self, trmm_kind, arch, tile_description, A, B, C, element_epilogue, \
|
| 62 |
+
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8):
|
| 63 |
+
|
| 64 |
+
self.operation_kind = OperationKind.Trmm
|
| 65 |
+
self.arch = arch
|
| 66 |
+
self.tile_description = tile_description
|
| 67 |
+
self.trmm_kind = trmm_kind
|
| 68 |
+
self.A = A
|
| 69 |
+
self.B = B
|
| 70 |
+
self.C = C
|
| 71 |
+
self.element_epilogue = element_epilogue
|
| 72 |
+
self.epilogue_functor = epilogue_functor
|
| 73 |
+
self.swizzling_functor = swizzling_functor
|
| 74 |
+
|
| 75 |
+
#
|
| 76 |
+
def is_complex(self):
|
| 77 |
+
complex_operators = [
|
| 78 |
+
MathOperation.multiply_add_complex,
|
| 79 |
+
MathOperation.multiply_add_complex_gaussian,
|
| 80 |
+
MathOperation.multiply_add_complex_fast_f32
|
| 81 |
+
]
|
| 82 |
+
return self.tile_description.math_instruction.math_operation in complex_operators
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
#
|
| 86 |
+
def is_planar_complex(self):
|
| 87 |
+
# return self.trmm_kind in (TrmmKind.PlanarComplex, TrmmKind.PlanarComplexArray)
|
| 88 |
+
return False
|
| 89 |
+
|
| 90 |
+
#
|
| 91 |
+
def is_mixed_input(self):
|
| 92 |
+
return self.A.element != self.B.element
|
| 93 |
+
|
| 94 |
+
#
|
| 95 |
+
def accumulator_type(self):
|
| 96 |
+
accum = self.tile_description.math_instruction.element_accumulator
|
| 97 |
+
|
| 98 |
+
if self.is_complex():
|
| 99 |
+
return get_complex_from_real(accum)
|
| 100 |
+
|
| 101 |
+
return accum
|
| 102 |
+
|
| 103 |
+
#
|
| 104 |
+
def short_math_name(self):
|
| 105 |
+
if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
|
| 106 |
+
return "g%s" % ShortDataTypeNames[self.accumulator_type()]
|
| 107 |
+
return ShortDataTypeNames[self.accumulator_type()]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
#
|
| 111 |
+
def core_name(self):
|
| 112 |
+
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
| 113 |
+
|
| 114 |
+
inst_shape = ''
|
| 115 |
+
inst_operation = ''
|
| 116 |
+
intermediate_type = ''
|
| 117 |
+
|
| 118 |
+
math_operations_map = {
|
| 119 |
+
MathOperation.xor_popc: 'xor',
|
| 120 |
+
MathOperation.and_popc: 'and'
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
|
| 124 |
+
self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
|
| 125 |
+
|
| 126 |
+
math_op = self.tile_description.math_instruction.math_operation
|
| 127 |
+
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
|
| 128 |
+
|
| 129 |
+
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
| 130 |
+
inst_shape += math_op_string
|
| 131 |
+
|
| 132 |
+
if self.tile_description.math_instruction.element_a != self.A.element and \
|
| 133 |
+
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
| 134 |
+
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
| 135 |
+
|
| 136 |
+
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, TrmmKindNames[self.trmm_kind])
|
| 137 |
+
|
| 138 |
+
#
|
| 139 |
+
def extended_name(self):
|
| 140 |
+
''' Append data types if they differ from compute type. '''
|
| 141 |
+
if self.is_complex():
|
| 142 |
+
extended_name = "${core_name}"
|
| 143 |
+
else:
|
| 144 |
+
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
| 145 |
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 146 |
+
extended_name = "${element_c}_${core_name}_${element_a}"
|
| 147 |
+
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
| 148 |
+
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 149 |
+
extended_name = "${core_name}_${element_a}"
|
| 150 |
+
else:
|
| 151 |
+
extended_name = "${core_name}"
|
| 152 |
+
|
| 153 |
+
extended_name = SubstituteTemplate(extended_name, {
|
| 154 |
+
'element_a': DataTypeNames[self.A.element],
|
| 155 |
+
'element_c': DataTypeNames[self.C.element],
|
| 156 |
+
'core_name': self.core_name()
|
| 157 |
+
})
|
| 158 |
+
|
| 159 |
+
return extended_name
|
| 160 |
+
|
| 161 |
+
#
|
| 162 |
+
def layout_name(self):
|
| 163 |
+
if self.is_complex() or self.is_planar_complex():
|
| 164 |
+
return "%s%s" % (
|
| 165 |
+
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
|
| 166 |
+
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
|
| 167 |
+
)
|
| 168 |
+
return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
|
| 169 |
+
|
| 170 |
+
#
|
| 171 |
+
def side_mode_name(self):
|
| 172 |
+
return "%s" % (ShortSideModeNames[self.A.side_mode])
|
| 173 |
+
|
| 174 |
+
#
|
| 175 |
+
def fill_mode_name(self):
|
| 176 |
+
return "%s" % (ShortFillModeNames[self.A.fill_mode])
|
| 177 |
+
|
| 178 |
+
#
|
| 179 |
+
def diag_type_name(self):
|
| 180 |
+
return "%s" % (ShortDiagTypeNames[self.A.diag_type])
|
| 181 |
+
|
| 182 |
+
#
|
| 183 |
+
def procedural_name(self):
|
| 184 |
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 185 |
+
threadblock = self.tile_description.procedural_name()
|
| 186 |
+
|
| 187 |
+
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
| 188 |
+
|
| 189 |
+
alignment = max([self.C.alignment])
|
| 190 |
+
|
| 191 |
+
return SubstituteTemplate(
|
| 192 |
+
"cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${side_mode}_${fill_mode}_${diag_type}_align${alignment}",
|
| 193 |
+
{
|
| 194 |
+
'opcode_class': opcode_class_name,
|
| 195 |
+
'extended_name': self.extended_name(),
|
| 196 |
+
'threadblock': threadblock,
|
| 197 |
+
'layout': self.layout_name(),
|
| 198 |
+
'side_mode': self.side_mode_name(),
|
| 199 |
+
'fill_mode': self.fill_mode_name(),
|
| 200 |
+
'diag_type': self.diag_type_name(),
|
| 201 |
+
'alignment': "%d" % self.C.alignment,
|
| 202 |
+
}
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
#
|
| 206 |
+
def configuration_name(self):
|
| 207 |
+
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 208 |
+
return self.procedural_name()
|
| 209 |
+
|
| 210 |
+
###################################################################################################
|
| 211 |
+
#
|
| 212 |
+
# Emits single instances of a CUTLASS device-wide operator
|
| 213 |
+
#
|
| 214 |
+
###################################################################################################
|
| 215 |
+
|
| 216 |
+
#
|
| 217 |
+
class EmitTrmmUniversalInstance:
|
| 218 |
+
''' Responsible for emitting a CUTLASS template definition'''
|
| 219 |
+
|
| 220 |
+
def __init__(self):
|
| 221 |
+
self.trmm_template = """
|
| 222 |
+
// Trmm operator ${operation_name}
|
| 223 |
+
using Operation_${operation_name} =
|
| 224 |
+
typename cutlass::gemm::device::Trmm<
|
| 225 |
+
${element_a}, ${layout_a},
|
| 226 |
+
${side_mode}, ${fill_mode}, ${diag_type},
|
| 227 |
+
${element_b}, ${layout_b},
|
| 228 |
+
${element_c}, ${layout_c},
|
| 229 |
+
${element_accumulator},
|
| 230 |
+
${opcode_class},
|
| 231 |
+
${arch},
|
| 232 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 233 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 234 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 235 |
+
${epilogue_functor}<
|
| 236 |
+
${element_c},
|
| 237 |
+
${epilogue_vector_length},
|
| 238 |
+
${element_accumulator},
|
| 239 |
+
${element_epilogue},
|
| 240 |
+
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
| 241 |
+
>,
|
| 242 |
+
${swizzling_functor},
|
| 243 |
+
${stages},
|
| 244 |
+
${align_a},
|
| 245 |
+
${align_b},
|
| 246 |
+
${split_k_serial},
|
| 247 |
+
${math_operation}
|
| 248 |
+
>;
|
| 249 |
+
"""
|
| 250 |
+
self.trmm_complex_template = """
|
| 251 |
+
// Trmm operator ${operation_name}
|
| 252 |
+
using Operation_${operation_name} =
|
| 253 |
+
typename cutlass::gemm::device::Trmm<
|
| 254 |
+
${element_a}, ${layout_a},
|
| 255 |
+
${side_mode}, ${fill_mode}, ${diag_type},
|
| 256 |
+
${element_b}, ${layout_b},
|
| 257 |
+
${element_c}, ${layout_c},
|
| 258 |
+
${element_accumulator},
|
| 259 |
+
${opcode_class},
|
| 260 |
+
${arch},
|
| 261 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 262 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 263 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 264 |
+
${epilogue_functor}<
|
| 265 |
+
${element_c},
|
| 266 |
+
${epilogue_vector_length},
|
| 267 |
+
${element_accumulator},
|
| 268 |
+
${element_epilogue},
|
| 269 |
+
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
| 270 |
+
>,
|
| 271 |
+
${swizzling_functor},
|
| 272 |
+
${stages},
|
| 273 |
+
${align_a},
|
| 274 |
+
${align_b},
|
| 275 |
+
${split_k_serial},
|
| 276 |
+
${math_operation},
|
| 277 |
+
${transform_a}
|
| 278 |
+
>;
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
def emit(self, operation):
|
| 282 |
+
|
| 283 |
+
threadblock_shape = operation.tile_description.threadblock_shape
|
| 284 |
+
warp_count = operation.tile_description.warp_count
|
| 285 |
+
|
| 286 |
+
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
| 287 |
+
|
| 288 |
+
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
| 289 |
+
|
| 290 |
+
values = {
|
| 291 |
+
'operation_name': operation.procedural_name(),
|
| 292 |
+
'element_a': DataTypeTag[operation.A.element],
|
| 293 |
+
'layout_a': LayoutTag[operation.A.layout],
|
| 294 |
+
'side_mode' : SideModeTag[operation.A.side_mode],
|
| 295 |
+
'fill_mode': FillModeTag[operation.A.fill_mode],
|
| 296 |
+
'diag_type' : DiagTypeTag[operation.A.diag_type],
|
| 297 |
+
'element_b': DataTypeTag[operation.B.element],
|
| 298 |
+
'layout_b': LayoutTag[operation.B.layout],
|
| 299 |
+
'element_c': DataTypeTag[operation.C.element],
|
| 300 |
+
'layout_c': LayoutTag[operation.C.layout],
|
| 301 |
+
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
| 302 |
+
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 303 |
+
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
| 304 |
+
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
| 305 |
+
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
| 306 |
+
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
| 307 |
+
'warp_shape_m': str(warp_shape[0]),
|
| 308 |
+
'warp_shape_n': str(warp_shape[1]),
|
| 309 |
+
'warp_shape_k': str(warp_shape[2]),
|
| 310 |
+
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 311 |
+
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 312 |
+
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 313 |
+
'epilogue_vector_length': str(epilogue_vector_length),
|
| 314 |
+
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 315 |
+
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
| 316 |
+
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
| 317 |
+
'stages': str(operation.tile_description.stages),
|
| 318 |
+
'align_a': str(1), # TRMM A's alignment is always 1 for no padding to work until we make zfill work with variable bytes
|
| 319 |
+
'align_b': str(operation.B.alignment),
|
| 320 |
+
'split_k_serial': 'false',
|
| 321 |
+
'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
| 322 |
+
'transform_a': ComplexTransformTag[operation.A.complex_transform]
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
trmm_template = self.trmm_complex_template if operation.is_complex() else self.trmm_template
|
| 326 |
+
|
| 327 |
+
return SubstituteTemplate(trmm_template, values)
|
| 328 |
+
|
| 329 |
+
###################################################################################################
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
###################################################################################################
|
| 333 |
+
#
|
| 334 |
+
# Emitters functions for all targets
|
| 335 |
+
#
|
| 336 |
+
###################################################################################################
|
| 337 |
+
|
| 338 |
+
class EmitTrmmConfigurationLibrary:
|
| 339 |
+
def __init__(self, operation_path, configuration_name):
|
| 340 |
+
self.configuration_name = configuration_name
|
| 341 |
+
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
|
| 342 |
+
|
| 343 |
+
self.instance_emitter = {
|
| 344 |
+
TrmmKind.Universal: EmitTrmmUniversalInstance,
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
self.trmm_kind_wrappers = {
|
| 348 |
+
TrmmKind.Universal: 'TrmmOperation',
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
self.instance_template = {
|
| 352 |
+
TrmmKind.Universal: """
|
| 353 |
+
${compile_guard_start}
|
| 354 |
+
manifest.append(new ${trmm_kind}<
|
| 355 |
+
Operation_${operation_name}
|
| 356 |
+
>("${operation_name}"));
|
| 357 |
+
${compile_guard_end}
|
| 358 |
+
"""
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
self.header_template = """
|
| 362 |
+
/*
|
| 363 |
+
Generated by trmm_operation.py - Do not edit.
|
| 364 |
+
*/
|
| 365 |
+
|
| 366 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 367 |
+
#include "cutlass/cutlass.h"
|
| 368 |
+
#include "cutlass/library/library.h"
|
| 369 |
+
#include "cutlass/library/manifest.h"
|
| 370 |
+
|
| 371 |
+
#include "library_internal.h"
|
| 372 |
+
#include "trmm_operation.h"
|
| 373 |
+
|
| 374 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 375 |
+
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
self.initialize_function_template = """
|
| 379 |
+
|
| 380 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 381 |
+
|
| 382 |
+
namespace cutlass {
|
| 383 |
+
namespace library {
|
| 384 |
+
|
| 385 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 386 |
+
|
| 387 |
+
void initialize_${configuration_name}(Manifest &manifest) {
|
| 388 |
+
|
| 389 |
+
"""
|
| 390 |
+
self.epilogue_template = """
|
| 391 |
+
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 395 |
+
|
| 396 |
+
} // namespace library
|
| 397 |
+
} // namespace cutlass
|
| 398 |
+
|
| 399 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 400 |
+
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
def __enter__(self):
|
| 404 |
+
self.configuration_file = open(self.configuration_path, "w")
|
| 405 |
+
self.configuration_file.write(self.header_template)
|
| 406 |
+
|
| 407 |
+
self.instance_definitions = []
|
| 408 |
+
self.instance_wrappers = []
|
| 409 |
+
|
| 410 |
+
self.operations = []
|
| 411 |
+
return self
|
| 412 |
+
|
| 413 |
+
def emit(self, operation):
|
| 414 |
+
emitter = self.instance_emitter[operation.trmm_kind]()
|
| 415 |
+
|
| 416 |
+
self.operations.append(operation)
|
| 417 |
+
|
| 418 |
+
self.instance_definitions.append(emitter.emit(operation))
|
| 419 |
+
|
| 420 |
+
self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.trmm_kind], {
|
| 421 |
+
'configuration_name': self.configuration_name,
|
| 422 |
+
'operation_name': operation.procedural_name(),
|
| 423 |
+
'trmm_kind': self.trmm_kind_wrappers[operation.trmm_kind],
|
| 424 |
+
'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
|
| 425 |
+
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
|
| 426 |
+
'compile_guard_end': "#endif" \
|
| 427 |
+
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
|
| 428 |
+
}))
|
| 429 |
+
|
| 430 |
+
def __exit__(self, exception_type, exception_value, traceback):
|
| 431 |
+
|
| 432 |
+
# Write instance definitions in top-level namespace
|
| 433 |
+
for instance_definition in self.instance_definitions:
|
| 434 |
+
self.configuration_file.write(instance_definition)
|
| 435 |
+
|
| 436 |
+
# Add wrapper objects within initialize() function
|
| 437 |
+
self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
|
| 438 |
+
'configuration_name': self.configuration_name
|
| 439 |
+
}))
|
| 440 |
+
|
| 441 |
+
for instance_wrapper in self.instance_wrappers:
|
| 442 |
+
self.configuration_file.write(instance_wrapper)
|
| 443 |
+
|
| 444 |
+
self.configuration_file.write(self.epilogue_template)
|
| 445 |
+
self.configuration_file.close()
|
| 446 |
+
|
| 447 |
+
###################################################################################################
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/docs_src/source/conf.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
# Configuration file for the Sphinx documentation builder.
|
| 34 |
+
#
|
| 35 |
+
# For the full list of built-in configuration values, see the documentation:
|
| 36 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
| 37 |
+
|
| 38 |
+
# -- Path setup --------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
# If extensions (or modules to document with autodoc) are in another directory,
|
| 41 |
+
# add these directories to sys.path here. If the directory is relative to the
|
| 42 |
+
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
| 43 |
+
#
|
| 44 |
+
import os
|
| 45 |
+
import sys
|
| 46 |
+
|
| 47 |
+
sys.path.insert(0, os.path.abspath('..'))
|
| 48 |
+
sys.path.insert(0, os.path.abspath('../..'))
|
| 49 |
+
sys.path.insert(0, os.path.abspath('../../media/docs'))
|
| 50 |
+
|
| 51 |
+
# -- Project information -----------------------------------------------------
|
| 52 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
| 53 |
+
|
| 54 |
+
project = 'CUTLASS Python interface'
|
| 55 |
+
copyright = '2023, NVIDIA'
|
| 56 |
+
author = 'NVIDIA'
|
| 57 |
+
release = '3.1.0'
|
| 58 |
+
|
| 59 |
+
# -- General configuration ---------------------------------------------------
|
| 60 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Add any Sphinx extension module names here, as strings. They can be
|
| 64 |
+
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
| 65 |
+
# ones.
|
| 66 |
+
extensions = [
|
| 67 |
+
'myst_parser',
|
| 68 |
+
'nbsphinx',
|
| 69 |
+
'nbsphinx_link',
|
| 70 |
+
'sphinx_copybutton',
|
| 71 |
+
'sphinx.ext.autodoc',
|
| 72 |
+
'sphinx.ext.autosectionlabel',
|
| 73 |
+
'sphinx.ext.autosummary',
|
| 74 |
+
'sphinx.ext.coverage',
|
| 75 |
+
'sphinx.ext.extlinks',
|
| 76 |
+
'sphinx.ext.ifconfig',
|
| 77 |
+
'sphinx.ext.intersphinx',
|
| 78 |
+
'sphinx.ext.mathjax',
|
| 79 |
+
'sphinx.ext.napoleon',
|
| 80 |
+
'sphinx.ext.viewcode',
|
| 81 |
+
'sphinx_inline_tabs',
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
source_suffix = {
|
| 85 |
+
'.rst': 'restructuredtext',
|
| 86 |
+
'.md': 'markdown',
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
autodoc_typehints = 'description'
|
| 90 |
+
|
| 91 |
+
pygments_style = "sphinx"
|
| 92 |
+
pygments_dark_style = "monokai"
|
| 93 |
+
|
| 94 |
+
templates_path = ['_templates']
|
| 95 |
+
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
| 96 |
+
|
| 97 |
+
# Ignore errors when converting notebooks
|
| 98 |
+
nbsphinx_allow_errors = True
|
| 99 |
+
|
| 100 |
+
language = 'en'
|
| 101 |
+
# -- Options for HTML output -------------------------------------------------
|
| 102 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
| 103 |
+
|
| 104 |
+
html_static_path = ['_static']
|
| 105 |
+
|
| 106 |
+
html_title = "CUTLASS Python"
|
| 107 |
+
html_baseurl = 'docs'
|
| 108 |
+
html_theme = 'furo'
|
| 109 |
+
html_theme_options = {
|
| 110 |
+
"light_logo": "cutlass-logo-small.png",
|
| 111 |
+
"dark_logo": "cutlass-logo-small.png",
|
| 112 |
+
"light_css_variables": {
|
| 113 |
+
"color-brand-primary": "#76B900",
|
| 114 |
+
"color-brand-content": "#76B900",
|
| 115 |
+
},
|
| 116 |
+
"dark_css_variables": {
|
| 117 |
+
"color-brand-primary": "#76B900",
|
| 118 |
+
"color-brand-content": "#76B900",
|
| 119 |
+
},
|
| 120 |
+
"footer_icons": [
|
| 121 |
+
{
|
| 122 |
+
"name": "GitHub",
|
| 123 |
+
"url": "https://github.com/NVIDIA/cutlass",
|
| 124 |
+
"html": """
|
| 125 |
+
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
|
| 126 |
+
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
|
| 127 |
+
</svg>
|
| 128 |
+
""",
|
| 129 |
+
"class": "",
|
| 130 |
+
},
|
| 131 |
+
],
|
| 132 |
+
}
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from .int_tuple import *
|
| 34 |
+
from .layout import *
|
| 35 |
+
from .swizzle import *
|
| 36 |
+
from .typing import *
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/int_tuple.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Functions for manipulating IntTuples
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from functools import reduce
|
| 38 |
+
from itertools import chain
|
| 39 |
+
from typing import Union
|
| 40 |
+
from .typing import Integer
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def is_int(x):
|
| 44 |
+
return isinstance(x, Integer)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def is_tuple(x):
|
| 48 |
+
return isinstance(x, tuple)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def flatten(t):
|
| 52 |
+
if is_tuple(t):
|
| 53 |
+
if len(t) == 0:
|
| 54 |
+
return ()
|
| 55 |
+
else:
|
| 56 |
+
return tuple(i for a in t for i in flatten(a))
|
| 57 |
+
else:
|
| 58 |
+
return (t,)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def signum(a):
|
| 62 |
+
return bool(a > 0) - bool(a < 0)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def product(a):
|
| 66 |
+
if is_tuple(a):
|
| 67 |
+
return reduce(lambda val,elem : val*product(elem), a, 1)
|
| 68 |
+
else:
|
| 69 |
+
return a
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def inner_product(a, b):
|
| 73 |
+
if is_tuple(a): # tuple tuple
|
| 74 |
+
assert len(a) == len(b)
|
| 75 |
+
return sum(inner_product(x,y) for x,y in zip(a,b))
|
| 76 |
+
else: # "int" "int"
|
| 77 |
+
assert not is_tuple(b)
|
| 78 |
+
return a * b
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def tuple_max(a):
|
| 82 |
+
if is_tuple(a):
|
| 83 |
+
return max(tuple_max(x) for x in a)
|
| 84 |
+
else:
|
| 85 |
+
return a
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def elem_scale(a, b):
|
| 89 |
+
if is_tuple(a):
|
| 90 |
+
if is_tuple(b): # tuple tuple
|
| 91 |
+
assert len(a) == len(b)
|
| 92 |
+
return tuple(elem_scale(x,y) for x,y in zip(a,b))
|
| 93 |
+
else: # tuple "int"
|
| 94 |
+
assert False # Error
|
| 95 |
+
else:
|
| 96 |
+
if is_tuple(b): # "int" tuple
|
| 97 |
+
return elem_scale(a, product(b))
|
| 98 |
+
else: # "int" "int"
|
| 99 |
+
return a * b
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Inclusive prefix ceil div with output congruent to input a
|
| 103 |
+
def shape_div(a, b):
|
| 104 |
+
if is_tuple(a):
|
| 105 |
+
if is_tuple(b): # tuple tuple
|
| 106 |
+
assert len(a) == len(b)
|
| 107 |
+
return tuple(shape_div(x,y) for x,y in zip(a,b))
|
| 108 |
+
else: # tuple "int"
|
| 109 |
+
#r = [shape_div(a[0],b)] + [shape_div(a[i],b := shape_div(b, product(a[i-1]))) for i in range(1,len(a))]
|
| 110 |
+
r = []
|
| 111 |
+
for v in a:
|
| 112 |
+
r.append(shape_div(v,b))
|
| 113 |
+
b = shape_div(b,product(v))
|
| 114 |
+
return tuple(r)
|
| 115 |
+
else:
|
| 116 |
+
if is_tuple(b): # "int" tuple
|
| 117 |
+
return shape_div(a, product(b))
|
| 118 |
+
else: # "int" "int"
|
| 119 |
+
assert a % b == 0 or b % a == 0
|
| 120 |
+
return (a + b - 1) // b
|
| 121 |
+
|
| 122 |
+
# Exclusive prefix product with output congruent to input a
|
| 123 |
+
def prefix_product(a, init=1):
|
| 124 |
+
if is_tuple(a):
|
| 125 |
+
if is_tuple(init): # tuple tuple
|
| 126 |
+
assert len(a) == len(init)
|
| 127 |
+
return tuple(prefix_product(x,i) for x,i in zip(a,init))
|
| 128 |
+
else: # tuple "int"
|
| 129 |
+
#r = [prefix_product(a[0],init)] + [prefix_product(a[i],init := init * product(a[i-1])) for i in range(1,len(a))]
|
| 130 |
+
r = []
|
| 131 |
+
for v in a:
|
| 132 |
+
r.append(prefix_product(v,init))
|
| 133 |
+
init = init * product(v)
|
| 134 |
+
return tuple(r)
|
| 135 |
+
else:
|
| 136 |
+
if is_tuple(init): # "int" tuple
|
| 137 |
+
assert False # Error
|
| 138 |
+
else: # "int" "int"
|
| 139 |
+
return init
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def idx2crd(idx, shape, stride=None):
|
| 143 |
+
if stride is None:
|
| 144 |
+
stride = prefix_product(shape)
|
| 145 |
+
|
| 146 |
+
if is_tuple(idx):
|
| 147 |
+
if is_tuple(shape): # tuple tuple tuple
|
| 148 |
+
assert len(idx) == len(shape) and len(idx) == len(stride)
|
| 149 |
+
return tuple(idx2crd(i, s, d) for i, s, d in zip(idx,shape,stride))
|
| 150 |
+
else: # tuple "int" "int"
|
| 151 |
+
assert False # Error
|
| 152 |
+
else:
|
| 153 |
+
if is_tuple(shape): # "int" tuple tuple
|
| 154 |
+
assert len(shape) == len(stride)
|
| 155 |
+
return tuple(idx2crd(idx, s, d) for s,d in zip(shape,stride))
|
| 156 |
+
else: # "int" "int" "int"
|
| 157 |
+
return (idx // stride) % shape
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def crd2idx(crd, shape, stride=None):
|
| 161 |
+
if stride is None:
|
| 162 |
+
stride = prefix_product(shape)
|
| 163 |
+
|
| 164 |
+
if is_tuple(crd):
|
| 165 |
+
if is_tuple(shape): # tuple tuple tuple
|
| 166 |
+
assert len(crd) == len(shape) and len(crd) == len(stride)
|
| 167 |
+
return sum(crd2idx(c, s, d) for c, s, d in zip(crd, shape, stride))
|
| 168 |
+
else: # tuple "int" "int"
|
| 169 |
+
assert False, f"crd={crd}, shape={shape}" # Error
|
| 170 |
+
else:
|
| 171 |
+
if crd is None:
|
| 172 |
+
crd = 0
|
| 173 |
+
|
| 174 |
+
if is_tuple(shape): # "int" tuple tuple
|
| 175 |
+
assert len(shape) == len(stride)
|
| 176 |
+
result = 0
|
| 177 |
+
for i in range(len(shape)-1):
|
| 178 |
+
result += crd2idx(crd % product(shape[i]), shape[i], stride[i])
|
| 179 |
+
crd = crd // product(shape[i])
|
| 180 |
+
return result + crd2idx(crd, shape[-1], stride[-1])
|
| 181 |
+
else: # "int" "int" "int"
|
| 182 |
+
return crd * stride
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# Transform crd into the dst_shape's iteration space
|
| 186 |
+
def crd2crd(crd, dst_shape, src_shape=None):
|
| 187 |
+
if is_tuple(crd):
|
| 188 |
+
if is_tuple(dst_shape): # tuple tuple
|
| 189 |
+
assert len(crd) == len(dst_shape)
|
| 190 |
+
return tuple(crd2crd(x, y) for x, y in zip(crd,dst_shape))
|
| 191 |
+
else: # tuple "int"
|
| 192 |
+
# Ambiguous unless we have src_shape
|
| 193 |
+
assert src_shape is not None
|
| 194 |
+
return crd2idx(crd, src_shape)
|
| 195 |
+
else:
|
| 196 |
+
if is_tuple(dst_shape): # "int" tuple
|
| 197 |
+
return idx2crd(crd, dst_shape)
|
| 198 |
+
else: # "int" "int"
|
| 199 |
+
assert crd < dst_shape
|
| 200 |
+
return crd
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# Filter trg according to crd: keep only elements of trg that are paired with None
|
| 204 |
+
def slice_(crd: Union[None, tuple, int],
|
| 205 |
+
trg: Union[tuple, int]):
|
| 206 |
+
if is_tuple(crd):
|
| 207 |
+
if is_tuple(trg): # tuple tuple
|
| 208 |
+
assert len(crd) == len(trg)
|
| 209 |
+
# match C++ behavior of `filter_tuple` using `tuple_cat(...)`
|
| 210 |
+
return tuple(chain(*filter(lambda x: x != (), [slice_(c, s) for c, s in zip(crd, trg)])))
|
| 211 |
+
else:
|
| 212 |
+
assert False # tuple "int" : Error
|
| 213 |
+
elif crd is None:
|
| 214 |
+
# match C++ behavior `return cute::tuple<B>{b};`
|
| 215 |
+
return (trg,)
|
| 216 |
+
else:
|
| 217 |
+
return ()
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# Determine if None appears at any of an int_tuples' terminals
|
| 221 |
+
def has_none(a: Union[None, tuple, int]):
|
| 222 |
+
if is_tuple(a):
|
| 223 |
+
return any(has_none(v) for v in a)
|
| 224 |
+
else:
|
| 225 |
+
return a is None
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/layout.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Definition of CuTe Layouts and functions to manipulate them
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from itertools import chain
|
| 38 |
+
from typing import Union
|
| 39 |
+
|
| 40 |
+
from .int_tuple import *
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class LayoutBase:
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def is_layout(x):
|
| 48 |
+
return isinstance(x, LayoutBase)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Layout(LayoutBase):
|
| 52 |
+
def __init__(self, _shape, _stride=None):
|
| 53 |
+
self.shape = _shape
|
| 54 |
+
if _stride is None:
|
| 55 |
+
self.stride = prefix_product(self.shape)
|
| 56 |
+
else:
|
| 57 |
+
self.stride = _stride
|
| 58 |
+
|
| 59 |
+
# operator ==
|
| 60 |
+
def __eq__(self, other):
|
| 61 |
+
return self.shape == other.shape and self.stride == other.stride
|
| 62 |
+
|
| 63 |
+
# operator len(L) (len [rank] like tuples)
|
| 64 |
+
def __len__(self):
|
| 65 |
+
if is_tuple(self.shape):
|
| 66 |
+
return len(self.shape)
|
| 67 |
+
else:
|
| 68 |
+
return 1
|
| 69 |
+
|
| 70 |
+
# operator () (map coord to idx)
|
| 71 |
+
def __call__(self, *args):
|
| 72 |
+
"""
|
| 73 |
+
Map a logical coordinate to a linear index (Coord has no Underscore slice operators)
|
| 74 |
+
OR
|
| 75 |
+
Slice the layout and return the sublayout (Coord has an Underscore slice op)
|
| 76 |
+
|
| 77 |
+
Follow the same behavior of `Layout::operator(Coord const&)` in cute C++
|
| 78 |
+
"""
|
| 79 |
+
if has_none(args):
|
| 80 |
+
if len(args) == 1:
|
| 81 |
+
return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride))
|
| 82 |
+
else:
|
| 83 |
+
return Layout(slice_(args, self.shape), slice_(args, self.stride))
|
| 84 |
+
else:
|
| 85 |
+
if len(args) == 1:
|
| 86 |
+
return crd2idx(args[0], self.shape, self.stride)
|
| 87 |
+
else:
|
| 88 |
+
return crd2idx(args, self.shape, self.stride)
|
| 89 |
+
|
| 90 |
+
# operator [] (get-i like tuples)
|
| 91 |
+
def __getitem__(self, i):
|
| 92 |
+
if is_tuple(self.shape):
|
| 93 |
+
return Layout(self.shape[i], self.stride[i])
|
| 94 |
+
else:
|
| 95 |
+
assert i == 0
|
| 96 |
+
return Layout(self.shape, self.stride)
|
| 97 |
+
|
| 98 |
+
# size(layout) Size of the domain
|
| 99 |
+
def size(self):
|
| 100 |
+
return product(self.shape)
|
| 101 |
+
|
| 102 |
+
# cosize(layout) Size of the codomain
|
| 103 |
+
def cosize(self):
|
| 104 |
+
return self(self.size() - 1) + 1
|
| 105 |
+
|
| 106 |
+
# print and str
|
| 107 |
+
def __str__(self):
|
| 108 |
+
return f"{self.shape}:{self.stride}"
|
| 109 |
+
|
| 110 |
+
# error msgs and representation
|
| 111 |
+
def __repr__(self):
|
| 112 |
+
return f"Layout({self.shape},{self.stride})"
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# Make Layout from a list of layouts (each layout it's own mode in the result)
|
| 116 |
+
def make_layout(*layouts):
|
| 117 |
+
if len(layouts) == 1 and not is_layout(layouts[0]):
|
| 118 |
+
layouts = layouts[0]
|
| 119 |
+
|
| 120 |
+
shape, stride = zip(*((a.shape,a.stride) for a in layouts))
|
| 121 |
+
return Layout(shape, stride)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# Size of the domain
|
| 125 |
+
def size(layout):
|
| 126 |
+
if is_layout(layout):
|
| 127 |
+
return layout.size()
|
| 128 |
+
return product(layout)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# Size of the codomain
|
| 132 |
+
def cosize(layout):
|
| 133 |
+
return layout.cosize()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function
|
| 137 |
+
def coalesce(layout, profile=None):
|
| 138 |
+
if is_tuple(profile):
|
| 139 |
+
assert len(layout) >= len(profile)
|
| 140 |
+
return make_layout(chain((coalesce(layout[i], profile[i]) for i in range( 0,len(profile))),
|
| 141 |
+
(layout[i] for i in range(len(profile),len(layout)))))
|
| 142 |
+
|
| 143 |
+
result_shape = [1]
|
| 144 |
+
result_stride = [0]
|
| 145 |
+
for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)):
|
| 146 |
+
# skip their shape-1s
|
| 147 |
+
if shape == 1:
|
| 148 |
+
continue
|
| 149 |
+
# replace our shape-1 with anything
|
| 150 |
+
elif result_shape[-1] == 1:
|
| 151 |
+
result_shape[-1] = shape
|
| 152 |
+
result_stride[-1] = stride
|
| 153 |
+
# merge modes if the shape*stride match
|
| 154 |
+
elif result_shape[-1] * result_stride[-1] == stride:
|
| 155 |
+
result_shape[-1] = result_shape[-1] * shape
|
| 156 |
+
# append a new mode
|
| 157 |
+
else:
|
| 158 |
+
result_shape.append(shape)
|
| 159 |
+
result_stride.append(stride)
|
| 160 |
+
|
| 161 |
+
if len(result_shape) == 1:
|
| 162 |
+
return Layout(result_shape[0], result_stride[0])
|
| 163 |
+
else:
|
| 164 |
+
return Layout(tuple(result_shape), tuple(result_stride))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them
|
| 168 |
+
def filter(layout, profile=None):
|
| 169 |
+
if is_tuple(profile):
|
| 170 |
+
assert len(layout) >= len(profile)
|
| 171 |
+
return make_layout(chain((filter(layout[i], profile[i]) for i in range( 0,len(profile))),
|
| 172 |
+
(layout[i] for i in range(len(profile),len(layout)))))
|
| 173 |
+
|
| 174 |
+
result_shape = []
|
| 175 |
+
result_stride = []
|
| 176 |
+
for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)):
|
| 177 |
+
# skip their shape-1s and stride-0s
|
| 178 |
+
if not (shape == 1 or stride == 0):
|
| 179 |
+
result_shape.append(shape)
|
| 180 |
+
result_stride.append(stride)
|
| 181 |
+
|
| 182 |
+
if len(result_shape) == 0:
|
| 183 |
+
return Layout(1,0)
|
| 184 |
+
else:
|
| 185 |
+
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# Layout composition
|
| 189 |
+
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
|
| 190 |
+
def composition(layoutA, layoutB):
|
| 191 |
+
if layoutB is None:
|
| 192 |
+
return layoutA
|
| 193 |
+
elif is_int(layoutB):
|
| 194 |
+
return composition(layoutA, Layout(layoutB))
|
| 195 |
+
elif is_tuple(layoutB):
|
| 196 |
+
assert len(layoutA) >= len(layoutB)
|
| 197 |
+
return make_layout(chain((composition(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))),
|
| 198 |
+
(layoutA[i] for i in range(len(layoutB),len(layoutA)))))
|
| 199 |
+
elif is_tuple(layoutB.shape):
|
| 200 |
+
return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB)
|
| 201 |
+
|
| 202 |
+
if layoutB.stride == 0:
|
| 203 |
+
return Layout(layoutB.shape, 0)
|
| 204 |
+
else:
|
| 205 |
+
result_shape = []
|
| 206 |
+
result_stride = []
|
| 207 |
+
rest_shape = layoutB.shape
|
| 208 |
+
rest_stride = layoutB.stride
|
| 209 |
+
flat_A = coalesce(layoutA)
|
| 210 |
+
for (curr_shape, curr_stride) in zip(flatten(flat_A.shape)[:-1], flatten(flat_A.stride)[:-1]):
|
| 211 |
+
assert curr_shape % rest_stride == 0 or rest_stride % curr_shape == 0
|
| 212 |
+
new_shape = min(max(1, curr_shape // rest_stride), rest_shape)
|
| 213 |
+
|
| 214 |
+
if new_shape != 1:
|
| 215 |
+
result_shape.append(new_shape)
|
| 216 |
+
result_stride.append(rest_stride * curr_stride)
|
| 217 |
+
|
| 218 |
+
rest_shape = rest_shape // new_shape
|
| 219 |
+
rest_stride = -(-rest_stride // curr_shape) # Python exclusive impl: "//" is always floor div so == ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride)
|
| 220 |
+
|
| 221 |
+
if rest_shape != 1 or len(result_shape) == 0:
|
| 222 |
+
result_shape.append(rest_shape)
|
| 223 |
+
result_stride.append(rest_stride * flatten(flat_A.stride)[-1])
|
| 224 |
+
|
| 225 |
+
if len(result_shape) == 1:
|
| 226 |
+
return Layout(result_shape[0], result_stride[0])
|
| 227 |
+
else:
|
| 228 |
+
return Layout(tuple(result_shape), tuple(result_stride))
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# Layout complement
|
| 232 |
+
def complement(layout, max_idx=1):
|
| 233 |
+
if is_int(layout):
|
| 234 |
+
return complement(Layout(layout))
|
| 235 |
+
|
| 236 |
+
result_shape = []
|
| 237 |
+
result_stride = []
|
| 238 |
+
current_idx = 1
|
| 239 |
+
|
| 240 |
+
sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape)))
|
| 241 |
+
for (stride, shape) in sorted_DS:
|
| 242 |
+
if stride == 0 or shape == 1:
|
| 243 |
+
continue
|
| 244 |
+
|
| 245 |
+
in_bound = current_idx <= shape * stride
|
| 246 |
+
# To support symbolic value which can't be evaluated now
|
| 247 |
+
assert (type(in_bound) is not bool) or in_bound
|
| 248 |
+
|
| 249 |
+
result_shape.append(stride // current_idx)
|
| 250 |
+
result_stride.append(current_idx)
|
| 251 |
+
current_idx = shape * stride
|
| 252 |
+
|
| 253 |
+
result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div
|
| 254 |
+
result_stride.append(current_idx)
|
| 255 |
+
|
| 256 |
+
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# Layout right inverse
|
| 260 |
+
def right_inverse(layout):
|
| 261 |
+
if layout is None:
|
| 262 |
+
return None
|
| 263 |
+
elif is_int(layout):
|
| 264 |
+
return Layout(layout)
|
| 265 |
+
|
| 266 |
+
result_shape = []
|
| 267 |
+
result_stride = []
|
| 268 |
+
current_idx = 1
|
| 269 |
+
|
| 270 |
+
flat_shape = flatten(layout.shape)
|
| 271 |
+
flat_stride = flatten(layout.stride)
|
| 272 |
+
sorted_DSA = sorted(zip(flat_stride, flat_shape, prefix_product(flat_shape)))
|
| 273 |
+
for (stride,shape,rstride) in sorted_DSA:
|
| 274 |
+
if shape == 1:
|
| 275 |
+
continue
|
| 276 |
+
if current_idx != stride:
|
| 277 |
+
break
|
| 278 |
+
|
| 279 |
+
result_shape.append(shape)
|
| 280 |
+
result_stride.append(rstride)
|
| 281 |
+
current_idx = shape * stride
|
| 282 |
+
|
| 283 |
+
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# Layout left inverse
|
| 287 |
+
def left_inverse(layout):
|
| 288 |
+
if layout is None:
|
| 289 |
+
return None
|
| 290 |
+
elif is_int(layout):
|
| 291 |
+
return Layout(layout)
|
| 292 |
+
return right_inverse(make_layout(layout, complement(layout)))
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# Split a layout by the composition of B and the "rest"
|
| 296 |
+
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
|
| 297 |
+
def logical_divide(layoutA, layoutB):
|
| 298 |
+
if layoutB is None:
|
| 299 |
+
return layoutA
|
| 300 |
+
elif is_int(layoutB):
|
| 301 |
+
return logical_divide(layoutA, Layout(layoutB))
|
| 302 |
+
elif is_tuple(layoutB):
|
| 303 |
+
assert len(layoutA) >= len(layoutB)
|
| 304 |
+
return make_layout(chain((logical_divide(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))),
|
| 305 |
+
(layoutA[i] for i in range(len(layoutB),len(layoutA)))))
|
| 306 |
+
|
| 307 |
+
return composition(layoutA, make_layout(layoutB, complement(layoutB, size(layoutA))))
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# Reproduce a layoutA over a layoutB
|
| 311 |
+
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
|
| 312 |
+
def logical_product(layoutA, layoutB):
|
| 313 |
+
if layoutB is None:
|
| 314 |
+
return layoutA
|
| 315 |
+
elif is_int(layoutB):
|
| 316 |
+
return logical_divide(layoutA, Layout(layoutB))
|
| 317 |
+
elif is_tuple(layoutB):
|
| 318 |
+
assert len(layoutA) >= len(layoutB)
|
| 319 |
+
return make_layout(chain((logical_product(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))),
|
| 320 |
+
(layoutA[i] for i in range(len(layoutB),len(layoutA)))))
|
| 321 |
+
|
| 322 |
+
return make_layout(layoutA, composition(complement(layoutA, size(layoutA)*cosize(layoutB)), layoutB));
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# Gather the modes from a hierarchical logical_divide or logical_product
|
| 326 |
+
def hier_unzip(splitter, layoutA, layoutB):
|
| 327 |
+
if layoutB is None:
|
| 328 |
+
return make_layout(Layout(1,0), layoutA)
|
| 329 |
+
elif is_tuple(layoutB):
|
| 330 |
+
assert len(layoutA) >= len(layoutB)
|
| 331 |
+
# A layout with shape ((A,a),(B,b),(C,c))
|
| 332 |
+
split = make_layout(hier_unzip(splitter, layoutA[i], layoutB[i]) for i in range(0,len(layoutB)))
|
| 333 |
+
# Gather to shape ((A,B,C,...),(a,b,c,...,y,z))
|
| 334 |
+
return make_layout(make_layout( split[i][0] for i in range( 0,len(layoutB))),
|
| 335 |
+
make_layout(chain((split[i][1] for i in range( 0,len(layoutB))),
|
| 336 |
+
(layoutA[i] for i in range(len(layoutB),len(layoutA))))))
|
| 337 |
+
|
| 338 |
+
# splitter must return a rank-2 layout
|
| 339 |
+
return splitter(layoutA, layoutB)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# Apply logical divide hierarchically and gather the split modes into two modes
|
| 343 |
+
def zipped_divide(layoutA, layoutB):
|
| 344 |
+
return hier_unzip(logical_divide, layoutA, layoutB)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
# Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode
|
| 348 |
+
def tiled_divide(layoutA, layoutB):
|
| 349 |
+
result = zipped_divide(layoutA, layoutB)
|
| 350 |
+
return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))])
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
# Apply logical product hierarchically and gather the split modes into two modes
|
| 354 |
+
def zipped_product(layoutA, layoutB):
|
| 355 |
+
return hier_unzip(logical_product, layoutA, layoutB)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
# Perform logical product hierarchically and gather tiles (B-layouts) into a new mode
|
| 359 |
+
def tiled_product(layoutA, layoutB):
|
| 360 |
+
result = zipped_product(layoutA, layoutB)
|
| 361 |
+
return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))])
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def slice_and_offset(crd: tuple,
|
| 365 |
+
layout: Layout):
|
| 366 |
+
return (Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)),
|
| 367 |
+
crd2idx(crd, layout.shape, layout.stride))
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/swizzle.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Methods for layout swizzling
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from .layout import *
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def shiftr(a, s):
|
| 41 |
+
return a >> s if s > 0 else shiftl(a, -s)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def shiftl(a, s):
|
| 45 |
+
return a << s if s > 0 else shiftr(a, -s)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
## A generic Swizzle functor
|
| 49 |
+
# 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx
|
| 50 |
+
# ^--^ Base is the number of least-sig bits to keep constant
|
| 51 |
+
# ^-^ ^-^ Bits is the number of bits in the mask
|
| 52 |
+
# ^---------^ Shift is the distance to shift the YYY mask
|
| 53 |
+
# (pos shifts YYY to the right, neg shifts YYY to the left)
|
| 54 |
+
#
|
| 55 |
+
# e.g. Given
|
| 56 |
+
# 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
|
| 57 |
+
# the result is
|
| 58 |
+
# 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
|
| 59 |
+
#
|
| 60 |
+
class Swizzle:
|
| 61 |
+
def __init__(self, bits, base, shift):
|
| 62 |
+
assert bits >= 0
|
| 63 |
+
assert base >= 0
|
| 64 |
+
assert abs(shift) >= bits
|
| 65 |
+
self.bits = bits
|
| 66 |
+
self.base = base
|
| 67 |
+
self.shift = shift
|
| 68 |
+
bit_msk = (1 << bits) - 1
|
| 69 |
+
self.yyy_msk = bit_msk << (base + max(0,shift))
|
| 70 |
+
self.zzz_msk = bit_msk << (base - min(0,shift))
|
| 71 |
+
|
| 72 |
+
# operator () (transform integer)
|
| 73 |
+
def __call__(self, offset):
|
| 74 |
+
return offset ^ shiftr(offset & self.yyy_msk, self.shift)
|
| 75 |
+
|
| 76 |
+
# Size of the domain
|
| 77 |
+
def size(self):
|
| 78 |
+
return 1 << (self.bits + self.base + abs(self.shift))
|
| 79 |
+
|
| 80 |
+
# Size of the codomain
|
| 81 |
+
def cosize(self):
|
| 82 |
+
return self.size()
|
| 83 |
+
|
| 84 |
+
# print and str
|
| 85 |
+
def __str__(self):
|
| 86 |
+
return f"SW_{self.bits}_{self.base}_{self.shift}"
|
| 87 |
+
|
| 88 |
+
# error msgs and representation
|
| 89 |
+
def __repr__(self):
|
| 90 |
+
return f"Swizzle({self.bits},{self.base},{self.shift})"
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class ComposedLayout(LayoutBase):
|
| 94 |
+
def __init__(self, layoutB, offset, layoutA):
|
| 95 |
+
self.layoutB = layoutB
|
| 96 |
+
self.offset = offset
|
| 97 |
+
self.layoutA = layoutA
|
| 98 |
+
|
| 99 |
+
# operator ==
|
| 100 |
+
def __eq__(self, other):
|
| 101 |
+
return self.layoutB == other.layoutB and self.offset == other.offset and self.layoutA == other.layoutA
|
| 102 |
+
|
| 103 |
+
# operator len(L) (len [rank] like tuples)
|
| 104 |
+
def __len__(self):
|
| 105 |
+
return len(self.layoutA)
|
| 106 |
+
|
| 107 |
+
# operator () (map coord to idx)
|
| 108 |
+
def __call__(self, *args):
|
| 109 |
+
return self.layoutB(self.offset + self.layoutA(*args))
|
| 110 |
+
|
| 111 |
+
# operator [] (get-i like tuples)
|
| 112 |
+
def __getitem__(self, i):
|
| 113 |
+
return ComposedLayout(self.layoutB, self.offset, self.layoutA[i])
|
| 114 |
+
|
| 115 |
+
# size(layout) Size of the domain
|
| 116 |
+
def size(self):
|
| 117 |
+
return size(self.layoutA)
|
| 118 |
+
|
| 119 |
+
# cosize(layout) Size of the codomain
|
| 120 |
+
def cosize(self):
|
| 121 |
+
return cosize(self.layoutB)
|
| 122 |
+
|
| 123 |
+
# print and str
|
| 124 |
+
def __str__(self):
|
| 125 |
+
return f"{self.layoutB} o {self.offset} o {self.layoutA}"
|
| 126 |
+
|
| 127 |
+
# error msgs and representation
|
| 128 |
+
def __repr__(self):
|
| 129 |
+
return f"ComposedLayout({repr(self.layoutB)},{repr(self.offset)},{repr(self.layoutA)})"
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/typing.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from abc import ABC
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Integer(ABC):
|
| 37 |
+
@classmethod
|
| 38 |
+
def __subclasshook__(cls, c):
|
| 39 |
+
if c in [bool, float]:
|
| 40 |
+
return False
|
| 41 |
+
|
| 42 |
+
return issubclass(c, int)
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_cutlass.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
import copy
|
| 35 |
+
import os
|
| 36 |
+
import setuptools
|
| 37 |
+
from setuptools import setup
|
| 38 |
+
from setuptools.command.build_ext import build_ext
|
| 39 |
+
|
| 40 |
+
import setup_pycute
|
| 41 |
+
import setup_library
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Install cutlass_library package
|
| 45 |
+
setup_library.perform_setup()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Install the PyCuTe package
|
| 49 |
+
setup_pycute.perform_setup()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
setup(
|
| 53 |
+
name='cutlass_cppgen',
|
| 54 |
+
version='4.2.0',
|
| 55 |
+
description='CUTLASS Pythonic Interface',
|
| 56 |
+
package_dir={'': '.'},
|
| 57 |
+
packages=[
|
| 58 |
+
'cutlass_cppgen',
|
| 59 |
+
'cutlass_cppgen.emit',
|
| 60 |
+
'cutlass_cppgen.op',
|
| 61 |
+
'cutlass_cppgen.utils',
|
| 62 |
+
'cutlass_cppgen.backend',
|
| 63 |
+
'cutlass_cppgen.backend.utils'
|
| 64 |
+
],
|
| 65 |
+
setup_requires=['pybind11'],
|
| 66 |
+
install_requires=[
|
| 67 |
+
'bfloat16',
|
| 68 |
+
'cuda-python>=11.8.0',
|
| 69 |
+
'pybind11',
|
| 70 |
+
'scikit-build',
|
| 71 |
+
'treelib',
|
| 72 |
+
'pydot'
|
| 73 |
+
]
|
| 74 |
+
)
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_library.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from setuptools import setup
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def perform_setup():
|
| 37 |
+
setup(
|
| 38 |
+
name='cutlass_library',
|
| 39 |
+
version='4.2.1',
|
| 40 |
+
description='CUTLASS library generation scripts',
|
| 41 |
+
packages=['cutlass_library']
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == '__main__':
|
| 46 |
+
perform_setup()
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_pycute.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from setuptools import setup
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def perform_setup():
|
| 37 |
+
setup(
|
| 38 |
+
name='pycute',
|
| 39 |
+
version='4.2.1',
|
| 40 |
+
description='Python implementation of CuTe',
|
| 41 |
+
packages=['pycute'],
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == '__main__':
|
| 46 |
+
perform_setup()
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py
ADDED
|
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for defining Conv2D problem sizes for testing.
|
| 35 |
+
|
| 36 |
+
This file was ported from the C++ version in test/unit/conv/device/conv2d_problems.h
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
from cutlass_library import ConvMode
|
| 40 |
+
|
| 41 |
+
import cutlass_cppgen
|
| 42 |
+
from cutlass_cppgen.shape import Conv2DProblemSize
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TestbedConv2dProblemSizes:
|
| 46 |
+
def __init__(self, minimum_channel_size: int):
|
| 47 |
+
conv2d_default_sizes = self.initialize_conv2d_default_sizes(minimum_channel_size)
|
| 48 |
+
conv2d_rigorous_sizes = self.initialize_conv2d_rigorous_sizes(minimum_channel_size)
|
| 49 |
+
conv2d_resnet50_sizes = self.initialize_conv2d_resnet50_sizes(1)
|
| 50 |
+
conv2d_resnet50_sizes_perf = self.initialize_conv2d_resnet50_sizes(34)
|
| 51 |
+
grouped_sizes = self.initialize_conv2d_grouped_sizes()
|
| 52 |
+
|
| 53 |
+
# Filter all problems
|
| 54 |
+
self.all = []
|
| 55 |
+
for size_list in [conv2d_default_sizes, conv2d_rigorous_sizes, conv2d_resnet50_sizes, conv2d_resnet50_sizes_perf, grouped_sizes]:
|
| 56 |
+
for size in size_list:
|
| 57 |
+
if (size.C // size.groups) % minimum_channel_size == 0:
|
| 58 |
+
self.all.append(size)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def initialize_conv2d_default_sizes(self, minimum_channel_size):
|
| 62 |
+
# Small input size x stride (1,1)
|
| 63 |
+
# C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64}
|
| 64 |
+
|
| 65 |
+
conv2d_default_sizes = []
|
| 66 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 67 |
+
1, 1, 1, minimum_channel_size,
|
| 68 |
+
8, 1, 1, minimum_channel_size,
|
| 69 |
+
1, 1,
|
| 70 |
+
1, 1,
|
| 71 |
+
1, 1,
|
| 72 |
+
))
|
| 73 |
+
|
| 74 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 75 |
+
1, 1, 8, minimum_channel_size,
|
| 76 |
+
8, 1, 3, minimum_channel_size,
|
| 77 |
+
1, 1,
|
| 78 |
+
1, 1,
|
| 79 |
+
1, 1,
|
| 80 |
+
))
|
| 81 |
+
|
| 82 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 83 |
+
1, 7, 8, minimum_channel_size,
|
| 84 |
+
8, 3, 3, minimum_channel_size,
|
| 85 |
+
1, 1,
|
| 86 |
+
1, 1,
|
| 87 |
+
1, 1,
|
| 88 |
+
))
|
| 89 |
+
|
| 90 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 91 |
+
1, 7, 9, minimum_channel_size,
|
| 92 |
+
8, 4, 4, minimum_channel_size,
|
| 93 |
+
1, 1,
|
| 94 |
+
1, 1,
|
| 95 |
+
1, 1,
|
| 96 |
+
))
|
| 97 |
+
|
| 98 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 99 |
+
2, 7, 9, minimum_channel_size,
|
| 100 |
+
8, 5, 5, minimum_channel_size,
|
| 101 |
+
1, 1,
|
| 102 |
+
1, 1,
|
| 103 |
+
1, 1,
|
| 104 |
+
))
|
| 105 |
+
|
| 106 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 107 |
+
3, 7, 9, minimum_channel_size,
|
| 108 |
+
8, 6, 5, minimum_channel_size,
|
| 109 |
+
1, 1,
|
| 110 |
+
1, 1,
|
| 111 |
+
1, 1,
|
| 112 |
+
))
|
| 113 |
+
|
| 114 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 115 |
+
3, 7, 9, minimum_channel_size,
|
| 116 |
+
8, 6, 6, minimum_channel_size,
|
| 117 |
+
1, 1,
|
| 118 |
+
1, 1,
|
| 119 |
+
1, 1,
|
| 120 |
+
))
|
| 121 |
+
|
| 122 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 123 |
+
3, 7, 9, minimum_channel_size,
|
| 124 |
+
8, 7, 7, minimum_channel_size,
|
| 125 |
+
1, 1,
|
| 126 |
+
1, 1,
|
| 127 |
+
1, 1,
|
| 128 |
+
))
|
| 129 |
+
|
| 130 |
+
##############################################
|
| 131 |
+
# Small input size x stride (2,2)
|
| 132 |
+
# C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64}
|
| 133 |
+
##############################################
|
| 134 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 135 |
+
1, 11, 7, minimum_channel_size,
|
| 136 |
+
8, 1, 1, minimum_channel_size,
|
| 137 |
+
0, 0,
|
| 138 |
+
2, 2,
|
| 139 |
+
1, 1,
|
| 140 |
+
))
|
| 141 |
+
|
| 142 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 143 |
+
1, 11, 7, minimum_channel_size,
|
| 144 |
+
8, 3, 3, minimum_channel_size,
|
| 145 |
+
1, 1,
|
| 146 |
+
2, 2,
|
| 147 |
+
1, 1,
|
| 148 |
+
))
|
| 149 |
+
|
| 150 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 151 |
+
1, 13, 11, minimum_channel_size,
|
| 152 |
+
8, 1, 1, minimum_channel_size,
|
| 153 |
+
1, 1,
|
| 154 |
+
2, 2,
|
| 155 |
+
1, 1,
|
| 156 |
+
))
|
| 157 |
+
|
| 158 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 159 |
+
1, 17, 19, minimum_channel_size,
|
| 160 |
+
16, 2, 2, minimum_channel_size,
|
| 161 |
+
1, 1,
|
| 162 |
+
2, 2,
|
| 163 |
+
1, 1,
|
| 164 |
+
))
|
| 165 |
+
|
| 166 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 167 |
+
1, 23, 5, minimum_channel_size,
|
| 168 |
+
16, 3, 3, minimum_channel_size,
|
| 169 |
+
1, 1,
|
| 170 |
+
2, 2,
|
| 171 |
+
1, 1,
|
| 172 |
+
))
|
| 173 |
+
|
| 174 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 175 |
+
1, 13, 17, 8,
|
| 176 |
+
24, 3, 3, 8,
|
| 177 |
+
0, 0,
|
| 178 |
+
2, 2,
|
| 179 |
+
1, 1,
|
| 180 |
+
))
|
| 181 |
+
|
| 182 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 183 |
+
1, 23, 21, 8,
|
| 184 |
+
24, 3, 3, 8,
|
| 185 |
+
1, 1,
|
| 186 |
+
3, 3,
|
| 187 |
+
1, 1,
|
| 188 |
+
))
|
| 189 |
+
|
| 190 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 191 |
+
1, 20, 24, 8,
|
| 192 |
+
40, 3, 3, 8,
|
| 193 |
+
3, 3,
|
| 194 |
+
3, 3,
|
| 195 |
+
1, 1,
|
| 196 |
+
))
|
| 197 |
+
|
| 198 |
+
##########################################
|
| 199 |
+
# Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1)
|
| 200 |
+
##########################################
|
| 201 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 202 |
+
1, 15, 19, 160,
|
| 203 |
+
224, 1, 1, 160,
|
| 204 |
+
0, 0,
|
| 205 |
+
1, 1,
|
| 206 |
+
1, 1,
|
| 207 |
+
))
|
| 208 |
+
|
| 209 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 210 |
+
1, 19, 37, 160,
|
| 211 |
+
224, 3, 3, 160,
|
| 212 |
+
1, 1,
|
| 213 |
+
2, 2,
|
| 214 |
+
1, 1,
|
| 215 |
+
))
|
| 216 |
+
|
| 217 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 218 |
+
1, 16, 16, 160,
|
| 219 |
+
224, 2, 3, 160,
|
| 220 |
+
1, 1,
|
| 221 |
+
1, 1,
|
| 222 |
+
1, 1,
|
| 223 |
+
))
|
| 224 |
+
|
| 225 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 226 |
+
1, 23, 21, 128,
|
| 227 |
+
224, 3, 3, 128,
|
| 228 |
+
1, 1,
|
| 229 |
+
1, 1,
|
| 230 |
+
1, 1,
|
| 231 |
+
))
|
| 232 |
+
|
| 233 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 234 |
+
1, 29, 37, 160,
|
| 235 |
+
224, 5, 5, 160,
|
| 236 |
+
2, 2,
|
| 237 |
+
1, 1,
|
| 238 |
+
1, 1,
|
| 239 |
+
))
|
| 240 |
+
|
| 241 |
+
##########################################
|
| 242 |
+
# C > CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64}
|
| 243 |
+
##########################################
|
| 244 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 245 |
+
1, 15, 19, 32 + minimum_channel_size,
|
| 246 |
+
96, 3, 3, 32 + minimum_channel_size,
|
| 247 |
+
1, 1,
|
| 248 |
+
1, 1,
|
| 249 |
+
1, 1,
|
| 250 |
+
))
|
| 251 |
+
|
| 252 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 253 |
+
1, 16, 24, 64 + minimum_channel_size,
|
| 254 |
+
96, 3, 3, 64 + minimum_channel_size,
|
| 255 |
+
1, 1,
|
| 256 |
+
1, 1,
|
| 257 |
+
1, 1,
|
| 258 |
+
))
|
| 259 |
+
|
| 260 |
+
##########################################
|
| 261 |
+
# Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2)
|
| 262 |
+
##########################################
|
| 263 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 264 |
+
1, 13, 16, 288,
|
| 265 |
+
160, 5, 5, 288,
|
| 266 |
+
2, 2,
|
| 267 |
+
2, 2,
|
| 268 |
+
1, 1,
|
| 269 |
+
))
|
| 270 |
+
|
| 271 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 272 |
+
1, 55, 51, 256,
|
| 273 |
+
512, 1, 1, 256,
|
| 274 |
+
0, 0,
|
| 275 |
+
2, 2,
|
| 276 |
+
1, 1,
|
| 277 |
+
))
|
| 278 |
+
|
| 279 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 280 |
+
1, 71, 80, 32,
|
| 281 |
+
64, 5, 5, 32,
|
| 282 |
+
2, 2,
|
| 283 |
+
2, 2,
|
| 284 |
+
1, 1,
|
| 285 |
+
))
|
| 286 |
+
|
| 287 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 288 |
+
1, 224, 224, 8,
|
| 289 |
+
64, 7, 7, 8,
|
| 290 |
+
3, 3,
|
| 291 |
+
2, 2,
|
| 292 |
+
1, 1,
|
| 293 |
+
))
|
| 294 |
+
|
| 295 |
+
##########################################
|
| 296 |
+
# Medium input size stride (3, 3), filter (3, 3), non-default padding
|
| 297 |
+
##########################################
|
| 298 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 299 |
+
1, 27, 23, 256,
|
| 300 |
+
512, 3, 3, 256,
|
| 301 |
+
0, 0,
|
| 302 |
+
3, 3,
|
| 303 |
+
1, 1,
|
| 304 |
+
))
|
| 305 |
+
|
| 306 |
+
##########################################
|
| 307 |
+
# Medium input size padding > stride, asymmetric filter, padding and striding
|
| 308 |
+
##########################################
|
| 309 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 310 |
+
1, 27, 31, 256,
|
| 311 |
+
512, 3, 3, 256,
|
| 312 |
+
5, 7,
|
| 313 |
+
3, 4,
|
| 314 |
+
1, 1,
|
| 315 |
+
))
|
| 316 |
+
|
| 317 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 318 |
+
1, 27, 35, 256,
|
| 319 |
+
512, 7, 5, 256,
|
| 320 |
+
11, 7,
|
| 321 |
+
3, 5,
|
| 322 |
+
1, 1,
|
| 323 |
+
))
|
| 324 |
+
|
| 325 |
+
##########################################
|
| 326 |
+
# Medium input size *mixed* stride (1, 2) and (2, 1),
|
| 327 |
+
# filter (3, 3), default padding
|
| 328 |
+
##########################################
|
| 329 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 330 |
+
1, 27, 27, 256,
|
| 331 |
+
512, 3, 3, 256,
|
| 332 |
+
1, 1,
|
| 333 |
+
1, 2,
|
| 334 |
+
1, 1,
|
| 335 |
+
))
|
| 336 |
+
|
| 337 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 338 |
+
1, 27, 27, 256,
|
| 339 |
+
512, 3, 3, 256,
|
| 340 |
+
1, 1,
|
| 341 |
+
2, 1,
|
| 342 |
+
1, 1,
|
| 343 |
+
))
|
| 344 |
+
|
| 345 |
+
######################################/
|
| 346 |
+
# Additional input size
|
| 347 |
+
######################################/
|
| 348 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 349 |
+
3, 28, 28, 256,
|
| 350 |
+
256, 2, 2, 256,
|
| 351 |
+
0, 0,
|
| 352 |
+
2, 2,
|
| 353 |
+
1, 1,
|
| 354 |
+
))
|
| 355 |
+
|
| 356 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 357 |
+
1, 32, 32, 16,
|
| 358 |
+
32, 3, 3, 16,
|
| 359 |
+
1, 1,
|
| 360 |
+
6, 2,
|
| 361 |
+
1, 1,
|
| 362 |
+
))
|
| 363 |
+
|
| 364 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 365 |
+
32, 24, 32, 32,
|
| 366 |
+
32, 1, 2, 32,
|
| 367 |
+
0, 0,
|
| 368 |
+
1, 1,
|
| 369 |
+
1, 1,
|
| 370 |
+
))
|
| 371 |
+
|
| 372 |
+
conv2d_default_sizes.append(Conv2DProblemSize(
|
| 373 |
+
4, 2, 3, 256,
|
| 374 |
+
328, 3, 5, 256,
|
| 375 |
+
1, 1,
|
| 376 |
+
1, 1,
|
| 377 |
+
1, 1,
|
| 378 |
+
))
|
| 379 |
+
return conv2d_default_sizes
|
| 380 |
+
|
| 381 |
+
# Add a few large and rigorous convolution problem sizes
|
| 382 |
+
def initialize_conv2d_rigorous_sizes(self, minimum_channel_size):
|
| 383 |
+
sizes = []
|
| 384 |
+
if False:
|
| 385 |
+
sizes.append(Conv2DProblemSize.from_sizes(
|
| 386 |
+
(1, 124, 224, 2 * minimum_channel_size),
|
| 387 |
+
(24, 7, 7, 2 * minimum_channel_size),
|
| 388 |
+
))
|
| 389 |
+
|
| 390 |
+
sizes.append(Conv2DProblemSize.from_sizes(
|
| 391 |
+
(1, 233, 35, minimum_channel_size),
|
| 392 |
+
(24, 7, 5, minimum_channel_size),
|
| 393 |
+
))
|
| 394 |
+
return sizes
|
| 395 |
+
|
| 396 |
+
# Add resent50 layers to unit testing sizes
|
| 397 |
+
def initialize_conv2d_resnet50_sizes(self, batch_size):
|
| 398 |
+
conv2d_problem_vector = []
|
| 399 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 400 |
+
batch_size, 56, 56, 64,
|
| 401 |
+
256, 1, 1, 64,
|
| 402 |
+
0, 0,
|
| 403 |
+
1, 1,
|
| 404 |
+
1, 1,
|
| 405 |
+
))
|
| 406 |
+
|
| 407 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 408 |
+
batch_size, 56, 56, 64,
|
| 409 |
+
64, 1, 1, 64,
|
| 410 |
+
0, 0,
|
| 411 |
+
1, 1,
|
| 412 |
+
1, 1,
|
| 413 |
+
))
|
| 414 |
+
|
| 415 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 416 |
+
batch_size, 56, 56, 64,
|
| 417 |
+
64, 3, 3, 64,
|
| 418 |
+
1, 1,
|
| 419 |
+
1, 1,
|
| 420 |
+
1, 1,
|
| 421 |
+
))
|
| 422 |
+
|
| 423 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 424 |
+
batch_size, 56, 56, 256,
|
| 425 |
+
64, 1, 1, 256,
|
| 426 |
+
0, 0,
|
| 427 |
+
1, 1,
|
| 428 |
+
1, 1,
|
| 429 |
+
))
|
| 430 |
+
|
| 431 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 432 |
+
batch_size, 56, 56, 256,
|
| 433 |
+
512, 1, 1, 256,
|
| 434 |
+
0, 0,
|
| 435 |
+
2, 2,
|
| 436 |
+
1, 1,
|
| 437 |
+
))
|
| 438 |
+
|
| 439 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 440 |
+
batch_size, 56, 56, 256,
|
| 441 |
+
128, 1, 1, 256,
|
| 442 |
+
0, 0,
|
| 443 |
+
2, 2,
|
| 444 |
+
1, 1,
|
| 445 |
+
))
|
| 446 |
+
|
| 447 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 448 |
+
batch_size, 28, 28, 128,
|
| 449 |
+
128, 3, 3, 128,
|
| 450 |
+
1, 1,
|
| 451 |
+
1, 1,
|
| 452 |
+
1, 1,
|
| 453 |
+
))
|
| 454 |
+
|
| 455 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 456 |
+
batch_size, 28, 28, 128,
|
| 457 |
+
512, 1, 1, 128,
|
| 458 |
+
0, 0,
|
| 459 |
+
1, 1,
|
| 460 |
+
1, 1,
|
| 461 |
+
))
|
| 462 |
+
|
| 463 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 464 |
+
batch_size, 28, 28, 512,
|
| 465 |
+
128, 1, 1, 512,
|
| 466 |
+
0, 0,
|
| 467 |
+
1, 1,
|
| 468 |
+
1, 1,
|
| 469 |
+
))
|
| 470 |
+
|
| 471 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 472 |
+
batch_size, 28, 28, 512,
|
| 473 |
+
1024, 1, 1, 512,
|
| 474 |
+
0, 0,
|
| 475 |
+
2, 2,
|
| 476 |
+
1, 1,
|
| 477 |
+
))
|
| 478 |
+
|
| 479 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 480 |
+
batch_size, 28, 28, 512,
|
| 481 |
+
256, 1, 1, 512,
|
| 482 |
+
0, 0,
|
| 483 |
+
2, 2,
|
| 484 |
+
1, 1,
|
| 485 |
+
))
|
| 486 |
+
|
| 487 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 488 |
+
batch_size, 14, 14, 256,
|
| 489 |
+
256, 3, 3, 256,
|
| 490 |
+
1, 1,
|
| 491 |
+
1, 1,
|
| 492 |
+
1, 1,
|
| 493 |
+
))
|
| 494 |
+
|
| 495 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 496 |
+
batch_size, 14, 14, 256,
|
| 497 |
+
1024, 1, 1, 256,
|
| 498 |
+
0, 0,
|
| 499 |
+
1, 1,
|
| 500 |
+
1, 1,
|
| 501 |
+
))
|
| 502 |
+
|
| 503 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 504 |
+
batch_size, 14, 14, 1024,
|
| 505 |
+
256, 1, 1, 1024,
|
| 506 |
+
0, 0,
|
| 507 |
+
1, 1,
|
| 508 |
+
1, 1,
|
| 509 |
+
))
|
| 510 |
+
|
| 511 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 512 |
+
batch_size, 14, 14, 1024,
|
| 513 |
+
2048, 1, 1, 1024,
|
| 514 |
+
0, 0,
|
| 515 |
+
2, 2,
|
| 516 |
+
1, 1,
|
| 517 |
+
))
|
| 518 |
+
|
| 519 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 520 |
+
batch_size, 14, 14, 1024,
|
| 521 |
+
512, 1, 1, 1024,
|
| 522 |
+
0, 0,
|
| 523 |
+
2, 2,
|
| 524 |
+
1, 1,
|
| 525 |
+
))
|
| 526 |
+
|
| 527 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 528 |
+
batch_size, 7, 7, 512,
|
| 529 |
+
512, 3, 3, 512,
|
| 530 |
+
1, 1,
|
| 531 |
+
1, 1,
|
| 532 |
+
1, 1,
|
| 533 |
+
))
|
| 534 |
+
|
| 535 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 536 |
+
batch_size, 7, 7, 512,
|
| 537 |
+
2048, 1, 1, 512,
|
| 538 |
+
0, 0,
|
| 539 |
+
1, 1,
|
| 540 |
+
1, 1,
|
| 541 |
+
))
|
| 542 |
+
|
| 543 |
+
conv2d_problem_vector.append(Conv2DProblemSize(
|
| 544 |
+
batch_size, 7, 7, 2048,
|
| 545 |
+
512, 1, 1, 2048,
|
| 546 |
+
0, 0,
|
| 547 |
+
1, 1,
|
| 548 |
+
1, 1,
|
| 549 |
+
))
|
| 550 |
+
|
| 551 |
+
return conv2d_problem_vector
|
| 552 |
+
|
| 553 |
+
def initialize_conv2d_grouped_sizes(self):
|
| 554 |
+
threadblock_n = 128
|
| 555 |
+
threadblock_k = 32
|
| 556 |
+
|
| 557 |
+
sizes = []
|
| 558 |
+
##########################################
|
| 559 |
+
# One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0
|
| 560 |
+
# One CTA calculates a single group
|
| 561 |
+
##########################################
|
| 562 |
+
for cta_per_group_k in range(1, 4):
|
| 563 |
+
for groups in range(2, 5):
|
| 564 |
+
conv_k = cta_per_group_k * threadblock_n * groups
|
| 565 |
+
sizes.append(Conv2DProblemSize(
|
| 566 |
+
1, 8, 8, threadblock_k * 2 * groups,
|
| 567 |
+
conv_k, 3, 3, threadblock_k * 2,
|
| 568 |
+
1, 1,
|
| 569 |
+
1, 1,
|
| 570 |
+
1, 1,
|
| 571 |
+
ConvMode.CrossCorrelation,
|
| 572 |
+
1,
|
| 573 |
+
groups
|
| 574 |
+
))
|
| 575 |
+
|
| 576 |
+
# Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K
|
| 577 |
+
sizes.append(Conv2DProblemSize(
|
| 578 |
+
1, 8, 8, threadblock_k,
|
| 579 |
+
threadblock_n * 2, 3, 3, threadblock_k // 2,
|
| 580 |
+
1, 1,
|
| 581 |
+
1, 1,
|
| 582 |
+
1, 1,
|
| 583 |
+
ConvMode.CrossCorrelation,
|
| 584 |
+
1,
|
| 585 |
+
2
|
| 586 |
+
))
|
| 587 |
+
|
| 588 |
+
sizes.append(Conv2DProblemSize(
|
| 589 |
+
1, 56, 56, 696,
|
| 590 |
+
768, 3, 3, 232,
|
| 591 |
+
1, 1,
|
| 592 |
+
2, 2,
|
| 593 |
+
1, 1,
|
| 594 |
+
ConvMode.CrossCorrelation,
|
| 595 |
+
1,
|
| 596 |
+
3
|
| 597 |
+
))
|
| 598 |
+
sizes.append(Conv2DProblemSize(
|
| 599 |
+
1, 14, 14, 1392,
|
| 600 |
+
1536, 3, 3, 232,
|
| 601 |
+
1, 1,
|
| 602 |
+
1, 1,
|
| 603 |
+
1, 1,
|
| 604 |
+
ConvMode.CrossCorrelation,
|
| 605 |
+
1,
|
| 606 |
+
3
|
| 607 |
+
))
|
| 608 |
+
|
| 609 |
+
##########################################
|
| 610 |
+
# One CTA calculate multiple groups: CTA::N % k_per_group = 0
|
| 611 |
+
##########################################
|
| 612 |
+
|
| 613 |
+
# 2 groups per CTA
|
| 614 |
+
sizes.append(Conv2DProblemSize(
|
| 615 |
+
1, 8, 8, threadblock_k * 4,
|
| 616 |
+
threadblock_n, 3, 3, threadblock_k * 2,
|
| 617 |
+
1, 1,
|
| 618 |
+
1, 1,
|
| 619 |
+
1, 1,
|
| 620 |
+
ConvMode.CrossCorrelation,
|
| 621 |
+
1,
|
| 622 |
+
2
|
| 623 |
+
))
|
| 624 |
+
|
| 625 |
+
# 2 groups per CTA and partial gemm_k
|
| 626 |
+
sizes.append(Conv2DProblemSize(
|
| 627 |
+
1, 8, 8, threadblock_k,
|
| 628 |
+
threadblock_n, 3, 3, threadblock_k // 2,
|
| 629 |
+
1, 1,
|
| 630 |
+
1, 1,
|
| 631 |
+
1, 1,
|
| 632 |
+
ConvMode.CrossCorrelation,
|
| 633 |
+
1,
|
| 634 |
+
2
|
| 635 |
+
))
|
| 636 |
+
|
| 637 |
+
# 4 groups per CTA
|
| 638 |
+
sizes.append(Conv2DProblemSize(
|
| 639 |
+
1, 8, 8, threadblock_k * 8,
|
| 640 |
+
threadblock_n // 2, 3, 3, threadblock_k * 2,
|
| 641 |
+
1, 1,
|
| 642 |
+
1, 1,
|
| 643 |
+
1, 1,
|
| 644 |
+
ConvMode.CrossCorrelation,
|
| 645 |
+
1,
|
| 646 |
+
4
|
| 647 |
+
))
|
| 648 |
+
|
| 649 |
+
# 4 groups per CTA and partial gemm_k
|
| 650 |
+
sizes.append(Conv2DProblemSize(
|
| 651 |
+
1, 8, 8, threadblock_k * 2,
|
| 652 |
+
threadblock_n // 2, 3, 3, threadblock_k // 2,
|
| 653 |
+
1, 1,
|
| 654 |
+
1, 1,
|
| 655 |
+
1, 1,
|
| 656 |
+
ConvMode.CrossCorrelation,
|
| 657 |
+
1,
|
| 658 |
+
4
|
| 659 |
+
))
|
| 660 |
+
|
| 661 |
+
return sizes
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Low-level functionality tests for Conv2d opreations on SM80
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
import cutlass_cppgen
|
| 41 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 42 |
+
|
| 43 |
+
from conv2d_test_utils import *
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 47 |
+
cc = 80
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@unittest.skipIf(device_cc() < cc, 'Device compute capability is invalid for SM80 tests.')
|
| 51 |
+
class Conv2dSm80(unittest.TestCase):
|
| 52 |
+
"""
|
| 53 |
+
Wrapper class to which tests will be added dynamically in __main__
|
| 54 |
+
"""
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
conv_problems = get_conv_problems()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Tests for optimized & analytic
|
| 62 |
+
for conv_kind in ["fprop", "wgrad", "dgrad"]:
|
| 63 |
+
# F16, simt
|
| 64 |
+
add_test(
|
| 65 |
+
Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
|
| 66 |
+
opclass="simt", threadblock_shape=[128, 128, 8],
|
| 67 |
+
warp_count=[4, 2, 1], stages=2, instruction_shape=[1, 1, 1])
|
| 68 |
+
# F16, tensor op
|
| 69 |
+
add_test(
|
| 70 |
+
Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
|
| 71 |
+
opclass="tensor_op", threadblock_shape=[128, 128, 64],
|
| 72 |
+
warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16])
|
| 73 |
+
# F16, tensor op, analytic iterator
|
| 74 |
+
add_test(
|
| 75 |
+
Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16,
|
| 76 |
+
opclass="tensor_op", threadblock_shape=[128, 128, 64],
|
| 77 |
+
warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="analytic")
|
| 78 |
+
# F16, tensor op, f32 output
|
| 79 |
+
add_test(
|
| 80 |
+
Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32,
|
| 81 |
+
opclass="tensor_op", threadblock_shape=[128, 128, 64],
|
| 82 |
+
warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16])
|
| 83 |
+
# F16, tensor op, different tile description
|
| 84 |
+
add_test(
|
| 85 |
+
Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
|
| 86 |
+
opclass="tensor_op", threadblock_shape=[128, 64, 32],
|
| 87 |
+
warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8])
|
| 88 |
+
# F32, simt
|
| 89 |
+
add_test(
|
| 90 |
+
Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32,
|
| 91 |
+
opclass="simt", threadblock_shape=[128, 128, 8],
|
| 92 |
+
warp_count=[4, 2, 1], stages=4, instruction_shape=[1, 1, 1])
|
| 93 |
+
# Tf32, tensorop
|
| 94 |
+
add_test(
|
| 95 |
+
Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32,
|
| 96 |
+
opclass="tensor_op", threadblock_shape=[128, 128, 16],
|
| 97 |
+
warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8]
|
| 98 |
+
)
|
| 99 |
+
# Split-K
|
| 100 |
+
add_test(
|
| 101 |
+
Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
|
| 102 |
+
opclass="tensor_op", threadblock_shape=[128, 128, 64],
|
| 103 |
+
warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="serial",
|
| 104 |
+
split_k_slices=2)
|
| 105 |
+
add_test(
|
| 106 |
+
Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
|
| 107 |
+
opclass="tensor_op", threadblock_shape=[128, 128, 64],
|
| 108 |
+
warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="parallel",
|
| 109 |
+
split_k_slices=5)
|
| 110 |
+
# Swizzling functor
|
| 111 |
+
add_test(
|
| 112 |
+
Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
|
| 113 |
+
opclass="tensor_op", threadblock_shape=[128, 64, 32],
|
| 114 |
+
warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8], swizzle=4)
|
| 115 |
+
|
| 116 |
+
# Tests for few channels and fixed channels
|
| 117 |
+
# F16, tensor op, few channels
|
| 118 |
+
for c, tb, stage, inst in zip([2, 1],
|
| 119 |
+
[[128, 128, 64], [128, 128, 32]],
|
| 120 |
+
[3, 2],
|
| 121 |
+
[[16, 8, 16], [16, 8, 8]]):
|
| 122 |
+
add_test(
|
| 123 |
+
Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
|
| 124 |
+
opclass="tensor_op", threadblock_shape=tb,
|
| 125 |
+
warp_count=[2, 2, 1], stages=stage, instruction_shape=inst, iterator_algorithm="few_channels"
|
| 126 |
+
)
|
| 127 |
+
# F16, tensor op, fixed channels
|
| 128 |
+
for c in [8, 4, 2]:
|
| 129 |
+
add_test(
|
| 130 |
+
Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
|
| 131 |
+
opclass="tensor_op", threadblock_shape=[128, 128, 64],
|
| 132 |
+
warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="fixed_channels"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Test activations
|
| 136 |
+
for activation in ["relu", "leaky_relu"]:
|
| 137 |
+
for split_k_mode, split_k_slices in zip(["parallel", "serial", "parallel"], [1, 7, 5]):
|
| 138 |
+
add_test(
|
| 139 |
+
Conv2dSm80, cc, "fprop", conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
|
| 140 |
+
opclass="tensor_op", threadblock_shape=[128, 128, 64],
|
| 141 |
+
warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode=split_k_mode,
|
| 142 |
+
split_k_slices=split_k_slices, activation=activation)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == '__main__':
|
| 146 |
+
unittest.main()
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utility functions for Conv2d tests.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from cutlass_library import SubstituteTemplate
|
| 38 |
+
import torch
|
| 39 |
+
|
| 40 |
+
import cutlass_cppgen
|
| 41 |
+
from cutlass_library import (
|
| 42 |
+
ConvKind,
|
| 43 |
+
ConvMode,
|
| 44 |
+
DataType,
|
| 45 |
+
DataTypeNames,
|
| 46 |
+
EpilogueScheduleSuffixes,
|
| 47 |
+
KernelScheduleSuffixes,
|
| 48 |
+
LayoutType,
|
| 49 |
+
OpcodeClassNames,
|
| 50 |
+
ShortDataTypeNames,
|
| 51 |
+
ShortLayoutTypeNames,
|
| 52 |
+
SplitKMode,
|
| 53 |
+
)
|
| 54 |
+
from cutlass_cppgen.shape import Conv2DProblemSize
|
| 55 |
+
from cutlass_cppgen.utils.datatypes import numpy_type, torch_type
|
| 56 |
+
|
| 57 |
+
from conv2d_problem_sizes import TestbedConv2dProblemSizes
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_name_conv2d(
|
| 61 |
+
arch,
|
| 62 |
+
conv_kind,
|
| 63 |
+
element,
|
| 64 |
+
element_accumulator,
|
| 65 |
+
element_output,
|
| 66 |
+
opclass,
|
| 67 |
+
threadblock_shape,
|
| 68 |
+
warp_count,
|
| 69 |
+
instruction_shape,
|
| 70 |
+
stages,
|
| 71 |
+
iterator_algorithm,
|
| 72 |
+
swizzle,
|
| 73 |
+
split_k_mode,
|
| 74 |
+
split_k_slices,
|
| 75 |
+
activation
|
| 76 |
+
):
|
| 77 |
+
"""
|
| 78 |
+
Generates a procedural name for a test case for conv2d
|
| 79 |
+
|
| 80 |
+
:param arch: compute capability of kernel being generated
|
| 81 |
+
:type arch: int
|
| 82 |
+
:param conv_kind: the convolution type (i.e. fprop, dgrad, wgrad)
|
| 83 |
+
:type conv_kind: str
|
| 84 |
+
:param iterator_algorithm: the iterator algorithm applied
|
| 85 |
+
:type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
|
| 86 |
+
:param element_a: data type of operand A
|
| 87 |
+
:param element_b: data type of operand B
|
| 88 |
+
:param element_c: data type of operand C
|
| 89 |
+
:param element_accumulator: data type used in accumulation
|
| 90 |
+
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
|
| 91 |
+
:type opclass: cutlass_cppgen.OpcodeClass
|
| 92 |
+
:param threadblock_shape: indexable container of dimensions of threadblock tiles
|
| 93 |
+
:param stages: number of pipeline stages to use in the kernel
|
| 94 |
+
:type stages: int
|
| 95 |
+
:param stride_support: stride support of dgrad
|
| 96 |
+
:param alignment: int
|
| 97 |
+
:type alignment: int
|
| 98 |
+
|
| 99 |
+
:return: str
|
| 100 |
+
"""
|
| 101 |
+
if iterator_algorithm is None:
|
| 102 |
+
iterator_algorithm = "AUTO"
|
| 103 |
+
if swizzle is None:
|
| 104 |
+
swizzle = 1
|
| 105 |
+
name_format = "test_SM${arch}_Device_Conv2d_${conv_kind}_${iter_alg}_ImplicitGemm_${eA}nhwc_${eB}nhwc_${eC}nhwc_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${wM}x${wN}x${wK}_${IM}${IN}${IK}_stage${stages}_swizzle${swizzle}_${split_k_mode}${split_k_slices}_${activation}"
|
| 106 |
+
|
| 107 |
+
return SubstituteTemplate(
|
| 108 |
+
name_format,
|
| 109 |
+
{
|
| 110 |
+
"arch": str(arch),
|
| 111 |
+
"conv_kind": conv_kind,
|
| 112 |
+
"iter_alg": iterator_algorithm,
|
| 113 |
+
"eA": DataTypeNames[element],
|
| 114 |
+
"eB": DataTypeNames[element],
|
| 115 |
+
"eC": DataTypeNames[element_output],
|
| 116 |
+
"opclass": opclass,
|
| 117 |
+
"acc": DataTypeNames[element_accumulator],
|
| 118 |
+
"tbM": str(threadblock_shape[0]),
|
| 119 |
+
"tbN": str(threadblock_shape[1]),
|
| 120 |
+
"tbK": str(threadblock_shape[2]),
|
| 121 |
+
"wM": str(threadblock_shape[0] // warp_count[0]),
|
| 122 |
+
"wN": str(threadblock_shape[1] // warp_count[1]),
|
| 123 |
+
"wK": str(threadblock_shape[2] // warp_count[2]),
|
| 124 |
+
"IM": str(instruction_shape[0]),
|
| 125 |
+
"IN": str(instruction_shape[1]),
|
| 126 |
+
"IK": str(instruction_shape[2]),
|
| 127 |
+
"stages": str(stages),
|
| 128 |
+
"swizzle": str(swizzle),
|
| 129 |
+
"split_k_mode": split_k_mode,
|
| 130 |
+
"split_k_slices": str(split_k_slices),
|
| 131 |
+
"activation": activation
|
| 132 |
+
}
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def conv2d_few_channel_problemsizes(channels):
|
| 137 |
+
problem_sizes = [
|
| 138 |
+
Conv2DProblemSize(
|
| 139 |
+
1, 8, 8, channels,
|
| 140 |
+
16, 3, 3, channels,
|
| 141 |
+
1, 1,
|
| 142 |
+
2, 2,
|
| 143 |
+
1, 1,
|
| 144 |
+
ConvMode.CrossCorrelation,
|
| 145 |
+
1, 1
|
| 146 |
+
),
|
| 147 |
+
Conv2DProblemSize(
|
| 148 |
+
1, 16, 16, channels,
|
| 149 |
+
16, 3, 3, channels,
|
| 150 |
+
1, 1,
|
| 151 |
+
2, 2,
|
| 152 |
+
1, 1,
|
| 153 |
+
ConvMode.CrossCorrelation,
|
| 154 |
+
1, 1
|
| 155 |
+
),
|
| 156 |
+
Conv2DProblemSize(
|
| 157 |
+
1, 16, 16, channels,
|
| 158 |
+
16, 7, 7, channels,
|
| 159 |
+
1, 1,
|
| 160 |
+
1, 1,
|
| 161 |
+
1, 1,
|
| 162 |
+
ConvMode.CrossCorrelation,
|
| 163 |
+
1, 1
|
| 164 |
+
),
|
| 165 |
+
Conv2DProblemSize(
|
| 166 |
+
1, 224, 224, channels,
|
| 167 |
+
32, 7, 7, channels,
|
| 168 |
+
1, 1,
|
| 169 |
+
1, 1,
|
| 170 |
+
1, 1,
|
| 171 |
+
ConvMode.CrossCorrelation,
|
| 172 |
+
1, 1
|
| 173 |
+
),
|
| 174 |
+
Conv2DProblemSize(
|
| 175 |
+
1, 224, 224, channels,
|
| 176 |
+
64, 7, 7, channels,
|
| 177 |
+
1, 1,
|
| 178 |
+
2, 2,
|
| 179 |
+
1, 1,
|
| 180 |
+
ConvMode.CrossCorrelation,
|
| 181 |
+
1, 1
|
| 182 |
+
),
|
| 183 |
+
Conv2DProblemSize(
|
| 184 |
+
1, 224, 224, channels,
|
| 185 |
+
64, 5, 5, channels,
|
| 186 |
+
1, 1,
|
| 187 |
+
1, 1,
|
| 188 |
+
1, 1,
|
| 189 |
+
ConvMode.CrossCorrelation,
|
| 190 |
+
1, 1
|
| 191 |
+
),
|
| 192 |
+
Conv2DProblemSize(
|
| 193 |
+
1, 224, 224, channels,
|
| 194 |
+
64, 5, 5, channels,
|
| 195 |
+
1, 1,
|
| 196 |
+
2, 2,
|
| 197 |
+
1, 1,
|
| 198 |
+
ConvMode.CrossCorrelation,
|
| 199 |
+
1, 1
|
| 200 |
+
),
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
+
return problem_sizes
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def validate_problem_size(ps, conv_kind, split_k_slices):
|
| 207 |
+
P = (ps.H + 2 * ps.pad_h - ps.dilation_h * (ps.R - 1) - 1) // ps.stride_h + 1
|
| 208 |
+
Q = (ps.W + 2 * ps.pad_w - ps.dilation_w * (ps.S - 1) - 1) // ps.stride_w + 1
|
| 209 |
+
if P != ps.P or Q != ps.Q:
|
| 210 |
+
return False
|
| 211 |
+
|
| 212 |
+
# Split-K (serial or parallel) is not supported for strided dgrad
|
| 213 |
+
if conv_kind == "dgrad" and split_k_slices > 1 and (ps.stride_h > 1 or ps.stride_w > 1):
|
| 214 |
+
return False
|
| 215 |
+
return True
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class Conv2dLauncherFrontend:
|
| 219 |
+
def __init__(self, plan: cutlass_cppgen.Conv2d, seed: int = 80, backend="numpy"):
|
| 220 |
+
self.operation = plan
|
| 221 |
+
self.conv_kind = plan.conv_kind
|
| 222 |
+
self.seed = seed
|
| 223 |
+
self.backend = backend
|
| 224 |
+
|
| 225 |
+
self.dtype_A = plan._element_a
|
| 226 |
+
self.dtype_B = plan._element_b
|
| 227 |
+
self.dtype_C = plan._element_c
|
| 228 |
+
self.dtype_acc = plan._element_accumulator
|
| 229 |
+
self.layout_A = LayoutType.TensorNHWC
|
| 230 |
+
self.layout_B = LayoutType.TensorNHWC
|
| 231 |
+
self.layout_C = LayoutType.TensorNHWC
|
| 232 |
+
self.layout_D = LayoutType.TensorNHWC
|
| 233 |
+
|
| 234 |
+
self.element_compute = DataType.f32
|
| 235 |
+
|
| 236 |
+
if self.dtype_A in [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.bf16]:
|
| 237 |
+
self.rand_max = 1
|
| 238 |
+
else:
|
| 239 |
+
self.rand_max = 4
|
| 240 |
+
self.activation = plan.activation
|
| 241 |
+
|
| 242 |
+
def uniform_init(self, size, dtype):
|
| 243 |
+
tensor = torch.ceil(
|
| 244 |
+
torch.empty(size=size, dtype=torch_type(dtype), device="cuda").uniform_(-self.rand_max - 0.5, self.rand_max - 0.5)
|
| 245 |
+
).to(memory_format=torch.channels_last)
|
| 246 |
+
return tensor
|
| 247 |
+
|
| 248 |
+
def reference(self, ps, A, B, C, alpha, beta, activation):
|
| 249 |
+
if self.conv_kind == ConvKind.Fprop:
|
| 250 |
+
torch_result = alpha * torch.ops.aten.conv2d(
|
| 251 |
+
A,
|
| 252 |
+
B,
|
| 253 |
+
stride=(ps.stride_h, ps.stride_w),
|
| 254 |
+
padding=(ps.pad_h, ps.pad_w),
|
| 255 |
+
dilation=(ps.dilation_h, ps.dilation_w)
|
| 256 |
+
) + beta * C
|
| 257 |
+
elif self.conv_kind == ConvKind.Dgrad:
|
| 258 |
+
torch_result = alpha * torch.nn.grad.conv2d_input(
|
| 259 |
+
(ps.N, ps.C, ps.H, ps.W),
|
| 260 |
+
B,
|
| 261 |
+
A,
|
| 262 |
+
padding=(ps.pad_h, ps.pad_w),
|
| 263 |
+
stride=(ps.stride_h, ps.stride_w)
|
| 264 |
+
) + beta * C
|
| 265 |
+
elif self.conv_kind == ConvKind.Wgrad:
|
| 266 |
+
torch_result = alpha * torch.nn.grad.conv2d_weight(
|
| 267 |
+
B,
|
| 268 |
+
(ps.K, ps.C, ps.R, ps.S),
|
| 269 |
+
A,
|
| 270 |
+
padding=(ps.pad_h, ps.pad_w),
|
| 271 |
+
stride=(ps.stride_h, ps.stride_w)
|
| 272 |
+
) + beta * C
|
| 273 |
+
else:
|
| 274 |
+
raise Exception(f"Conv kind {self.conv_kind} is currently unsupported.")
|
| 275 |
+
|
| 276 |
+
if activation == cutlass_cppgen.backend.epilogue.relu:
|
| 277 |
+
torch_result = torch.nn.functional.relu(torch_result)
|
| 278 |
+
elif activation == cutlass_cppgen.backend.epilogue.leaky_relu:
|
| 279 |
+
torch_result = torch.nn.functional.leaky_relu(torch_result, 0.5)
|
| 280 |
+
return torch_result
|
| 281 |
+
|
| 282 |
+
def run(self, ps, split_k_mode=SplitKMode.Serial, split_k_slices=1, alpha=1.0, beta=0.0):
|
| 283 |
+
if self.conv_kind == ConvKind.Fprop:
|
| 284 |
+
tensor_A_size = (ps.N, ps.C, ps.H, ps.W)
|
| 285 |
+
tensor_B_size = (ps.K, ps.C, ps.R, ps.S)
|
| 286 |
+
tensor_C_size = (ps.N, ps.K, ps.P, ps.Q)
|
| 287 |
+
elif self.conv_kind == ConvKind.Dgrad:
|
| 288 |
+
tensor_A_size = (ps.N, ps.K, ps.P, ps.Q)
|
| 289 |
+
tensor_B_size = (ps.K, ps.C, ps.R, ps.S)
|
| 290 |
+
tensor_C_size = (ps.N, ps.C, ps.H, ps.W)
|
| 291 |
+
elif self.conv_kind == ConvKind.Wgrad:
|
| 292 |
+
tensor_A_size = (ps.N, ps.K, ps.P, ps.Q)
|
| 293 |
+
tensor_B_size = (ps.N, ps.C, ps.H, ps.W)
|
| 294 |
+
tensor_C_size = (ps.K, ps.C, ps.R, ps.S)
|
| 295 |
+
else:
|
| 296 |
+
raise Exception(f"Conv kind {self.conv_kind} is not supported")
|
| 297 |
+
|
| 298 |
+
torch.manual_seed(self.seed)
|
| 299 |
+
|
| 300 |
+
tensor_A = self.uniform_init(size=tensor_A_size, dtype=self.dtype_A)
|
| 301 |
+
tensor_B = self.uniform_init(size=tensor_B_size, dtype=self.dtype_B)
|
| 302 |
+
tensor_C = self.uniform_init(size=tensor_C_size, dtype=self.dtype_C)
|
| 303 |
+
tensor_D = torch.zeros_like(tensor_C).to(memory_format=torch.channels_last)
|
| 304 |
+
args = self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D,
|
| 305 |
+
stride=(ps.stride_h, ps.stride_w),
|
| 306 |
+
padding=(ps.pad_h, ps.pad_w),
|
| 307 |
+
dilation=(ps.dilation_h, ps.dilation_w),
|
| 308 |
+
alpha=alpha, beta=beta,
|
| 309 |
+
split_k=(split_k_mode, split_k_slices))
|
| 310 |
+
|
| 311 |
+
args.sync()
|
| 312 |
+
|
| 313 |
+
tensor_D_ref = self.reference(ps, tensor_A, tensor_B, tensor_C, alpha, beta, self.activation)
|
| 314 |
+
|
| 315 |
+
torch.cuda.synchronize()
|
| 316 |
+
passed = torch.allclose(tensor_D, tensor_D_ref, atol=2e-06)
|
| 317 |
+
|
| 318 |
+
return passed
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def add_test(
|
| 322 |
+
cls,
|
| 323 |
+
cc,
|
| 324 |
+
conv_kind,
|
| 325 |
+
problem_sizes,
|
| 326 |
+
element,
|
| 327 |
+
element_accumulator,
|
| 328 |
+
element_output,
|
| 329 |
+
opclass,
|
| 330 |
+
threadblock_shape,
|
| 331 |
+
warp_count,
|
| 332 |
+
instruction_shape,
|
| 333 |
+
stages,
|
| 334 |
+
iterator_algorithm=None,
|
| 335 |
+
swizzle=None,
|
| 336 |
+
split_k_mode="serial",
|
| 337 |
+
split_k_slices=1,
|
| 338 |
+
activation = "identity"
|
| 339 |
+
):
|
| 340 |
+
"""Create a test-running function with the given specification"""
|
| 341 |
+
test_name = get_name_conv2d(
|
| 342 |
+
cc, conv_kind, element, element_accumulator,
|
| 343 |
+
element_output, opclass, threadblock_shape, warp_count, instruction_shape, stages,
|
| 344 |
+
iterator_algorithm, swizzle, split_k_mode, split_k_slices, activation)
|
| 345 |
+
|
| 346 |
+
def run(self):
|
| 347 |
+
# Create the plan
|
| 348 |
+
plan = cutlass_cppgen.Conv2d(
|
| 349 |
+
kind=conv_kind,
|
| 350 |
+
element=element,
|
| 351 |
+
element_accumulator=element_accumulator,
|
| 352 |
+
element_C=element_output,
|
| 353 |
+
element_D=element_output
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# Set the opclass
|
| 357 |
+
plan.opclass = opclass
|
| 358 |
+
# Set the tile description
|
| 359 |
+
td = {
|
| 360 |
+
"threadblock_shape": threadblock_shape,
|
| 361 |
+
"warp_count": warp_count,
|
| 362 |
+
"stages": stages,
|
| 363 |
+
"instruction_shape": instruction_shape,
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
plan.tile_description = td
|
| 367 |
+
# Set iterator algorithm
|
| 368 |
+
if iterator_algorithm is not None:
|
| 369 |
+
plan.iterator_algorithm = iterator_algorithm
|
| 370 |
+
# Set swizzling functor
|
| 371 |
+
if swizzle is not None:
|
| 372 |
+
plan.swizzling_stride = swizzle
|
| 373 |
+
|
| 374 |
+
if activation != "identity":
|
| 375 |
+
if activation == "leaky_relu":
|
| 376 |
+
plan.activation = (cutlass_cppgen.epilogue.leaky_relu, 0.5)
|
| 377 |
+
else:
|
| 378 |
+
plan.activation = getattr(cutlass_cppgen.epilogue, activation)
|
| 379 |
+
|
| 380 |
+
conv2d_launcher = Conv2dLauncherFrontend(plan, 80, backend="torch")
|
| 381 |
+
|
| 382 |
+
for ps in problem_sizes:
|
| 383 |
+
if not validate_problem_size(ps, conv_kind, split_k_slices):
|
| 384 |
+
continue
|
| 385 |
+
|
| 386 |
+
self.assertTrue(conv2d_launcher.run(ps, split_k_mode, split_k_slices, 1.0, 2.0))
|
| 387 |
+
|
| 388 |
+
setattr(cls, test_name, run)
|
| 389 |
+
|
| 390 |
+
return run
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def get_conv_problems():
|
| 394 |
+
# 64: minimum channel size
|
| 395 |
+
conv_problems = TestbedConv2dProblemSizes(64).all
|
| 396 |
+
|
| 397 |
+
# Insert alignment 4 & 2 tests
|
| 398 |
+
conv_problems += [
|
| 399 |
+
Conv2DProblemSize(
|
| 400 |
+
1, 4, 4, 12,
|
| 401 |
+
8, 3, 3, 12,
|
| 402 |
+
0, 0,
|
| 403 |
+
3, 3,
|
| 404 |
+
1, 1,
|
| 405 |
+
ConvMode.CrossCorrelation,
|
| 406 |
+
1, 1
|
| 407 |
+
),
|
| 408 |
+
Conv2DProblemSize(
|
| 409 |
+
1, 4, 4, 14,
|
| 410 |
+
8, 3, 3, 14,
|
| 411 |
+
0, 0,
|
| 412 |
+
3, 3,
|
| 413 |
+
1, 1,
|
| 414 |
+
ConvMode.CrossCorrelation,
|
| 415 |
+
1, 1
|
| 416 |
+
),
|
| 417 |
+
Conv2DProblemSize(
|
| 418 |
+
1, 23, 56, 98,
|
| 419 |
+
128, 3, 3, 98,
|
| 420 |
+
4, 5,
|
| 421 |
+
3, 3,
|
| 422 |
+
1, 1,
|
| 423 |
+
ConvMode.CrossCorrelation,
|
| 424 |
+
1, 1
|
| 425 |
+
),
|
| 426 |
+
]
|
| 427 |
+
|
| 428 |
+
return conv_problems
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
import pathlib
|
| 34 |
+
import unittest
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if __name__ == '__main__':
|
| 38 |
+
loader = unittest.TestLoader()
|
| 39 |
+
script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/'
|
| 40 |
+
tests = loader.discover(script_dir, 'conv2d_*.py')
|
| 41 |
+
testRunner = unittest.runner.TextTestRunner()
|
| 42 |
+
results = testRunner.run(tests)
|
| 43 |
+
if not results.wasSuccessful():
|
| 44 |
+
raise Exception('Test cases failed')
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Tests emitting a CUTLASS kernel to a PyTorch CUDA extension
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import random
|
| 38 |
+
import tempfile
|
| 39 |
+
import unittest
|
| 40 |
+
|
| 41 |
+
from cutlass_library import ConvMode
|
| 42 |
+
|
| 43 |
+
import cutlass_cppgen
|
| 44 |
+
|
| 45 |
+
if cutlass_cppgen.utils.datatypes.is_torch_available():
|
| 46 |
+
import torch
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _initialize(dtype, M: int, N: int, K: int):
|
| 50 |
+
"""
|
| 51 |
+
Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K
|
| 52 |
+
|
| 53 |
+
:param dtype: data type of tensors
|
| 54 |
+
:param M: M dimension of GEMM problem
|
| 55 |
+
:type M: int
|
| 56 |
+
:param N: N dimension of GEMM problem
|
| 57 |
+
:type N: int
|
| 58 |
+
:param K: N dimension of GEMM problem
|
| 59 |
+
:type K: int
|
| 60 |
+
|
| 61 |
+
:return: initialized tensors A, B, C, and D
|
| 62 |
+
:rtype: list
|
| 63 |
+
"""
|
| 64 |
+
sizes = [(M, K), (K, N), (M, N), (M, N)]
|
| 65 |
+
return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _generate_problems(dtype, num):
|
| 69 |
+
"""
|
| 70 |
+
Utility function to generate `num` GEMMs of random sizes
|
| 71 |
+
|
| 72 |
+
:param dtype: data type of tensors
|
| 73 |
+
:param num: number of GEMMs to generate
|
| 74 |
+
:type num: int
|
| 75 |
+
|
| 76 |
+
:return: lists of A, B, C, and D tensors
|
| 77 |
+
:rtype: list
|
| 78 |
+
"""
|
| 79 |
+
valid_sizes = [128, 256, 512, 1024]
|
| 80 |
+
As, Bs, Cs, Ds = [], [], [], []
|
| 81 |
+
for _ in range(num):
|
| 82 |
+
M, N, K = [random.choice(valid_sizes) for _ in range(3)]
|
| 83 |
+
A, B, C, D = _initialize(dtype, M, N, K)
|
| 84 |
+
As.append(A)
|
| 85 |
+
Bs.append(B)
|
| 86 |
+
Cs.append(C)
|
| 87 |
+
Ds.append(D)
|
| 88 |
+
return As, Bs, Cs, Ds
|
| 89 |
+
|
| 90 |
+
def _generate_conv2d_problem(conv_kind, dtype, ps):
|
| 91 |
+
"""
|
| 92 |
+
Utility function to generate conv2d inputs
|
| 93 |
+
|
| 94 |
+
:param conv_kind: kind of convolution
|
| 95 |
+
:type conv_kind: str
|
| 96 |
+
:param dtype: data type of tensors
|
| 97 |
+
:param problem_size: the conv2d problem size
|
| 98 |
+
:type problem_size: cutlass_cppgen.shape.Conv2DProblemSize
|
| 99 |
+
|
| 100 |
+
:return: initialized tensors A, B, C, and D
|
| 101 |
+
:rtype: list
|
| 102 |
+
"""
|
| 103 |
+
if conv_kind == "fprop":
|
| 104 |
+
tensor_A_size = (ps.N, ps.C, ps.H, ps.W)
|
| 105 |
+
tensor_B_size = (ps.K, ps.C, ps.R, ps.S)
|
| 106 |
+
tensor_C_size = (ps.N, ps.K, ps.P, ps.Q)
|
| 107 |
+
elif conv_kind == "dgrad":
|
| 108 |
+
tensor_A_size = (ps.N, ps.K, ps.P, ps.Q)
|
| 109 |
+
tensor_B_size = (ps.K, ps.C, ps.R, ps.S)
|
| 110 |
+
tensor_C_size = (ps.N, ps.C, ps.H, ps.W)
|
| 111 |
+
else:
|
| 112 |
+
tensor_A_size = (ps.N, ps.K, ps.P, ps.Q)
|
| 113 |
+
tensor_B_size = (ps.N, ps.C, ps.H, ps.W)
|
| 114 |
+
tensor_C_size = (ps.K, ps.C, ps.R, ps.S)
|
| 115 |
+
sizes = [tensor_A_size, tensor_B_size, tensor_C_size]
|
| 116 |
+
return [torch.ceil(torch.empty(size, dtype=dtype, device='cuda').uniform_(-4.5, 3.5)).to(memory_format=torch.channels_last) for size in sizes]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@unittest.skipIf(not cutlass_cppgen.utils.datatypes.is_torch_available(), 'PyTorch must be available to run PyTorch extension tests')
|
| 120 |
+
class PyTorchExtensionTest(unittest.TestCase):
|
| 121 |
+
|
| 122 |
+
def test_gemm(self):
|
| 123 |
+
random.seed(2023)
|
| 124 |
+
|
| 125 |
+
dtype = torch.float16
|
| 126 |
+
plan = cutlass_cppgen.op.Gemm(element=dtype, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 127 |
+
op = plan.construct()
|
| 128 |
+
|
| 129 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 130 |
+
mod = cutlass_cppgen.emit.pytorch(op, name='gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True)
|
| 131 |
+
|
| 132 |
+
A, B, C, _ = _initialize(dtype, 1024, 256, 512)
|
| 133 |
+
|
| 134 |
+
D_ref = A @ B
|
| 135 |
+
D = mod.run(A, B)
|
| 136 |
+
assert torch.allclose(D, D_ref)
|
| 137 |
+
|
| 138 |
+
D = mod.run(A, B, C)
|
| 139 |
+
assert torch.allclose(D, D_ref)
|
| 140 |
+
|
| 141 |
+
D = mod.run(A, B, C, 1.0)
|
| 142 |
+
assert torch.allclose(D, D_ref)
|
| 143 |
+
|
| 144 |
+
D = mod.run(A, B, C, 1.0, 0.0)
|
| 145 |
+
assert torch.allclose(D, D_ref)
|
| 146 |
+
|
| 147 |
+
alpha = 2.0
|
| 148 |
+
beta = -1.0
|
| 149 |
+
D_ref = (A @ B) * alpha + (beta * C)
|
| 150 |
+
D = mod.run(A, B, C, alpha, beta)
|
| 151 |
+
assert torch.allclose(D, D_ref)
|
| 152 |
+
|
| 153 |
+
def test_grouped_gemm(self):
|
| 154 |
+
random.seed(2023)
|
| 155 |
+
|
| 156 |
+
dtype = torch.float16
|
| 157 |
+
plan = cutlass_cppgen.op.GroupedGemm(element=dtype, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 158 |
+
op = plan.construct()
|
| 159 |
+
|
| 160 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 161 |
+
mod = cutlass_cppgen.emit.pytorch(op, name='grouped_gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True)
|
| 162 |
+
|
| 163 |
+
As, Bs, Cs, _ = _generate_problems(dtype, 50)
|
| 164 |
+
|
| 165 |
+
def check_all(X, Y):
|
| 166 |
+
for x, y in zip(X, Y):
|
| 167 |
+
assert torch.allclose(x, y)
|
| 168 |
+
|
| 169 |
+
Ds_ref = [a @ b for a, b in zip(As, Bs)]
|
| 170 |
+
Ds = mod.run(As, Bs)
|
| 171 |
+
check_all(Ds, Ds_ref)
|
| 172 |
+
|
| 173 |
+
Ds = mod.run(As, Bs, Cs)
|
| 174 |
+
check_all(Ds, Ds_ref)
|
| 175 |
+
|
| 176 |
+
Ds = mod.run(As, Bs, Cs, 1.0)
|
| 177 |
+
check_all(Ds, Ds_ref)
|
| 178 |
+
|
| 179 |
+
Ds = mod.run(As, Bs, Cs, 1.0, 0.0)
|
| 180 |
+
check_all(Ds, Ds_ref)
|
| 181 |
+
|
| 182 |
+
alpha = 2.0
|
| 183 |
+
beta = -1.0
|
| 184 |
+
Ds_ref = [(a @ b) * alpha + (beta * c) for a, b, c in zip(As, Bs, Cs)]
|
| 185 |
+
Ds = mod.run(As, Bs, Cs, alpha, beta)
|
| 186 |
+
check_all(Ds, Ds_ref)
|
| 187 |
+
|
| 188 |
+
def test_conv2d_fprop(self):
|
| 189 |
+
torch.manual_seed(2023)
|
| 190 |
+
|
| 191 |
+
dtype = torch.float16
|
| 192 |
+
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=dtype, element_accumulator=torch.float32)
|
| 193 |
+
plan.activation = "relu"
|
| 194 |
+
|
| 195 |
+
op = plan.construct()
|
| 196 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 197 |
+
mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_mod", cc=plan.cc, sourcedir=tmpdir, jit=True)
|
| 198 |
+
|
| 199 |
+
problem_size = cutlass_cppgen.shape.Conv2DProblemSize(
|
| 200 |
+
1, 4, 4, 16,
|
| 201 |
+
8, 3, 3, 16,
|
| 202 |
+
0, 0,
|
| 203 |
+
3, 3,
|
| 204 |
+
1, 1
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
A, B, C = _generate_conv2d_problem("fprop", dtype, problem_size)
|
| 208 |
+
stride = (problem_size.stride_h, problem_size.stride_w)
|
| 209 |
+
padding = (problem_size.pad_h, problem_size.pad_w)
|
| 210 |
+
|
| 211 |
+
alpha = 1.0
|
| 212 |
+
beta = 0.5
|
| 213 |
+
|
| 214 |
+
D_ref = alpha * torch.ops.aten.conv2d(
|
| 215 |
+
A, B, stride=stride, padding=padding
|
| 216 |
+
) + beta * C
|
| 217 |
+
D_ref = torch.nn.functional.relu(D_ref)
|
| 218 |
+
D = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta)
|
| 219 |
+
|
| 220 |
+
assert torch.allclose(D, D_ref)
|
| 221 |
+
|
| 222 |
+
# Test serial split-K
|
| 223 |
+
D_serial_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3)
|
| 224 |
+
assert torch.allclose(D, D_serial_split_k)
|
| 225 |
+
|
| 226 |
+
# Test parallel split-K
|
| 227 |
+
D_parallel_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7)
|
| 228 |
+
assert torch.allclose(D, D_parallel_split_k)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def test_conv2d_dgrad(self):
|
| 232 |
+
torch.manual_seed(2023)
|
| 233 |
+
dtype = torch.float16
|
| 234 |
+
plan = cutlass_cppgen.op.Conv2d(kind="dgrad", element=dtype, element_accumulator=torch.float32)
|
| 235 |
+
|
| 236 |
+
op = plan.construct()
|
| 237 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 238 |
+
mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_dgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True)
|
| 239 |
+
|
| 240 |
+
problem_size = cutlass_cppgen.shape.Conv2DProblemSize(
|
| 241 |
+
1, 4, 4, 16,
|
| 242 |
+
8, 3, 3, 16,
|
| 243 |
+
0, 0,
|
| 244 |
+
3, 3,
|
| 245 |
+
1, 1,
|
| 246 |
+
ConvMode.CrossCorrelation,
|
| 247 |
+
1, 1
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
A, B, C = _generate_conv2d_problem("dgrad", dtype, problem_size)
|
| 251 |
+
stride = (problem_size.stride_h, problem_size.stride_w)
|
| 252 |
+
padding = (problem_size.pad_h, problem_size.pad_w)
|
| 253 |
+
|
| 254 |
+
alpha = 1.0
|
| 255 |
+
beta = 0.5
|
| 256 |
+
input_size = (problem_size.N, problem_size.C, problem_size.H, problem_size.W)
|
| 257 |
+
D_ref = alpha * torch.nn.grad.conv2d_input(
|
| 258 |
+
input_size, B, A,
|
| 259 |
+
stride=stride, padding=padding
|
| 260 |
+
) + beta * C
|
| 261 |
+
D = mod.run(input_size, A, B, C, stride, padding, alpha=alpha, beta=beta, )
|
| 262 |
+
|
| 263 |
+
assert torch.allclose(D, D_ref)
|
| 264 |
+
|
| 265 |
+
def test_conv2d_wgrad(self):
|
| 266 |
+
torch.manual_seed(2023)
|
| 267 |
+
dtype = torch.float16
|
| 268 |
+
plan = cutlass_cppgen.op.Conv2d(kind="wgrad", element=dtype, element_accumulator=torch.float32)
|
| 269 |
+
|
| 270 |
+
op = plan.construct()
|
| 271 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 272 |
+
mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_wgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True)
|
| 273 |
+
|
| 274 |
+
problem_size = cutlass_cppgen.shape.Conv2DProblemSize(
|
| 275 |
+
1, 4, 4, 16,
|
| 276 |
+
8, 3, 3, 16,
|
| 277 |
+
0, 0,
|
| 278 |
+
3, 3,
|
| 279 |
+
1, 1,
|
| 280 |
+
ConvMode.CrossCorrelation,
|
| 281 |
+
1, 1
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
A, B, C = _generate_conv2d_problem("wgrad", dtype, problem_size)
|
| 285 |
+
stride = (problem_size.stride_h, problem_size.stride_w)
|
| 286 |
+
padding = (problem_size.pad_h, problem_size.pad_w)
|
| 287 |
+
|
| 288 |
+
alpha = 1.0
|
| 289 |
+
beta = 0.5
|
| 290 |
+
weight_size = (problem_size.K, problem_size.C, problem_size.R, problem_size.S)
|
| 291 |
+
D_ref = alpha * torch.nn.grad.conv2d_weight(
|
| 292 |
+
B, weight_size, A,
|
| 293 |
+
stride=stride, padding=padding
|
| 294 |
+
) + beta * C
|
| 295 |
+
D = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta)
|
| 296 |
+
|
| 297 |
+
assert torch.allclose(D, D_ref)
|
| 298 |
+
|
| 299 |
+
# Test serial split-K
|
| 300 |
+
D_serial_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3)
|
| 301 |
+
assert torch.allclose(D, D_serial_split_k)
|
| 302 |
+
|
| 303 |
+
# Test parallel split-K
|
| 304 |
+
D_parallel_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7)
|
| 305 |
+
assert torch.allclose(D, D_parallel_split_k)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
if __name__ == '__main__':
|
| 309 |
+
unittest.main()
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
################################################################################
|
| 32 |
+
"""
|
| 33 |
+
Unit test for compute node in SM90
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import logging
|
| 37 |
+
import unittest
|
| 38 |
+
|
| 39 |
+
import cutlass_cppgen
|
| 40 |
+
from cutlass_cppgen.backend import *
|
| 41 |
+
from cutlass_cppgen.epilogue import *
|
| 42 |
+
from cutlass_cppgen import swizzle
|
| 43 |
+
|
| 44 |
+
from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
|
| 45 |
+
|
| 46 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
|
| 50 |
+
class TestEVTCompute(EVTTestCaseBase):
|
| 51 |
+
|
| 52 |
+
def test_arith(self):
|
| 53 |
+
"""
|
| 54 |
+
Test Arithmatic op
|
| 55 |
+
"""
|
| 56 |
+
def evt_arith_compute(accum, C, alpha, beta, gamma):
|
| 57 |
+
D = ((accum + C) * alpha - gamma) / beta
|
| 58 |
+
return D
|
| 59 |
+
|
| 60 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 61 |
+
example_inputs = {
|
| 62 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 63 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 64 |
+
"alpha": 1.5,
|
| 65 |
+
"beta": 0.5,
|
| 66 |
+
"gamma": 2.5,
|
| 67 |
+
"D": self.fake_tensor(self.element, (l, m, n))
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
launcher = EVTTestBed(self.element, evt_arith_compute, example_inputs)
|
| 71 |
+
input_keys = ["C", "alpha", "beta", "gamma"]
|
| 72 |
+
result_keys = ["D"]
|
| 73 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 74 |
+
|
| 75 |
+
def test_func_call(self):
|
| 76 |
+
"""
|
| 77 |
+
Test Function call
|
| 78 |
+
"""
|
| 79 |
+
def evt_func_call(accum, C, alpha, beta, gamma):
|
| 80 |
+
D = multiply_add(relu(accum + alpha) + C, beta, gamma)
|
| 81 |
+
return D
|
| 82 |
+
|
| 83 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 84 |
+
example_inputs = {
|
| 85 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 86 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 87 |
+
"alpha": 1.5,
|
| 88 |
+
"beta": 0.5,
|
| 89 |
+
"gamma": 2.5,
|
| 90 |
+
"D": self.fake_tensor(self.element, (l, m, n))
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
launcher = EVTTestBed(self.element, evt_func_call, example_inputs)
|
| 94 |
+
input_keys = ["C", "alpha", "beta", "gamma"]
|
| 95 |
+
result_keys = ["D"]
|
| 96 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 97 |
+
|
| 98 |
+
def test_func_call2(self):
|
| 99 |
+
"""
|
| 100 |
+
Test Function call
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def evt_func_call2(accum, C, alpha, beta):
|
| 104 |
+
D = maximum(alpha * accum + beta * C, 0.0)
|
| 105 |
+
return D
|
| 106 |
+
|
| 107 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 108 |
+
example_inputs = {
|
| 109 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 110 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 111 |
+
"alpha": 1.5,
|
| 112 |
+
"beta": 0.5,
|
| 113 |
+
"D": self.fake_tensor(self.element, (l, m, n))
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
launcher = EVTTestBed(self.element, evt_func_call2, example_inputs)
|
| 117 |
+
input_keys = ["C", "alpha", "beta"]
|
| 118 |
+
result_keys = ["D"]
|
| 119 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 120 |
+
|
| 121 |
+
def test_tanh(self):
|
| 122 |
+
"""
|
| 123 |
+
Test Tanh op
|
| 124 |
+
"""
|
| 125 |
+
def evt_tanh(accum):
|
| 126 |
+
D = tanh(accum)
|
| 127 |
+
return D
|
| 128 |
+
|
| 129 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 130 |
+
example_inputs = {
|
| 131 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 132 |
+
"D": self.fake_tensor(self.element, (l, m, n))
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
launcher = EVTTestBed(self.element, evt_tanh, example_inputs)
|
| 136 |
+
input_keys = []
|
| 137 |
+
result_keys = ["D"]
|
| 138 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 139 |
+
|
| 140 |
+
def test_sigmoid(self):
|
| 141 |
+
"""
|
| 142 |
+
Test Sigmoid op
|
| 143 |
+
"""
|
| 144 |
+
def evt_sigmoid(accum):
|
| 145 |
+
D = sigmoid(accum)
|
| 146 |
+
return D
|
| 147 |
+
|
| 148 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 149 |
+
example_inputs = {
|
| 150 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 151 |
+
"D": self.fake_tensor(self.element, (l, m, n))
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
launcher = EVTTestBed(self.element, evt_sigmoid, example_inputs)
|
| 155 |
+
input_keys = []
|
| 156 |
+
result_keys = ["D"]
|
| 157 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 158 |
+
|
| 159 |
+
def test_gelu(self):
|
| 160 |
+
"""
|
| 161 |
+
Test GELU op
|
| 162 |
+
"""
|
| 163 |
+
def evt_gelu(accum):
|
| 164 |
+
D = gelu(accum)
|
| 165 |
+
return D
|
| 166 |
+
|
| 167 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 168 |
+
example_inputs = {
|
| 169 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 170 |
+
"D": self.fake_tensor(self.element, (l, m, n))
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
launcher = EVTTestBed(self.element, evt_gelu, example_inputs)
|
| 174 |
+
input_keys = []
|
| 175 |
+
result_keys = ["D"]
|
| 176 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 177 |
+
|
| 178 |
+
def test_exp(self):
|
| 179 |
+
"""
|
| 180 |
+
Test Exp op
|
| 181 |
+
"""
|
| 182 |
+
def evt_exp(accum):
|
| 183 |
+
D = exp(accum)
|
| 184 |
+
return D
|
| 185 |
+
|
| 186 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 187 |
+
example_inputs = {
|
| 188 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 189 |
+
"D": self.fake_tensor(self.element, (l, m, n))
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
launcher = EVTTestBed(self.element, evt_exp, example_inputs)
|
| 193 |
+
input_keys = []
|
| 194 |
+
result_keys = ["D"]
|
| 195 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 196 |
+
|
| 197 |
+
if __name__ == '__main__':
|
| 198 |
+
unittest.main()
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Unit test for store nodes in SM90
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
import cutlass_cppgen
|
| 41 |
+
from cutlass_cppgen.backend import *
|
| 42 |
+
from cutlass_cppgen.epilogue import *
|
| 43 |
+
|
| 44 |
+
from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
|
| 45 |
+
|
| 46 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
|
| 50 |
+
class TestEVTLayout(EVTTestCaseBase):
|
| 51 |
+
|
| 52 |
+
def test_permute_1(self):
|
| 53 |
+
"""
|
| 54 |
+
Returning a tensor with shape [m, n]
|
| 55 |
+
"""
|
| 56 |
+
def evt_permute(accum, alpha, C):
|
| 57 |
+
F = alpha * accum
|
| 58 |
+
F_permute = permute(F, indices=(0, 2, 1))
|
| 59 |
+
D_permute = F_permute + permute(C, indices=(0, 2, 1))
|
| 60 |
+
D = permute(D_permute, indices=(0, 2, 1))
|
| 61 |
+
return D, F
|
| 62 |
+
|
| 63 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 64 |
+
example_inputs = {
|
| 65 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 66 |
+
"alpha": 0.5,
|
| 67 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 68 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 69 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
launcher = EVTTestBed(self.element, evt_permute, example_inputs)
|
| 73 |
+
input_keys = ["C", "alpha"]
|
| 74 |
+
result_keys = ["D", "F"]
|
| 75 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 76 |
+
|
| 77 |
+
@unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only")
|
| 78 |
+
def test_permute_2(self):
|
| 79 |
+
"""
|
| 80 |
+
Returning a tensor with shape [m, n]
|
| 81 |
+
"""
|
| 82 |
+
def evt_permute(accum, alpha, C):
|
| 83 |
+
F = alpha * accum
|
| 84 |
+
F_permute = permute(F, indices=(0, 2, 1))
|
| 85 |
+
D = F_permute + C
|
| 86 |
+
return D, F
|
| 87 |
+
|
| 88 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 89 |
+
example_inputs = {
|
| 90 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 91 |
+
"alpha": 0.5,
|
| 92 |
+
"C": self.fake_tensor(self.element, (l, n, m)),
|
| 93 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 94 |
+
"D": self.fake_tensor(self.element, (l, n, m)),
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
launcher = EVTTestBed(self.element, evt_permute, example_inputs)
|
| 98 |
+
input_keys = ["C", "alpha"]
|
| 99 |
+
result_keys = ["D", "F"]
|
| 100 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 101 |
+
|
| 102 |
+
@unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only")
|
| 103 |
+
def test_permute_3(self):
|
| 104 |
+
"""
|
| 105 |
+
Returning a tensor with shape [m, n]
|
| 106 |
+
"""
|
| 107 |
+
def evt_permute(accum, alpha, C):
|
| 108 |
+
F = alpha * accum
|
| 109 |
+
F_permute = permute(F, indices=(1, 0, 2))
|
| 110 |
+
D = F_permute + C
|
| 111 |
+
return D, F
|
| 112 |
+
|
| 113 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 114 |
+
example_inputs = {
|
| 115 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 116 |
+
"alpha": 0.5,
|
| 117 |
+
"C": self.fake_tensor(self.element, (m, l, n)),
|
| 118 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 119 |
+
"D": self.fake_tensor(self.element, (m, l, n)),
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
launcher = EVTTestBed(self.element, evt_permute, example_inputs)
|
| 123 |
+
input_keys = ["C", "alpha"]
|
| 124 |
+
result_keys = ["D", "F"]
|
| 125 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 126 |
+
|
| 127 |
+
def test_reshape(self):
|
| 128 |
+
"""
|
| 129 |
+
Test reshape
|
| 130 |
+
"""
|
| 131 |
+
def evt_reshape(accum, alpha, TensorE):
|
| 132 |
+
F = alpha * accum
|
| 133 |
+
E_reshape = reshape(TensorE, new_shape=(512, 1))
|
| 134 |
+
D = F + E_reshape
|
| 135 |
+
return D
|
| 136 |
+
|
| 137 |
+
example_inputs = {
|
| 138 |
+
"accum": self.fake_tensor(self.element, (self.l, self.m, self.n)),
|
| 139 |
+
"alpha": 0.5,
|
| 140 |
+
"TensorE": self.fake_tensor(self.element, (16, 32)),
|
| 141 |
+
"D": self.fake_tensor(self.element, (self.l, self.m, self.n)),
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
launcher = EVTTestBed(self.element, evt_reshape, example_inputs)
|
| 145 |
+
input_keys = ["alpha", "TensorE"]
|
| 146 |
+
result_keys = ["D"]
|
| 147 |
+
launcher.verify(self.problem_size, input_keys, result_keys, self.l)
|
| 148 |
+
|
| 149 |
+
def test_reshape2(self):
|
| 150 |
+
"""
|
| 151 |
+
Test reshape
|
| 152 |
+
"""
|
| 153 |
+
def evt_reshape(accum, alpha, TensorE):
|
| 154 |
+
F = alpha * accum
|
| 155 |
+
F_reshape = reshape(F, new_shape=(2, 3, 512, 256))
|
| 156 |
+
D = F_reshape + TensorE
|
| 157 |
+
return D
|
| 158 |
+
|
| 159 |
+
example_inputs = {
|
| 160 |
+
"accum": self.fake_tensor(self.element, (self.l, self.m, self.n)),
|
| 161 |
+
"alpha": 0.5,
|
| 162 |
+
"TensorE": self.fake_tensor(self.element, (2, 3, 1, self.n)),
|
| 163 |
+
"D": self.fake_tensor(self.element, (2, 3, self.m, self.n)),
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
launcher = EVTTestBed(self.element, evt_reshape, example_inputs)
|
| 167 |
+
input_keys = ["alpha", "TensorE"]
|
| 168 |
+
result_keys = ["D"]
|
| 169 |
+
launcher.verify(self.problem_size, input_keys, result_keys, self.l)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
if __name__ == '__main__':
|
| 173 |
+
unittest.main()
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Unit test for load nodes in SM90
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
import cutlass_cppgen
|
| 41 |
+
from cutlass_cppgen.backend import *
|
| 42 |
+
from cutlass_cppgen.epilogue import *
|
| 43 |
+
|
| 44 |
+
from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
|
| 45 |
+
|
| 46 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
|
| 50 |
+
class TestEVTLoad(EVTTestCaseBase):
|
| 51 |
+
|
| 52 |
+
def test_tensor_load(self):
|
| 53 |
+
"""
|
| 54 |
+
Load extra tensor with shape [m, n]
|
| 55 |
+
"""
|
| 56 |
+
def evt_tensor_load(accum, C, aux, aux_batch):
|
| 57 |
+
D = accum + C + aux + aux_batch
|
| 58 |
+
return D
|
| 59 |
+
|
| 60 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 61 |
+
example_inputs = {
|
| 62 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 63 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 64 |
+
"aux": self.fake_tensor(self.element, (m, n)),
|
| 65 |
+
"aux_batch": self.fake_tensor(np.float32, (l, m, n)),
|
| 66 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
launcher = EVTTestBed(self.element, evt_tensor_load, example_inputs)
|
| 70 |
+
input_keys = ["C", "aux", "aux_batch"]
|
| 71 |
+
result_keys = ["D"]
|
| 72 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 73 |
+
|
| 74 |
+
def test_row_broadcast(self):
|
| 75 |
+
"""
|
| 76 |
+
Load extra tensor with shape [1, n]
|
| 77 |
+
"""
|
| 78 |
+
def evt_row_broadcast(accum, C, bias, bias_batch):
|
| 79 |
+
D = accum + C + bias + bias_batch
|
| 80 |
+
return D
|
| 81 |
+
|
| 82 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 83 |
+
example_inputs = {
|
| 84 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 85 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 86 |
+
"bias": self.fake_tensor(self.element, (n,)),
|
| 87 |
+
"bias_batch": self.fake_tensor(np.float32, (l, 1, n)),
|
| 88 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
launcher = EVTTestBed(self.element, evt_row_broadcast, example_inputs)
|
| 92 |
+
input_keys = ["C", "bias", "bias_batch"]
|
| 93 |
+
result_keys = ["D"]
|
| 94 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 95 |
+
|
| 96 |
+
def test_column_broadcast(self):
|
| 97 |
+
"""
|
| 98 |
+
Load extra tensor with shape [m, 1]
|
| 99 |
+
"""
|
| 100 |
+
def evt_column_broadcast(accum, C, bias, bias_batch):
|
| 101 |
+
D = accum + C + bias + bias_batch
|
| 102 |
+
return D
|
| 103 |
+
|
| 104 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 105 |
+
example_inputs = {
|
| 106 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 107 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 108 |
+
"bias": self.fake_tensor(self.element, (m, 1)),
|
| 109 |
+
"bias_batch": self.fake_tensor(np.float32, (l, m, 1)),
|
| 110 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
launcher = EVTTestBed(self.element, evt_column_broadcast, example_inputs)
|
| 114 |
+
input_keys = ["C", "bias", "bias_batch"]
|
| 115 |
+
result_keys = ["D"]
|
| 116 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 117 |
+
|
| 118 |
+
def test_scalar_broadcast(self):
|
| 119 |
+
"""
|
| 120 |
+
Load extra tensor with shape [1, 1]
|
| 121 |
+
"""
|
| 122 |
+
def evt_scalar_broadcast(accum, C, alpha, alpha_batch):
|
| 123 |
+
D = accum + C + alpha + alpha_batch
|
| 124 |
+
return D
|
| 125 |
+
|
| 126 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 127 |
+
example_inputs = {
|
| 128 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 129 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 130 |
+
"alpha": 0.5,
|
| 131 |
+
"alpha_batch": self.fake_tensor(np.float32, (l, 1, 1)),
|
| 132 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
launcher = EVTTestBed(self.element, evt_scalar_broadcast, example_inputs)
|
| 136 |
+
input_keys = ["C", "alpha", "alpha_batch"]
|
| 137 |
+
result_keys = ["D"]
|
| 138 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if __name__ == '__main__':
|
| 142 |
+
unittest.main()
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Unittest for mixed types of nodes in SM90
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
import cutlass_cppgen
|
| 41 |
+
from cutlass_cppgen.backend import *
|
| 42 |
+
from cutlass_cppgen.epilogue import *
|
| 43 |
+
from cutlass_cppgen.swizzle import ThreadblockSwizzleStreamK
|
| 44 |
+
|
| 45 |
+
from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
|
| 46 |
+
|
| 47 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
|
| 51 |
+
class TestEVTMixed(EVTTestCaseBase):
|
| 52 |
+
|
| 53 |
+
def test_same_variable_used_multiple_times(self):
|
| 54 |
+
"""
|
| 55 |
+
The same variable z0 is used multiple times
|
| 56 |
+
"""
|
| 57 |
+
def evt_aux_store(accum):
|
| 58 |
+
z0 = relu(accum)
|
| 59 |
+
D = z0 + z0
|
| 60 |
+
return z0, D
|
| 61 |
+
|
| 62 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 63 |
+
example_inputs = {
|
| 64 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 65 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 66 |
+
"z0": self.fake_tensor(self.element, (l, m, n)),
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
launcher = EVTTestBed(self.element, evt_aux_store, example_inputs)
|
| 70 |
+
input_keys = ["accum"]
|
| 71 |
+
result_keys = ["z0", "D"]
|
| 72 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 73 |
+
|
| 74 |
+
def test_no_lca(self):
|
| 75 |
+
"""
|
| 76 |
+
The same variable z0 is used multiple times
|
| 77 |
+
"""
|
| 78 |
+
def evt_no_lca(accum, bias):
|
| 79 |
+
E = relu(accum)
|
| 80 |
+
F = E + bias
|
| 81 |
+
tmp_2 = E + 2
|
| 82 |
+
D = tmp_2 + E
|
| 83 |
+
return D
|
| 84 |
+
|
| 85 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 86 |
+
example_inputs = {
|
| 87 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 88 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 89 |
+
"bias": self.fake_tensor(self.element, (m,1), stride=(1,0)),
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
launcher = EVTTestBed(self.element, evt_no_lca, example_inputs)
|
| 93 |
+
input_keys = ["accum", "bias"]
|
| 94 |
+
result_keys = ["D"]
|
| 95 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 96 |
+
|
| 97 |
+
def test_mixed_dag(self):
|
| 98 |
+
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
|
| 99 |
+
F = alpha * accum + (beta * C + aux)
|
| 100 |
+
F_row_max = max(F, dim=[0, 1])
|
| 101 |
+
E = relu(F + 1) + cbias + rbias
|
| 102 |
+
E_col_max = max(E, dim=[0, 2])
|
| 103 |
+
D = E + F
|
| 104 |
+
return D, F, F_row_max, E_col_max
|
| 105 |
+
|
| 106 |
+
if device_cc() == 80:
|
| 107 |
+
alignments = [2, 4, 8]
|
| 108 |
+
else:
|
| 109 |
+
# Sm90 EVT currently only supports 128-bit alignment
|
| 110 |
+
alignments = [8,]
|
| 111 |
+
for align in alignments:
|
| 112 |
+
for m, n, k, l in self.get_problem_sizes(align):
|
| 113 |
+
example_inputs = {
|
| 114 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 115 |
+
"alpha": 1.0,
|
| 116 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 117 |
+
"beta": 1.0,
|
| 118 |
+
"aux": self.fake_tensor(self.element, (l, m, n)),
|
| 119 |
+
"cbias": self.fake_tensor(self.element, (m, 1)),
|
| 120 |
+
"rbias": self.fake_tensor(self.element, (n,)),
|
| 121 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 122 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 123 |
+
"F_row_max": self.fake_tensor(DataType.f32, (n,)),
|
| 124 |
+
"E_col_max": self.fake_tensor(DataType.f32, (m, 1))
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs)
|
| 128 |
+
input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
|
| 129 |
+
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
| 130 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 131 |
+
|
| 132 |
+
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
|
| 133 |
+
def test_mixed_dag_float(self):
|
| 134 |
+
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
|
| 135 |
+
F = alpha * accum + (beta * C + aux)
|
| 136 |
+
F_row_max = max(F, dim=[0, 1])
|
| 137 |
+
E = relu(F + 1) + cbias + rbias
|
| 138 |
+
E_col_max = max(E, dim=[0, 2])
|
| 139 |
+
D = E + F
|
| 140 |
+
return D, F, F_row_max, E_col_max
|
| 141 |
+
|
| 142 |
+
for align in [3, 2, 4]:
|
| 143 |
+
for m, n, k, l in self.get_problem_sizes(align):
|
| 144 |
+
example_inputs = {
|
| 145 |
+
"accum": self.fake_tensor(np.float32, (l, m, n)),
|
| 146 |
+
"alpha": 1.0,
|
| 147 |
+
"C": self.fake_tensor(np.float32, (l, m, n)),
|
| 148 |
+
"beta": 1.0,
|
| 149 |
+
"aux": self.fake_tensor(np.float32, (l, m, n)),
|
| 150 |
+
"cbias": self.fake_tensor(np.float32, (m, 1)),
|
| 151 |
+
"rbias": self.fake_tensor(np.float32, (n,)),
|
| 152 |
+
"D": self.fake_tensor(np.float32, (l, m, n)),
|
| 153 |
+
"F": self.fake_tensor(np.float32, (l, m, n)),
|
| 154 |
+
"F_row_max": self.fake_tensor(np.float32, (n,)),
|
| 155 |
+
"E_col_max": self.fake_tensor(np.float32, (m, 1))
|
| 156 |
+
}
|
| 157 |
+
launcher = EVTTestBed(DataType.f32, evt_mixed_dag, example_inputs)
|
| 158 |
+
input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
|
| 159 |
+
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
| 160 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 161 |
+
|
| 162 |
+
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
|
| 163 |
+
def test_mixed_dag_stage2(self):
|
| 164 |
+
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
|
| 165 |
+
F = alpha * accum + (beta * C + aux)
|
| 166 |
+
F_row_max = max(F, dim=[0, 1])
|
| 167 |
+
E = relu(F + 1) + cbias + rbias
|
| 168 |
+
E_col_max = max(E, dim=[0, 2])
|
| 169 |
+
D = E + F
|
| 170 |
+
return D, F, F_row_max, E_col_max
|
| 171 |
+
|
| 172 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 173 |
+
example_inputs = {
|
| 174 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 175 |
+
"alpha": 1.0,
|
| 176 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 177 |
+
"beta": 1.0,
|
| 178 |
+
"aux": self.fake_tensor(self.element, (l, m, n)),
|
| 179 |
+
"cbias": self.fake_tensor(self.element, (m, 1)),
|
| 180 |
+
"rbias": self.fake_tensor(self.element, (n,)),
|
| 181 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 182 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 183 |
+
"F_row_max": self.fake_tensor(DataType.f32, (n,)),
|
| 184 |
+
"E_col_max": self.fake_tensor(DataType.f32, (m, 1))
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, epilogue_stages=2)
|
| 188 |
+
input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
|
| 189 |
+
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
| 190 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 191 |
+
|
| 192 |
+
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
|
| 193 |
+
def test_mixed_dag_partition_k(self):
|
| 194 |
+
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
|
| 195 |
+
F = alpha * accum + (beta * C + aux)
|
| 196 |
+
F_row_max = max(F, dim=[0, 1])
|
| 197 |
+
E = relu(F + 1) + cbias + rbias
|
| 198 |
+
E_col_max = max(E, dim=[0, 2])
|
| 199 |
+
D = E + F
|
| 200 |
+
return D, F, F_row_max, E_col_max
|
| 201 |
+
|
| 202 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 203 |
+
example_inputs = {
|
| 204 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 205 |
+
"alpha": 1.0,
|
| 206 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 207 |
+
"beta": 1.0,
|
| 208 |
+
"aux": self.fake_tensor(self.element, (l, m, n)),
|
| 209 |
+
"cbias": self.fake_tensor(self.element, (m, 1)),
|
| 210 |
+
"rbias": self.fake_tensor(self.element, (n,)),
|
| 211 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 212 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 213 |
+
"F_row_max": self.fake_tensor(DataType.f32, (n,)),
|
| 214 |
+
"E_col_max": self.fake_tensor(DataType.f32, (m, 1))
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
tile_description = {
|
| 218 |
+
"threadblock_shape": [128, 128, 64],
|
| 219 |
+
"warp_count": [2, 2, 2]
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, tile_description=tile_description, epilogue_stages=2)
|
| 223 |
+
input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
|
| 224 |
+
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
| 225 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 226 |
+
|
| 227 |
+
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
|
| 228 |
+
def test_mixed_dag_stream_k(self):
|
| 229 |
+
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
|
| 230 |
+
F = alpha * accum + (beta * C + aux)
|
| 231 |
+
F_row_max = max(F, dim=[0, 1])
|
| 232 |
+
E = relu(F + 1) + cbias + rbias
|
| 233 |
+
E_col_max = max(E, dim=[0, 2])
|
| 234 |
+
D = E + F
|
| 235 |
+
return D, F, F_row_max, E_col_max
|
| 236 |
+
|
| 237 |
+
# High per-sm occupancy tile_description
|
| 238 |
+
tile_description = {
|
| 239 |
+
"threadblock_shape": [128, 128, 32],
|
| 240 |
+
"warp_count": [2, 2, 1],
|
| 241 |
+
"stages": 3
|
| 242 |
+
}
|
| 243 |
+
tds = [None, tile_description]
|
| 244 |
+
for td in tds:
|
| 245 |
+
for m, n, k, l in self.get_problem_sizes(8, k=960, batch_count=[1, 3]):
|
| 246 |
+
if l == 1:
|
| 247 |
+
example_inputs = {
|
| 248 |
+
"accum": self.fake_tensor(self.element, (m, n)),
|
| 249 |
+
"alpha": 1.0,
|
| 250 |
+
"C": self.fake_tensor(self.element, (m, n)),
|
| 251 |
+
"beta": 1.0,
|
| 252 |
+
"aux": self.fake_tensor(self.element, (m, n)),
|
| 253 |
+
"cbias": self.fake_tensor(self.element, (m, 1)),
|
| 254 |
+
"rbias": self.fake_tensor(self.element, (n,)),
|
| 255 |
+
"D": self.fake_tensor(self.element, (m, n)),
|
| 256 |
+
"F": self.fake_tensor(self.element, (m, n)),
|
| 257 |
+
"F_row_max": self.fake_tensor(DataType.f32, (n,)),
|
| 258 |
+
"E_col_max": self.fake_tensor(DataType.f32, (m, 1))
|
| 259 |
+
}
|
| 260 |
+
else:
|
| 261 |
+
example_inputs = {
|
| 262 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 263 |
+
"alpha": 1.0,
|
| 264 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 265 |
+
"beta": 1.0,
|
| 266 |
+
"aux": self.fake_tensor(self.element, (l, m, n)),
|
| 267 |
+
"cbias": self.fake_tensor(self.element, (m, 1)),
|
| 268 |
+
"rbias": self.fake_tensor(self.element, (n,)),
|
| 269 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 270 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 271 |
+
"F_row_max": self.fake_tensor(DataType.f32, (n,)),
|
| 272 |
+
"E_col_max": self.fake_tensor(DataType.f32, (m, 1))
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
if td is not None:
|
| 276 |
+
launcher = EVTTestBed(
|
| 277 |
+
self.element, evt_mixed_dag, example_inputs,
|
| 278 |
+
tile_description=td,
|
| 279 |
+
swizzling_functor=ThreadblockSwizzleStreamK, backend="torch")
|
| 280 |
+
else:
|
| 281 |
+
launcher = EVTTestBed(
|
| 282 |
+
self.element, evt_mixed_dag, example_inputs,
|
| 283 |
+
swizzling_functor=ThreadblockSwizzleStreamK, backend="torch")
|
| 284 |
+
|
| 285 |
+
input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
|
| 286 |
+
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
| 287 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 288 |
+
|
| 289 |
+
def test_mixed_dag_no_batch(self):
|
| 290 |
+
def evt_mixed_dag_no_batch(accum, alpha, C, beta, aux, cbias, rbias):
|
| 291 |
+
F = alpha * accum + (beta * C + aux)
|
| 292 |
+
F_row_max = max(F, dim=[0, 1])
|
| 293 |
+
E = relu(F + 1) + cbias + rbias
|
| 294 |
+
E_col_max = max(E, dim=[0, 2])
|
| 295 |
+
D = E + F
|
| 296 |
+
return D, F, F_row_max, E_col_max
|
| 297 |
+
|
| 298 |
+
for m, n, k, _ in self.get_problem_sizes(8):
|
| 299 |
+
example_inputs = {
|
| 300 |
+
"accum": self.fake_tensor(self.element, (m, n)),
|
| 301 |
+
"alpha": 1.0,
|
| 302 |
+
"C": self.fake_tensor(self.element, (m, n)),
|
| 303 |
+
"beta": 1.0,
|
| 304 |
+
"aux": self.fake_tensor(self.element, (m, n)),
|
| 305 |
+
"cbias": self.fake_tensor(self.element, (m, 1)),
|
| 306 |
+
"rbias": self.fake_tensor(self.element, (n,)),
|
| 307 |
+
"D": self.fake_tensor(self.element, (m, n)),
|
| 308 |
+
"F": self.fake_tensor(self.element, (m, n)),
|
| 309 |
+
"F_row_max": self.fake_tensor(DataType.f32, (n,)),
|
| 310 |
+
"E_col_max": self.fake_tensor(DataType.f32, (m, 1))
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
launcher = EVTTestBed(self.element, evt_mixed_dag_no_batch, example_inputs)
|
| 314 |
+
input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
|
| 315 |
+
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
| 316 |
+
launcher.verify((m, n, k), input_keys, result_keys, 1)
|
| 317 |
+
|
| 318 |
+
if __name__ == '__main__':
|
| 319 |
+
unittest.main()
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Unit test for store nodes in SM90
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
import cutlass_cppgen
|
| 41 |
+
from cutlass_cppgen.backend import *
|
| 42 |
+
from cutlass_cppgen.epilogue import *
|
| 43 |
+
|
| 44 |
+
from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
|
| 45 |
+
|
| 46 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
|
| 50 |
+
class TestEVTStore(EVTTestCaseBase):
|
| 51 |
+
|
| 52 |
+
@unittest.skipIf(device_cc() != 90, "This test is only for CC 90")
|
| 53 |
+
def test_invalid_store(self):
|
| 54 |
+
"""
|
| 55 |
+
Test invalid store
|
| 56 |
+
"""
|
| 57 |
+
def evt_invalid_store(accum):
|
| 58 |
+
D = accum
|
| 59 |
+
F = D + 1 # D has users, which is not allowed on SM90 or higher
|
| 60 |
+
return D, F
|
| 61 |
+
|
| 62 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 63 |
+
example_inputs = {
|
| 64 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 65 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 66 |
+
"F": self.fake_tensor(self.element, (l, m, n))
|
| 67 |
+
}
|
| 68 |
+
with self.assertRaisesRegex(
|
| 69 |
+
RuntimeError,
|
| 70 |
+
r"On SM90 or higher, D is expected to be a output node with 0 users "
|
| 71 |
+
r"to enable smem reuse between C and D, but got 1"
|
| 72 |
+
):
|
| 73 |
+
launcher = EVTTestBed(self.element, evt_invalid_store, example_inputs)
|
| 74 |
+
|
| 75 |
+
break # Only need to test once
|
| 76 |
+
|
| 77 |
+
def test_aux_store(self):
|
| 78 |
+
"""
|
| 79 |
+
Returning a tensor with shape [m, n]
|
| 80 |
+
"""
|
| 81 |
+
def evt_aux_store(accum, alpha, C):
|
| 82 |
+
F = alpha * accum
|
| 83 |
+
D = F + C
|
| 84 |
+
return D, F
|
| 85 |
+
|
| 86 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 87 |
+
example_inputs = {
|
| 88 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 89 |
+
"alpha": 0.5,
|
| 90 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 91 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 92 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
launcher = EVTTestBed(self.element, evt_aux_store, example_inputs)
|
| 96 |
+
input_keys = ["C", "alpha"]
|
| 97 |
+
result_keys = ["D", "F"]
|
| 98 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 99 |
+
|
| 100 |
+
def test_col_reduce(self):
|
| 101 |
+
"""
|
| 102 |
+
Reduction [m, n] -> [m, 1]
|
| 103 |
+
"""
|
| 104 |
+
def evt_row_reduce(accum, alpha, C):
|
| 105 |
+
acc_row_max = max(accum, dim=[2,])
|
| 106 |
+
F = alpha * accum
|
| 107 |
+
F_row_max = max(F, dim=[0, 2])
|
| 108 |
+
D = F + C
|
| 109 |
+
return D, F_row_max, acc_row_max
|
| 110 |
+
|
| 111 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 112 |
+
example_inputs = {
|
| 113 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 114 |
+
"alpha": 2.0,
|
| 115 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 116 |
+
"F_row_max": self.fake_tensor(np.float32, (m, 1)),
|
| 117 |
+
"acc_row_max": self.fake_tensor(np.float32, (l, m, 1)),
|
| 118 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
launcher = EVTTestBed(self.element, evt_row_reduce, example_inputs)
|
| 122 |
+
input_keys = ["C", "alpha"]
|
| 123 |
+
result_keys = ["D", "F_row_max", "acc_row_max"]
|
| 124 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 125 |
+
|
| 126 |
+
def test_row_reduce(self):
|
| 127 |
+
"""
|
| 128 |
+
Reduction [m, n] -> [n]
|
| 129 |
+
"""
|
| 130 |
+
def evt_col_reduce(accum, alpha, C):
|
| 131 |
+
acc_col_max = max(accum, dim=[1,])
|
| 132 |
+
F = alpha * accum
|
| 133 |
+
F_col_max = max(F, dim=[0, 1])
|
| 134 |
+
D = F + C
|
| 135 |
+
return D, F_col_max, acc_col_max
|
| 136 |
+
|
| 137 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 138 |
+
example_inputs = {
|
| 139 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 140 |
+
"alpha": 2.0,
|
| 141 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 142 |
+
"F_col_max": self.fake_tensor(np.float32, (n,)),
|
| 143 |
+
"acc_col_max": self.fake_tensor(np.float32, (l, 1, n)),
|
| 144 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
launcher = EVTTestBed(self.element, evt_col_reduce, example_inputs)
|
| 148 |
+
input_keys = ["C", "alpha"]
|
| 149 |
+
result_keys = ["D", "F_col_max", "acc_col_max"]
|
| 150 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 151 |
+
|
| 152 |
+
def test_scalar_reduce(self):
|
| 153 |
+
"""
|
| 154 |
+
Reduction [m, n] -> [1,]
|
| 155 |
+
"""
|
| 156 |
+
def evt_scalar_reduce(accum, alpha, C):
|
| 157 |
+
acc_max = max(accum, dim=[1, 2])
|
| 158 |
+
F = alpha * accum
|
| 159 |
+
F_max = max(F, dim=[0, 1, 2])
|
| 160 |
+
D = F + C
|
| 161 |
+
return D, F_max, acc_max
|
| 162 |
+
|
| 163 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 164 |
+
example_inputs = {
|
| 165 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 166 |
+
"alpha": 2.0,
|
| 167 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 168 |
+
"acc_max": self.fake_tensor(np.float32, (l, 1, 1)),
|
| 169 |
+
"F_max": self.fake_tensor(np.float32, (1,)),
|
| 170 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
launcher = EVTTestBed(self.element, evt_scalar_reduce, example_inputs)
|
| 174 |
+
input_keys = ["C", "alpha"]
|
| 175 |
+
result_keys = ["D", "F_max", "acc_max"]
|
| 176 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
if __name__ == '__main__':
|
| 180 |
+
unittest.main()
|
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
import pathlib
|
| 34 |
+
import unittest
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if __name__ == '__main__':
|
| 38 |
+
loader = unittest.TestLoader()
|
| 39 |
+
script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/'
|
| 40 |
+
tests = loader.discover(script_dir, 'evt_*.py')
|
| 41 |
+
testRunner = unittest.runner.TextTestRunner()
|
| 42 |
+
results = testRunner.run(tests)
|
| 43 |
+
if not results.wasSuccessful():
|
| 44 |
+
raise Exception('Test cases failed')
|