File size: 2,631 Bytes
9823a7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from enum import IntEnum


class OptimizeStrategy(IntEnum):
    SingleBatchDecodeOnly = 0
    ContigousBatching = 1

    def is_single_batch_decode_only(self):
        return self == OptimizeStrategy.SingleBatchDecodeOnly

    def is_contigous_batching(self):
        return self == OptimizeStrategy.ContigousBatching


class TransformKind(IntEnum):
    NonTransform = 0
    InterWarpTransform = 1
    IntraWarpTransform = 2
    LDMatrixTransform = 3

    def is_non_transform(self):
        return self == TransformKind.NonTransform

    def is_inter_warp_transform(self):
        return self == TransformKind.InterWarpTransform

    def is_intra_warp_transform(self):
        return self == TransformKind.IntraWarpTransform

    def is_ld_matrix_transform(self):
        return self == TransformKind.LDMatrixTransform


class BackendKind(IntEnum):
    TIR = 0
    TileLang = 1

    def is_tir_backend(self):
        return self == BackendKind.TIR

    def is_tilelang_backend(self):
        return self == BackendKind.TileLang


class QuantizationMemoryStage(IntEnum):
    # Represents in which stage the dequantize operation is performed
    #
    # 1. For devices without async copy, we can use a simple dequantize schedule
    # without shared memory prefetch.
    #     quantized weight
    #         |
    #         V
    #     dequantized in register
    #         |
    #         V
    #     save into shared memory
    #         |
    #         V
    #     compute
    #
    # 2. For A100 Like devices, the shared memory prefetch(async) is required
    # to achieve optimal performance.
    #     quantized weight
    #         |
    #         V
    #     shared memory prefetch (with async copy)
    #         |
    #         V
    #     dequantized into shared memory
    #         |
    #         V
    #     compute
    # 3. For A100 Like devices, the shared memory prefetch(async) is required
    # to achieve optimal performance.
    #     quantized weight
    #         |
    #         V
    #     shared memory prefetch (with async copy)
    #         |
    #         V
    #     LDMatrix into warp memory
    #         |
    #         V
    #     Dequantize
    #         |
    #         V
    #     Compute
    Local = 0
    Shared = 1
    Global = 2

    def is_quant_memory_in_local(self):
        return self == QuantizationMemoryStage.Local

    def is_quant_memory_in_shared(self):
        return self == QuantizationMemoryStage.Shared

    def is_quant_memory_in_global(self):
        return self == QuantizationMemoryStage.Global