File size: 11,140 Bytes
cb2428f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
# Copyright (c) Alibaba, Inc. and its affiliates.
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List, Optional

import torch
import torch.nn as nn
from tqdm import tqdm

from swift.llm import (ExportArguments, HfConfigFactory, MaxLengthError, ProcessorMixin, deep_getattr, get_model_arch,
                       is_moe_model, load_dataset, prepare_model_template, save_checkpoint, to_device)
from swift.utils import find_layers, get_logger, get_model_parameter_info

logger = get_logger()


class QuantEngine(ProcessorMixin):

    def __init__(self, args: ExportArguments):
        self.args = args
        kwargs = {}
        if args.quant_method == 'awq':
            from awq import AutoAWQForCausalLM
            kwargs['automodel_class'] = AutoAWQForCausalLM
        self.model, self.template = prepare_model_template(args, **kwargs)
        self.template.set_mode('train')
        self.model.config.use_cache = False
        HfConfigFactory.set_model_config_attr(self.model, 'use_cache', False)
        self.processor = self.template.processor
        args.save_args()

    def quantize(self):
        args = self.args
        if args.quant_bits is None:
            raise ValueError(f'Please set the quant_bits. args.quant_bits: {args.quant_bits}')
        if args.quant_method == 'awq':
            self.template.model = self.model.model
            self.awq_model_quantize()
            self.model.save_quantized(
                args.output_dir, safetensors=args.safe_serialization, shard_size=args.max_shard_size)
        elif args.quant_method == 'gptq':
            self.template.model = self.model
            gptq_quantizer = self.gptq_model_quantize()
            gptq_quantizer.save(
                self.model,
                args.output_dir,
                safe_serialization=args.safe_serialization,
                max_shard_size=args.max_shard_size)
        elif args.quant_method == 'bnb':
            self.model.save_pretrained(
                args.output_dir, safe_serialization=args.safe_serialization, max_shard_size=args.max_shard_size)
        else:
            raise ValueError(f'args.quant_method: {args.quant_method}')

        logger.info(f'model: {self.model}')
        logger.info(f'model_parameter_info: {get_model_parameter_info(self.model)}')
        save_checkpoint(
            None,
            self.processor,
            args.output_dir,
            model_dirs=[args.model_dir],
            additional_saved_files=self.model.model_meta.additional_saved_files)
        logger.info(f'Successfully quantized the model and saved in {args.output_dir}.')

    @torch.inference_mode()
    def _prepare_gptq_dataset(self, examples: List[Dict[str, torch.LongTensor]], batch_size: int = 1, *args, **kwargs):
        res = []
        for start in tqdm(range(0, len(examples), batch_size)):
            batched_inputs = examples[start:start + batch_size]
            inputs = to_device(self.template.data_collator(batched_inputs), self.model.device)
            if self.model.model_meta.is_multimodal:
                _, inputs = self.template.pre_forward_hook(self.model, None, inputs)
            res.append(to_device(inputs, 'cpu'))
        return res

    @torch.inference_mode()
    def _get_quant_dataset(self, *args, **kwargs):
        args = self.args
        assert args.quant_method in {'awq', 'gptq'}
        template = self.template
        n_samples = args.quant_n_samples
        block_size = args.max_length

        # only use train_dataset
        dataset = load_dataset(
            args.dataset, split_dataset_ratio=0, shuffle=args.dataset_shuffle, **args.get_dataset_kwargs())[0]
        logger.info(f'quant_dataset: {dataset}')
        dataset = dataset.shuffle()

        samples = []
        i = 0
        prog_bar = tqdm(total=n_samples, dynamic_ncols=True)
        is_multimodal = self.model.model_meta.is_multimodal
        for data in dataset:
            try:
                inputs = template.encode(data)
            except MaxLengthError:
                continue
            if is_multimodal and args.quant_method == 'gptq':
                inputs.pop('labels', None)
                samples.append(inputs)
            else:
                input_ids = inputs['input_ids']
                samples += input_ids
            i += 1
            prog_bar.update()
            if i == n_samples:
                break
        if is_multimodal and args.quant_method == 'gptq':
            return samples
        # now concatenate all samples and split according to block size
        n_split = len(samples) // block_size
        logger.info(f'Split into {n_split} blocks')
        res = []
        for i in range(n_split):
            input_ids = samples[i * block_size:(i + 1) * block_size]
            if args.quant_method == 'gptq':
                res.append({'input_ids': input_ids})
            else:
                res.append(torch.tensor(input_ids)[None])
        return res

    @staticmethod
    @contextmanager
    def _patch_awq_move_embed(awq_model):
        _origin_move_embed = awq_model.move_embed

        def _move_embed(model, device: str):
            if hasattr(model, '_hf_hook') and device != 'cpu':
                return
            _origin_move_embed(model, device)

        awq_model.move_embed = _move_embed
        try:
            yield
        finally:
            awq_model.move_embed = _origin_move_embed

    def get_awq_modules_to_not_convert(self):
        block_name = self.get_block_name_to_quantize(self.model)
        block = deep_getattr(self.model, block_name)[-1]
        prefix, experts = self._get_experts(block)
        num_experts = len(experts)

        def cond(name, module):
            if isinstance(module, nn.Linear) and module.out_features == num_experts:
                return True
            return False

        return find_layers(self.model, cond, min_name_len=2)  # min_name_len: fix Qwen3-MoE

    def awq_model_quantize(self) -> None:
        from awq.quantize import quantizer
        from transformers import AwqConfig

        args = self.args
        logger.info(f'Quantization dataset: {args.dataset}')
        _origin_get_calib_dataset = quantizer.get_calib_dataset
        quantizer.get_calib_dataset = self._get_quant_dataset
        quant_config = {
            'zero_point': True,
            'q_group_size': args.group_size,
            'w_bit': args.quant_bits,
            'version': 'GEMM'
        }
        if is_moe_model(self.model):
            quant_config['modules_to_not_convert'] = self.get_awq_modules_to_not_convert()
        logger.info(f'quant_config: {quant_config}')
        logger.info('Start quantizing the model...')
        with self._patch_awq_move_embed(self.model):
            self.model.quantize(
                self.tokenizer, quant_config=quant_config, n_parallel_calib_samples=args.quant_batch_size)
        quantizer.get_calib_dataset = _origin_get_calib_dataset  # recover
        if self.model.quant_config.modules_to_not_convert:
            model_arch = get_model_arch(args.model_meta.model_arch)
            lm_head_key = model_arch.lm_head or 'lm_head'
            self.model.quant_config.modules_to_not_convert.append(lm_head_key)

    @contextmanager
    def _patch_gptq(self):
        from optimum.gptq import quantizer
        _get_dataset_origin = quantizer.get_dataset
        _prepare_dataset_origin = quantizer.prepare_dataset
        quantizer.get_dataset = self._get_quant_dataset
        quantizer.prepare_dataset = self._prepare_gptq_dataset
        try:
            yield
        finally:
            quantizer.get_dataset = _get_dataset_origin
            quantizer.prepare_dataset = _prepare_dataset_origin

    @staticmethod
    def get_block_name_to_quantize(model: nn.Module) -> Optional[str]:
        model_arch = get_model_arch(model.model_meta.model_arch)
        prefix = ''
        if hasattr(model_arch, 'language_model'):
            assert len(model_arch.language_model) == 1, f'mllm_arch.language_model: {model_arch.language_model}'
            prefix = model_arch.language_model[0]
            model = deep_getattr(model, prefix)

        module_lists = []
        for n, m in model.named_modules():
            if (isinstance(m, (nn.ModuleList, nn.Sequential)) and len(m) >= 10
                    and 'mlp' not in m[0].__class__.__name__.lower()):  # fix moe
                module_lists.append((n, m))
        if module_lists:
            module_list = max(module_lists, key=lambda x: len(x[1]))
            return f'{prefix}.{module_list[0]}'.strip('.')

    @staticmethod
    def _get_experts(block):
        for n, m in block.named_modules():
            if isinstance(m, (nn.ModuleList, nn.Sequential)):
                return n, m

    @staticmethod
    def get_modules_in_block_to_quantize(model, block_name: str):
        if not is_moe_model(model):
            return
        from optimum.gptq.utils import get_layers
        # Do not quantize the gate part.
        block = deep_getattr(model, block_name)[-1]
        prefix, experts = QuantEngine._get_experts(block)
        num_experts = len(experts)

        layers = get_layers(block)
        res = []
        experts = defaultdict(list)
        experts_idx = None
        for name, layer in layers.items():
            if name.startswith(prefix):
                suffix = name.rsplit('.', 1)[-1]
                experts[suffix].append(name)
                experts_idx = len(res)
            elif layer.out_features not in {1, num_experts}:
                res.append([name])
        res[experts_idx:experts_idx] = experts.values()
        return res

    def gptq_model_quantize(self):
        from optimum.gptq import GPTQQuantizer
        args = self.args
        logger.info(f'Quantization dataset: {args.dataset}')
        block_name_to_quantize = self.get_block_name_to_quantize(self.model)
        modules_in_block_to_quantize = self.get_modules_in_block_to_quantize(self.model, block_name_to_quantize)
        logger.info(f'block_name_to_quantize: {block_name_to_quantize}')
        logger.info(f'modules_in_block_to_quantize: {modules_in_block_to_quantize}')
        with self._patch_gptq():
            gptq_quantizer = GPTQQuantizer(
                bits=args.quant_bits,
                group_size=args.group_size,
                dataset=','.join(args.dataset),
                batch_size=args.quant_batch_size,
                block_name_to_quantize=block_name_to_quantize,
                modules_in_block_to_quantize=modules_in_block_to_quantize)
            gptq_quantizer.serialization_keys.append('block_name_to_quantize')
            logger.info('Start quantizing the model...')
            logger.warning('The process of packing the model takes a long time and there is no progress bar. '
                           'Please be patient and wait...')
            gptq_quantizer.quantize_model(self.model, self.tokenizer)
            self.model.config.quantization_config.pop('dataset', None)
        return gptq_quantizer


def quantize_model(args: ExportArguments):
    QuantEngine(args).quantize()