File size: 2,631 Bytes
9823a7e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from enum import IntEnum
class OptimizeStrategy(IntEnum):
SingleBatchDecodeOnly = 0
ContigousBatching = 1
def is_single_batch_decode_only(self):
return self == OptimizeStrategy.SingleBatchDecodeOnly
def is_contigous_batching(self):
return self == OptimizeStrategy.ContigousBatching
class TransformKind(IntEnum):
NonTransform = 0
InterWarpTransform = 1
IntraWarpTransform = 2
LDMatrixTransform = 3
def is_non_transform(self):
return self == TransformKind.NonTransform
def is_inter_warp_transform(self):
return self == TransformKind.InterWarpTransform
def is_intra_warp_transform(self):
return self == TransformKind.IntraWarpTransform
def is_ld_matrix_transform(self):
return self == TransformKind.LDMatrixTransform
class BackendKind(IntEnum):
TIR = 0
TileLang = 1
def is_tir_backend(self):
return self == BackendKind.TIR
def is_tilelang_backend(self):
return self == BackendKind.TileLang
class QuantizationMemoryStage(IntEnum):
# Represents in which stage the dequantize operation is performed
#
# 1. For devices without async copy, we can use a simple dequantize schedule
# without shared memory prefetch.
# quantized weight
# |
# V
# dequantized in register
# |
# V
# save into shared memory
# |
# V
# compute
#
# 2. For A100 Like devices, the shared memory prefetch(async) is required
# to achieve optimal performance.
# quantized weight
# |
# V
# shared memory prefetch (with async copy)
# |
# V
# dequantized into shared memory
# |
# V
# compute
# 3. For A100 Like devices, the shared memory prefetch(async) is required
# to achieve optimal performance.
# quantized weight
# |
# V
# shared memory prefetch (with async copy)
# |
# V
# LDMatrix into warp memory
# |
# V
# Dequantize
# |
# V
# Compute
Local = 0
Shared = 1
Global = 2
def is_quant_memory_in_local(self):
return self == QuantizationMemoryStage.Local
def is_quant_memory_in_shared(self):
return self == QuantizationMemoryStage.Shared
def is_quant_memory_in_global(self):
return self == QuantizationMemoryStage.Global
|