File size: 12,151 Bytes
76f9669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import logging
from typing import Optional, Tuple, Union

import torch
from compressed_tensors.modeling import (
    IMPL_ATTR,
    KV_CACHE_ATTR,
    QuantizedAttentionImpl,
    QuantizedKVCache,
)
from compressed_tensors.quantization import (
    ActivationOrdering,
    DynamicType,
    QuantizationArgs,
    QuantizationMetadata,
    QuantizationScheme,
    QuantizationStatus,
    QuantizationStrategy,
)
from compressed_tensors.quantization.lifecycle.forward import (
    wrap_module_forward_quantized,
)
from compressed_tensors.quantization.utils import strategy_cdiv
from compressed_tensors.utils import (
    disable_hf_hook,
    get_execution_device,
    get_head_dim,
    get_num_attn_heads,
    get_num_kv_heads,
    register_offload_parameter,
)
from torch.nn import Module, Parameter


__all__ = [
    "initialize_module_for_quantization",
    "is_attention_module",
    "initialize_qparams",
    "initialize_attn_qparams",
]


_LOGGER = logging.getLogger(__name__)


def initialize_module_for_quantization(
    module: Module,
    scheme: Optional[QuantizationScheme] = None,
    force_zero_point: bool = True,
):
    """
    Attaches appropriate scales, zero points, and observers to a layer
    given its target quantization scheme.

    Previously initialized scales and zero points will be removed from
    module if they no longer apply to the scheme

    :param module: module to set for calibration
    :param scheme: scheme to use for quantization. if None is provided,
        will attempt to use scheme stored in the module under `quantization_scheme`,
        if not provided, the layer will be skipped
    :param force_zero_point: whether to force initialization of a zero point for
        symmetric quantization
    """
    scheme = scheme or getattr(module, "quantization_scheme", None)
    if scheme is None:
        return

    QuantizationMetadata.clear_all_qparams(module)

    if is_attention_module(module):
        # quantized actions based on calltime status
        initialize_attn_qparams(module, scheme, force_zero_point)

    else:
        if not isinstance(module, torch.nn.Linear):
            _LOGGER.warning(f"Attempting to quantize module of type {type(module)}")

        # use weight to determine observed shapes and dtype
        if hasattr(module, "weight"):
            weight = module.weight
            assert isinstance(weight, torch.Tensor)
        else:
            # Note that a weight is required for both weight and activation
            # quantization in order to know the dtype of activation scales
            _LOGGER.warning(
                f"module type {type(module)} targeted for quantization but "
                f"has no attribute weight, skipping quantization for {type(module)}"
            )
            return

        if scheme.input_activations is not None:
            initialize_qparams(
                module,
                "input",
                scheme.input_activations,
                observed_shape=weight.shape[-1:],
                observed_dtype=weight.dtype,
                force_zero_point=force_zero_point,
            )

        if scheme.weights is not None:
            initialize_qparams(
                module,
                "weight",
                scheme.weights,
                observed_shape=weight.shape,
                observed_dtype=weight.dtype,
                force_zero_point=force_zero_point,
            )

        if scheme.output_activations is not None:
            initialize_qparams(
                module,
                "output",
                scheme.output_activations,
                observed_shape=weight.shape[:-1],
                observed_dtype=weight.dtype,
                force_zero_point=force_zero_point,
            )

        with disable_hf_hook(module):
            # wrap forward call of module to perform
            # quantized actions based on calltime status
            wrap_module_forward_quantized(module, scheme)

    module.quantization_scheme = scheme
    module.quantization_status = QuantizationStatus.INITIALIZED


def is_attention_module(module: Module):
    return "attention" in module.__class__.__name__.lower() and (
        hasattr(module, "k_proj")
        or hasattr(module, "v_proj")
        or hasattr(module, "qkv_proj")
    )


