Spaces:
Sleeping
Sleeping
| # | |
| # \file generator.py | |
| # | |
| # \brief Generates the CUTLASS Library's instances | |
| # | |
| import re | |
| ################################################################################################### | |
| import enum | |
| # The following block implements enum.auto() for Python 3.5 variants that don't include it such | |
| # as the default 3.5.2 on Ubuntu 16.04. | |
| # | |
| # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility | |
| try: | |
| from enum import auto as enum_auto | |
| except ImportError: | |
| __cutlass_library_auto_enum = 0 | |
| def enum_auto() -> int: | |
| global __cutlass_library_auto_enum | |
| i = __cutlass_library_auto_enum | |
| __cutlass_library_auto_enum += 1 | |
| return i | |
| ################################################################################################### | |
| # | |
| class GeneratorTarget(enum.Enum): | |
| Library = enum_auto() | |
| # | |
| GeneratorTargetNames = { | |
| GeneratorTarget.Library: 'library' | |
| } | |
| # | |
| ################################################################################################### | |
| # | |
| class DataType(enum.Enum): | |
| b1 = enum_auto() | |
| u4 = enum_auto() | |
| u8 = enum_auto() | |
| u16 = enum_auto() | |
| u32 = enum_auto() | |
| u64 = enum_auto() | |
| s4 = enum_auto() | |
| s8 = enum_auto() | |
| s16 = enum_auto() | |
| s32 = enum_auto() | |
| s64 = enum_auto() | |
| f16 = enum_auto() | |
| bf16 = enum_auto() | |
| f32 = enum_auto() | |
| tf32 = enum_auto() | |
| f64 = enum_auto() | |
| cf16 = enum_auto() | |
| cbf16 = enum_auto() | |
| cf32 = enum_auto() | |
| ctf32 = enum_auto() | |
| cf64 = enum_auto() | |
| cs4 = enum_auto() | |
| cs8 = enum_auto() | |
| cs16 = enum_auto() | |
| cs32 = enum_auto() | |
| cs64 = enum_auto() | |
| cu4 = enum_auto() | |
| cu8 = enum_auto() | |
| cu16 = enum_auto() | |
| cu32 = enum_auto() | |
| cu64 = enum_auto() | |
| invalid = enum_auto() | |
| # | |
| ShortDataTypeNames = { | |
| DataType.s32: 'i', | |
| DataType.f16: 'h', | |
| DataType.f32: 's', | |
| DataType.f64: 'd', | |
| DataType.cf32: 'c', | |
| DataType.cf64: 'z', | |
| } | |
| # | |
| DataTypeNames = { | |
| DataType.b1: "b1", | |
| DataType.u4: "u4", | |
| DataType.u8: "u8", | |
| DataType.u16: "u16", | |
| DataType.u32: "u32", | |
| DataType.u64: "u64", | |
| DataType.s4: "s4", | |
| DataType.s8: "s8", | |
| DataType.s16: "s16", | |
| DataType.s32: "s32", | |
| DataType.s64: "s64", | |
| DataType.f16: "f16", | |
| DataType.bf16: "bf16", | |
| DataType.f32: "f32", | |
| DataType.tf32: "tf32", | |
| DataType.f64: "f64", | |
| DataType.cf16: "cf16", | |
| DataType.cbf16: "cbf16", | |
| DataType.cf32: "cf32", | |
| DataType.ctf32: "ctf32", | |
| DataType.cf64: "cf64", | |
| DataType.cu4: "cu4", | |
| DataType.cu8: "cu8", | |
| DataType.cu16: "cu16", | |
| DataType.cu32: "cu32", | |
| DataType.cu64: "cu64", | |
| DataType.cs4: "cs4", | |
| DataType.cs8: "cs8", | |
| DataType.cs16: "cs16", | |
| DataType.cs32: "cs32", | |
| DataType.cs64: "cs64", | |
| } | |
| DataTypeTag = { | |
| DataType.b1: "cutlass::uint1b_t", | |
| DataType.u4: "cutlass::uint4b_t", | |
| DataType.u8: "uint8_t", | |
| DataType.u16: "uint16_t", | |
| DataType.u32: "uint32_t", | |
| DataType.u64: "uint64_t", | |
| DataType.s4: "cutlass::int4b_t", | |
| DataType.s8: "int8_t", | |
| DataType.s16: "int16_t", | |
| DataType.s32: "int32_t", | |
| DataType.s64: "int64_t", | |
| DataType.f16: "cutlass::half_t", | |
| DataType.bf16: "cutlass::bfloat16_t", | |
| DataType.f32: "float", | |
| DataType.tf32: "cutlass::tfloat32_t", | |
| DataType.f64: "double", | |
| DataType.cf16: "cutlass::complex<cutlass::half_t>", | |
| DataType.cbf16: "cutlass::complex<cutlass::bfloat16_t>", | |
| DataType.cf32: "cutlass::complex<float>", | |
| DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>", | |
| DataType.cf64: "cutlass::complex<double>", | |
| DataType.cu4: "cutlass::complex<cutlass::uint4b_t>", | |
| DataType.cu8: "cutlass::complex<cutlass::uint8_t>", | |
| DataType.cu16: "cutlass::complex<cutlass::uint16_t>", | |
| DataType.cu32: "cutlass::complex<cutlass::uint32_t>", | |
| DataType.cu64: "cutlass::complex<cutlass::uint64_t>", | |
| DataType.cs4: "cutlass::complex<cutlass::int4b_t>", | |
| DataType.cs8: "cutlass::complex<cutlass::int8_t>", | |
| DataType.cs16: "cutlass::complex<cutlass::int16_t>", | |
| DataType.cs32: "cutlass::complex<cutlass::int32_t>", | |
| DataType.cs64: "cutlass::complex<cutlass::int64_t>", | |
| } | |
| DataTypeSize = { | |
| DataType.b1: 1, | |
| DataType.u4: 4, | |
| DataType.u8: 8, | |
| DataType.u16: 16, | |
| DataType.u32: 32, | |
| DataType.u64: 64, | |
| DataType.s4: 4, | |
| DataType.s8: 8, | |
| DataType.s16: 16, | |
| DataType.s32: 32, | |
| DataType.s64: 64, | |
| DataType.f16: 16, | |
| DataType.bf16: 16, | |
| DataType.f32: 32, | |
| DataType.tf32: 32, | |
| DataType.f64: 64, | |
| DataType.cf16: 32, | |
| DataType.cbf16: 32, | |
| DataType.cf32: 64, | |
| DataType.ctf32: 32, | |
| DataType.cf64: 128, | |
| DataType.cu4: 8, | |
| DataType.cu8: 16, | |
| DataType.cu16: 32, | |
| DataType.cu32: 64, | |
| DataType.cu64: 128, | |
| DataType.cs4: 8, | |
| DataType.cs8: 16, | |
| DataType.cs16: 32, | |
| DataType.cs32: 64, | |
| DataType.cs64: 128, | |
| } | |
| ################################################################################################### | |
| # | |
| class BlasMode(enum.Enum): | |
| symmetric = enum_auto() | |
| hermitian = enum_auto() | |
| # | |
| BlasModeTag = { | |
| BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric', | |
| BlasMode.hermitian: 'cutlass::BlasMode::kHermitian', | |
| } | |
| # | |
| class ComplexTransform(enum.Enum): | |
| none = enum_auto() | |
| conj = enum_auto() | |
| # | |
| ComplexTransformTag = { | |
| ComplexTransform.none: 'cutlass::ComplexTransform::kNone', | |
| ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate', | |
| } | |
| # | |
| RealComplexBijection = [ | |
| (DataType.f16, DataType.cf16), | |
| (DataType.f32, DataType.cf32), | |
| (DataType.f64, DataType.cf64), | |
| ] | |
| # | |
| def is_complex(data_type): | |
| for r, c in RealComplexBijection: | |
| if data_type == c: | |
| return True | |
| return False | |
| # | |
| def get_complex_from_real(real_type): | |
| for r, c in RealComplexBijection: | |
| if real_type == r: | |
| return c | |
| return DataType.invalid | |
| # | |
| def get_real_from_complex(complex_type): | |
| for r, c in RealComplexBijection: | |
| if complex_type == c: | |
| return r | |
| return DataType.invalid | |
| # | |
| class ComplexMultiplyOp(enum.Enum): | |
| multiply_add = enum_auto() | |
| gaussian = enum_auto() | |
| ################################################################################################### | |
| # | |
| class MathOperation(enum.Enum): | |
| multiply_add = enum_auto() | |
| multiply_add_saturate = enum_auto() | |
| xor_popc = enum_auto() | |
| multiply_add_fast_bf16 = enum_auto() | |
| multiply_add_fast_f16 = enum_auto() | |
| multiply_add_fast_f32 = enum_auto() | |
| multiply_add_complex_fast_f32 = enum_auto() | |
| multiply_add_complex = enum_auto() | |
| multiply_add_complex_gaussian = enum_auto() | |
| # | |
| MathOperationTag = { | |
| MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', | |
| MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', | |
| MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', | |
| MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', | |
| MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', | |
| MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32', | |
| MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32', | |
| MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', | |
| MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex', | |
| } | |
| ################################################################################################### | |
| # | |
| class LayoutType(enum.Enum): | |
| ColumnMajor = enum_auto() | |
| RowMajor = enum_auto() | |
| ColumnMajorInterleaved2 = enum_auto() | |
| RowMajorInterleaved2 = enum_auto() | |
| ColumnMajorInterleaved32 = enum_auto() | |
| RowMajorInterleaved32 = enum_auto() | |
| ColumnMajorInterleaved64 = enum_auto() | |
| RowMajorInterleaved64 = enum_auto() | |
| TensorNHWC = enum_auto() | |
| TensorNDHWC = enum_auto() | |
| TensorNCHW = enum_auto() | |
| TensorNGHWC = enum_auto() | |
| TensorNC32HW32 = enum_auto() | |
| TensorNC64HW64 = enum_auto() | |
| TensorC32RSK32 = enum_auto() | |
| TensorC64RSK64 = enum_auto() | |
| # | |
| LayoutTag = { | |
| LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor', | |
| LayoutType.RowMajor: 'cutlass::layout::RowMajor', | |
| LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>', | |
| LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>', | |
| LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>', | |
| LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>', | |
| LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>', | |
| LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>', | |
| LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC', | |
| LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC', | |
| LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW', | |
| LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC', | |
| LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>', | |
| LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', | |
| LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', | |
| LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', | |
| } | |
| # | |
| TransposedLayout = { | |
| LayoutType.ColumnMajor: LayoutType.RowMajor, | |
| LayoutType.RowMajor: LayoutType.ColumnMajor, | |
| LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2, | |
| LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2, | |
| LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, | |
| LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, | |
| LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, | |
| LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, | |
| LayoutType.TensorNHWC: LayoutType.TensorNHWC | |
| } | |
| # | |
| ShortLayoutTypeNames = { | |
| LayoutType.ColumnMajor: 'n', | |
| LayoutType.ColumnMajorInterleaved2: 'n2', | |
| LayoutType.ColumnMajorInterleaved32: 'n32', | |
| LayoutType.ColumnMajorInterleaved64: 'n64', | |
| LayoutType.RowMajor: 't', | |
| LayoutType.RowMajorInterleaved2: 't2', | |
| LayoutType.RowMajorInterleaved32: 't32', | |
| LayoutType.RowMajorInterleaved64: 't64', | |
| LayoutType.TensorNHWC: 'nhwc', | |
| LayoutType.TensorNDHWC: 'ndhwc', | |
| LayoutType.TensorNCHW: 'nchw', | |
| LayoutType.TensorNGHWC: 'nghwc', | |
| LayoutType.TensorNC32HW32: 'nc32hw32', | |
| LayoutType.TensorNC64HW64: 'nc64hw64', | |
| LayoutType.TensorC32RSK32: 'c32rsk32', | |
| LayoutType.TensorC64RSK64: 'c64rsk64' | |
| } | |
| # | |
| ShortComplexLayoutNames = { | |
| (LayoutType.ColumnMajor, ComplexTransform.none): 'n', | |
| (LayoutType.ColumnMajor, ComplexTransform.conj): 'c', | |
| (LayoutType.RowMajor, ComplexTransform.none): 't', | |
| (LayoutType.RowMajor, ComplexTransform.conj): 'h' | |
| } | |
| ################################################################################################### | |
| # | |
| class SideMode(enum.Enum): | |
| Left = enum_auto() | |
| Right = enum_auto() | |
| # | |
| SideModeTag = { | |
| SideMode.Left: 'cutlass::SideMode::kLeft', | |
| SideMode.Right: 'cutlass::SideMode::kRight' | |
| } | |
| # | |
| ShortSideModeNames = { | |
| SideMode.Left: 'ls', | |
| SideMode.Right: 'rs' | |
| } | |
| ################################################################################################### | |
| # | |
| class FillMode(enum.Enum): | |
| Lower = enum_auto() | |
| Upper = enum_auto() | |
| # | |
| FillModeTag = { | |
| FillMode.Lower: 'cutlass::FillMode::kLower', | |
| FillMode.Upper: 'cutlass::FillMode::kUpper' | |
| } | |
| # | |
| ShortFillModeNames = { | |
| FillMode.Lower: 'l', | |
| FillMode.Upper: 'u' | |
| } | |
| ################################################################################################### | |
| # | |
| class DiagType(enum.Enum): | |
| NonUnit = enum_auto() | |
| Unit = enum_auto() | |
| # | |
| DiagTypeTag = { | |
| DiagType.NonUnit: 'cutlass::DiagType::kNonUnit', | |
| DiagType.Unit: 'cutlass::DiagType::kUnit' | |
| } | |
| # | |
| ShortDiagTypeNames = { | |
| DiagType.NonUnit: 'nu', | |
| DiagType.Unit: 'un' | |
| } | |
| ################################################################################################### | |
| # | |
| class OpcodeClass(enum.Enum): | |
| Simt = enum_auto() | |
| TensorOp = enum_auto() | |
| WmmaTensorOp = enum_auto() | |
| SparseTensorOp = enum_auto() | |
| OpcodeClassNames = { | |
| OpcodeClass.Simt: 'simt', | |
| OpcodeClass.TensorOp: 'tensorop', | |
| OpcodeClass.WmmaTensorOp: 'wmma_tensorop', | |
| } | |
| OpcodeClassTag = { | |
| OpcodeClass.Simt: 'cutlass::arch::OpClassSimt', | |
| OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp', | |
| OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp', | |
| } | |
| ################################################################################################### | |
| # | |
| class OperationKind(enum.Enum): | |
| Gemm = enum_auto() | |
| RankK = enum_auto() | |
| Rank2K = enum_auto() | |
| Trmm = enum_auto() | |
| Symm = enum_auto() | |
| Conv2d = enum_auto() | |
| Conv3d = enum_auto() | |
| # | |
| OperationKindNames = { | |
| OperationKind.Gemm: 'gemm' | |
| , OperationKind.RankK: 'rank_k' | |
| , OperationKind.Rank2K: 'rank_2k' | |
| , OperationKind.Trmm: 'trmm' | |
| , OperationKind.Symm: 'symm' | |
| , OperationKind.Conv2d: 'conv2d' | |
| , OperationKind.Conv3d: 'conv3d' | |
| } | |
| # | |
| class Target(enum.Enum): | |
| library = enum_auto() | |
| # | |
| ArchitectureNames = { | |
| 50: 'maxwell', | |
| 60: 'pascal', | |
| 61: 'pascal', | |
| 70: 'volta', | |
| 75: 'turing', | |
| 80: 'ampere', | |
| } | |
| # | |
| SharedMemPerCC = { | |
| 70: 96, # 96KB of SMEM | |
| 72: 96, # 96KB of SMEM | |
| 75: 64, # 64KB of SMEM | |
| 80: 163, # 163KB of SMEM - 1KB reserved for the driver | |
| 86: 99, # 99KB of SMEM - 1KB reserved for the driver | |
| 87: 163, # 163KB of SMEM - 1KB reserved for the driver | |
| 89: 99, # 99KB of SMEM - 1KB reserved for the driver | |
| 90: 227, # 227KB of SMEM - 1KB reserved for the driver | |
| } | |
| ################################################################################################### | |
| # | |
| def SubstituteTemplate(template, values): | |
| text = template | |
| changed = True | |
| while changed: | |
| changed = False | |
| for key, value in values.items(): | |
| regex = "\\$\\{%s\\}" % key | |
| newtext = re.sub(regex, value, text) | |
| if newtext != text: | |
| changed = True | |
| text = newtext | |
| return text | |
| ################################################################################################### | |
| # | |
| class GemmKind(enum.Enum): | |
| Gemm = enum_auto() | |
| Sparse = enum_auto() | |
| Universal = enum_auto() | |
| PlanarComplex = enum_auto() | |
| PlanarComplexArray = enum_auto() | |
| Grouped = enum_auto() | |
| # | |
| GemmKindNames = { | |
| GemmKind.Gemm: "gemm", | |
| GemmKind.Sparse: "spgemm", | |
| GemmKind.Universal: "gemm", | |
| GemmKind.PlanarComplex: "gemm_planar_complex", | |
| GemmKind.PlanarComplexArray: "gemm_planar_complex_array", | |
| GemmKind.Grouped: "gemm_grouped" | |
| } | |
| # | |
| class RankKKind(enum.Enum): | |
| Universal = enum_auto() | |
| # | |
| RankKKindNames = { | |
| RankKKind.Universal: "rank_k" | |
| } | |
| # | |
| class TrmmKind(enum.Enum): | |
| Universal = enum_auto() | |
| # | |
| TrmmKindNames = { | |
| TrmmKind.Universal: "trmm" | |
| } | |
| # | |
| class SymmKind(enum.Enum): | |
| Universal = enum_auto() | |
| # | |
| SymmKindNames = { | |
| SymmKind.Universal: "symm" | |
| } | |
| # | |
| class EpilogueFunctor(enum.Enum): | |
| LinearCombination = enum_auto() | |
| LinearCombinationClamp = enum_auto() | |
| # | |
| EpilogueFunctorTag = { | |
| EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination', | |
| EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', | |
| } | |
| # | |
| class SwizzlingFunctor(enum.Enum): | |
| Identity1 = enum_auto() | |
| Identity2 = enum_auto() | |
| Identity4 = enum_auto() | |
| Identity8 = enum_auto() | |
| Horizontal = enum_auto() | |
| StridedDgradIdentity1 = enum_auto() | |
| StridedDgradIdentity4 = enum_auto() | |
| StridedDgradHorizontal = enum_auto() | |
| StreamK = enum_auto() | |
| # | |
| SwizzlingFunctorTag = { | |
| SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>', | |
| SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>', | |
| SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', | |
| SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', | |
| SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle', | |
| SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>', | |
| SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>', | |
| SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle', | |
| SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK', | |
| } | |
| # | |
| class GroupScheduleMode(enum.Enum): | |
| Device = enum_auto(), | |
| Host = enum_auto() | |
| # | |
| GroupScheduleModeTag = { | |
| GroupScheduleMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly', | |
| GroupScheduleMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute' | |
| } | |
| # | |
| ShortGroupScheduleModeNames = { | |
| GroupScheduleMode.Device: 'Device', | |
| GroupScheduleMode.Host: 'Host' | |
| } | |
| ################################################################################################### | |
| # | |
| class ConvKind(enum.Enum): | |
| Fprop = enum_auto() | |
| Dgrad = enum_auto() | |
| Wgrad = enum_auto() | |
| # | |
| ConvKindTag = { | |
| ConvKind.Fprop: 'cutlass::conv::Operator::kFprop', | |
| ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad', | |
| ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad' | |
| } | |
| ConvKindNames = { | |
| ConvKind.Fprop: 'fprop', | |
| ConvKind.Dgrad: 'dgrad', | |
| ConvKind.Wgrad: 'wgrad', | |
| } | |
| # | |
| class IteratorAlgorithm(enum.Enum): | |
| Analytic = enum_auto() | |
| Optimized = enum_auto() | |
| FixedChannels = enum_auto() | |
| FewChannels = enum_auto() | |
| FixedStrideDilation = enum_auto() | |
| # | |
| IteratorAlgorithmTag = { | |
| IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic', | |
| IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized', | |
| IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels', | |
| IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels', | |
| IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation' | |
| } | |
| IteratorAlgorithmNames = { | |
| IteratorAlgorithm.Analytic: 'analytic', | |
| IteratorAlgorithm.Optimized: 'optimized', | |
| IteratorAlgorithm.FixedChannels: 'fixed_channels', | |
| IteratorAlgorithm.FewChannels: 'few_channels', | |
| IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation' | |
| } | |
| # | |
| class StrideSupport(enum.Enum): | |
| Strided = enum_auto() | |
| Unity = enum_auto() | |
| Fixed = enum_auto() | |
| # | |
| StrideSupportTag = { | |
| StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided', | |
| StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity', | |
| StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed' | |
| } | |
| StrideSupportNames = { | |
| StrideSupport.Strided: '', | |
| StrideSupport.Unity: 'unity_stride', | |
| StrideSupport.Fixed: 'fixed_stride' | |
| } | |
| # | |
| class GroupMode(enum.Enum): | |
| NoneGroup = enum_auto() # dense conv (G=1) | |
| SingleGroup = enum_auto() # grouped convolution (single group per CTA) | |
| MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA) | |
| Depthwise = enum_auto() # Depthwise convolution ( C=K=G ) | |
| # | |
| GroupModeTag = { | |
| GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone', | |
| GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup', | |
| GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup', | |
| GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise', | |
| } | |
| GroupModeNames = { | |
| GroupMode.NoneGroup: '', | |
| GroupMode.SingleGroup: 'single_group', | |
| GroupMode.MultipleGroup: 'multiple_group', | |
| GroupMode.Depthwise: 'depthwise', | |
| } | |
| ################################################################################################### | |
| # | |
| class MathInstruction: | |
| def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation = MathOperation.multiply_add): | |
| self.instruction_shape = instruction_shape | |
| self.element_a = element_a | |
| self.element_b = element_b | |
| self.element_accumulator = element_accumulator | |
| self.opcode_class = opcode_class | |
| self.math_operation = math_operation | |
| # | |
| class TileDescription: | |
| def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute): | |
| self.threadblock_shape = threadblock_shape | |
| self.stages = stages | |
| self.warp_count = warp_count | |
| self.math_instruction = math_instruction | |
| self.minimum_compute_capability = min_compute | |
| self.maximum_compute_capability = max_compute | |
| def procedural_name(self): | |
| return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) | |
| # | |
| class Direct2dConvFixedStrideDilationTileDescription: | |
| def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): | |
| self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] | |
| self.threadblock_output_shape = threadblock_output_shape | |
| self.filter_shape = filter_shape | |
| self.stages = stages | |
| self.warp_count = warp_count | |
| self.stride = stride | |
| self.dilation = dilation | |
| self.math_instruction = math_instruction | |
| self.minimum_compute_capability = min_compute | |
| self.maximum_compute_capability = max_compute | |
| def procedural_name(self): | |
| str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], | |
| self.threadblock_shape[1], | |
| self.threadblock_shape[2], | |
| self.threadblock_output_shape[0], | |
| self.threadblock_output_shape[1], | |
| self.threadblock_output_shape[2], | |
| self.threadblock_output_shape[3], | |
| self.stages, | |
| self.filter_shape[0], | |
| self.filter_shape[1]) | |
| # Fixed Strided and dilation | |
| if self.stride != [-1, -1] and self.dilation != [-1, -1]: | |
| str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], | |
| self.stride[1], | |
| self.dilation[0], | |
| self.dilation[1]) | |
| return str_name | |
| # | |
| class TensorDescription: | |
| def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none): | |
| self.element = element | |
| self.layout = layout | |
| self.alignment = alignment | |
| self.complex_transform = complex_transform | |
| # | |
| class SymmetricTensorDescription: | |
| def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left): | |
| self.element = element | |
| self.layout = layout | |
| self.fill_mode = fill_mode | |
| self.alignment = alignment | |
| self.complex_transform = complex_transform | |
| self.side_mode = side_mode | |
| # | |
| class TriangularTensorDescription: | |
| def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none): | |
| self.element = element | |
| self.layout = layout | |
| self.side_mode = side_mode | |
| self.fill_mode = fill_mode | |
| self.diag_type = diag_type | |
| self.alignment = alignment | |
| self.complex_transform = complex_transform | |
| ################################################################################################### | |
| # | |
| def CalculateSmemUsage(operation): | |
| cta_shape = operation.tile_description.threadblock_shape | |
| stages = operation.tile_description.stages | |
| if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse: | |
| # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity) | |
| if DataTypeSize[operation.A.element] == 32: | |
| elements_per_8b_md = 2 | |
| elif DataTypeSize[operation.A.element] == 4: | |
| elements_per_8b_md = 8 | |
| else: | |
| elements_per_8b_md = 4 | |
| smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \ | |
| DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \ | |
| cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md | |
| else: | |
| # Few BLAS3 operations only have A tensor | |
| smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * cta_shape[2] // 8 + \ | |
| DataTypeSize[operation.A.element] * cta_shape[1] * cta_shape[2] // 8 | |
| smem_usage = smem_per_stage * stages | |
| return (smem_usage >> 10) | |
| ################################################################################################### | |