def initialize_qparams(
    module: Module,
    base_name: str,
    quantization_args: QuantizationArgs,
    observed_shape: Tuple[Union[int, None]],
    observed_dtype: torch.dtype,
    force_zero_point: bool = True,
):
    """
    Initialize quantization parameters for a given basename according to the passed
    quantization args. The shape and dtype of the observed weight/activation must also
    be provided.

    Scales will always be initialized. Global scales are initialized depending on args.
    Zero points will be initialized if not symmetric or if `force_zero_point` is True.

    :param module: module to register qparams to
    :param base_name: base name of qparams, for example "input", "weight", "k", "v"
    :param quantization_args: arguments for quantization
    :param observed_shape: last (right-most) known dimensions of the observed weight/act
    :param observed_dtype: dtype of the observed weight/actt
    :param force_zero_point: force the zero_point parameter to be initialized
    """
    strategy = quantization_args.strategy
    dynamic = quantization_args.dynamic
    actorder = quantization_args.actorder
    device = get_execution_device(module)  # avoid performing intialization ops on cpu

    # Skip all intialization for fully dynamic quantization
    if dynamic is True:
        return

    # 0. Create global scale for tensor-group quantization
    if strategy == QuantizationStrategy.TENSOR_GROUP:
        init_global_scale = Parameter(
            torch.empty(1, dtype=torch.float32, device=device),
            requires_grad=False,
        )
        register_offload_parameter(
            module, f"{base_name}_global_scale", init_global_scale
        )

    # Skip scale/zp initialization for locally dynamic quantization
    if dynamic == DynamicType.LOCAL:
        return

    # 1. Infer expected scale/zp shape
    if strategy == QuantizationStrategy.TENSOR:
        expected_shape = (1,)

    elif strategy == QuantizationStrategy.TOKEN:
        raise ValueError("Cannot perform static token quantization")

    elif strategy == QuantizationStrategy.CHANNEL:
        if len(observed_shape) < 2:
            raise ValueError("Channel quant requires at least 2 observed dimensions")

        expected_shape = (observed_shape[-2], 1)

    elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
        assert quantization_args.group_size is not None
        if len(observed_shape) < 1:
            raise ValueError("Group quant requires at least 1 observed dimension")

        group_size = quantization_args.group_size
        num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy)
        expected_shape = (*observed_shape[:-1], num_groups)

        # initialize activation ordering if applicable
        if actorder == ActivationOrdering.GROUP:
            init_g_idx = Parameter(
                torch.full((observed_shape[-1],), -1, device=device, dtype=torch.int),
                requires_grad=False,
            )
            register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)

    elif strategy == QuantizationStrategy.BLOCK:
        assert quantization_args.block_structure is not None
        if len(observed_shape) < 2:
            raise ValueError("Block quant requires at least 2 observed dimensions")

        block_structure = quantization_args.block_structure
        num_rows = strategy_cdiv(observed_shape[-2], block_structure[-2], strategy)
        num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy)
        expected_shape = (num_rows, num_cols)

    elif strategy == QuantizationStrategy.ATTN_HEAD:
        # (batch_size, num_attention_heads, seq_len, head_dim)
        if len(observed_shape) < 3:
            raise ValueError("Attention quant requires at least 3 observed dimensions")

        expected_shape = (observed_shape[-3], 1, 1)

    else:
        assert False, f"Unknown strategy {strategy}"

    # 2. Identify quantization scale and zp dtype
    scale_dtype = observed_dtype
    if scale_dtype not in [
        torch.float16,
        torch.bfloat16,
        torch.float32,
        torch.float64,
    ]:
        scale_dtype = torch.float16

    # 3. Initializes scale/zp for the module
    init_scale = Parameter(
        torch.empty(expected_shape, dtype=scale_dtype, device=device),
        requires_grad=False,
    )
    register_offload_parameter(module, f"{base_name}_scale", init_scale)

    if force_zero_point or not quantization_args.symmetric:
        init_zero_point = Parameter(
            torch.zeros(
                expected_shape, device=device, dtype=quantization_args.zp_dtype
            ),
            requires_grad=False,
        )
        register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)


def initialize_attn_qparams(
    module: Module, scheme: QuantizationScheme, force_zero_point: bool
):
    """Initlaize k_scale, v_scale for self_attn"""

    impl: Optional[QuantizedAttentionImpl] = getattr(module, IMPL_ATTR, None)
    kv_cache: Optional[QuantizedKVCache] = getattr(module, KV_CACHE_ATTR, None)

    if impl is None and kv_cache is None:
        raise ValueError(
            f"Attention module has quantization scheme but no {IMPL_ATTR} "
            f"or {KV_CACHE_ATTR} attributes. Please ensure that these "
            "attributes are initialized using `apply_quantization_config`."
        )

    _validate_attention_scheme(scheme)

    # extract shapes from config
    config = kv_cache.config
    num_attn_heads = get_num_attn_heads(config)
    num_kv_heads = get_num_kv_heads(config)
    head_dim = get_head_dim(config)

    # (batch_size, num_heads, slen, head_dim)
    q_observed_shape = (num_attn_heads, None, head_dim)
    kv_observed_shape = (num_kv_heads, None, head_dim)
    observed_dtype = next(module.parameters()).dtype

    if impl is not None:
        initialize_qparams(
            module,
            "q",
            scheme.input_activations,
            observed_shape=q_observed_shape,
            observed_dtype=observed_dtype,
            force_zero_point=force_zero_point,
        )

    if kv_cache is not None:
        initialize_qparams(
            module,
            "k",
            scheme.input_activations,
            observed_shape=kv_observed_shape,
            observed_dtype=observed_dtype,
            force_zero_point=force_zero_point,
        )
        initialize_qparams(
            module,
            "v",
            scheme.input_activations,
            observed_shape=kv_observed_shape,
            observed_dtype=observed_dtype,
            force_zero_point=force_zero_point,
        )


def _validate_attention_scheme(scheme: QuantizationScheme):
    if scheme.weights is not None:
        raise ValueError(
            "Cannot apply weight quantization to attention. "
            "Instead, target the (q|k|v)_proj submodule layers of attention"
        )

    if scheme.input_activations is None:
        raise ValueError(
            "Cannot apply attention quantization without specifying input activations"
        )

    if scheme.output_activations is not None:
        raise ValueError("Cannot apply output quantization to attention")