File size: 4,212 Bytes
14e1dca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass
from typing import Literal, Optional
import torch
from swift.llm import HfConfigFactory
@dataclass
class QuantizeArguments:
"""
QuantizeArguments is a dataclass that holds the configuration for model quantization.
Args:
quant_method (Literal['bnb', 'hqq', 'eetq']): The quantization method to be used.
quant_bits (Literal[1, 2, 3, 4, 8]): The number of bits to use for quantization.
hqq_axis (Optional[int]): The axis for hqq quantization.
bnb_4bit_compute_dtype (Literal['float16', 'bfloat16', 'float32', None]):
The compute dtype for bnb 4-bit quantization.
bnb_4bit_quant_type (Literal['fp4', 'nf4']): The quantization type for bnb 4-bit quantization.
bnb_4bit_use_double_quant (bool): Whether to use double quantization for bnb 4-bit quantization.
bnb_4bit_quant_storage (Optional[str]): This sets the storage type to pack the quantized 4-bit params.
"""
# awq, gptq, and aqlm need to be pre-quantized models.
# It can be detected automatically, without the need to pass in.
# while bnb, hqq, and eetq can be quantized during SFT using the original models.
quant_method: Literal['bnb', 'hqq', 'eetq', 'quanto'] = None
# bnb: 4,8; hqq: 1,2,3,4,8'; eetq: 8
# awq: 4; gptq: 2,3,4,8
quant_bits: Literal[1, 2, 3, 4, 8, 'float8'] = None
# hqq
hqq_axis: Optional[int] = None
# bnb
bnb_4bit_compute_dtype: Literal['float16', 'bfloat16', 'float32', None] = None
bnb_4bit_quant_type: Literal['fp4', 'nf4'] = 'nf4'
bnb_4bit_use_double_quant: bool = True
bnb_4bit_quant_storage: Optional[str] = None
def get_quantization_config(self):
if self.quant_method is None or self.quant_method in {'awq', 'gptq'}:
return None
assert self.quant_method in {'bnb', 'hqq', 'eetq', 'quanto'}
if self.quant_bits is None:
raise ValueError(f'Please set the quant_bits. args.quant_bits: {self.quant_bits}')
if self.quant_method == 'bnb':
if self.quant_bits == 4:
load_in_4bit, load_in_8bit = True, False
elif self.quant_bits == 8:
load_in_4bit, load_in_8bit = False, True
else:
raise ValueError(f'bnb not support quant_bits: {self.quant_bits}')
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
bnb_4bit_compute_dtype=self.bnb_4bit_compute_dtype,
bnb_4bit_quant_type=self.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=self.bnb_4bit_use_double_quant,
bnb_4bit_quant_storage=self.bnb_4bit_quant_storage)
elif self.quant_method == 'hqq':
from transformers import HqqConfig
quantization_config = HqqConfig(nbits=self.quant_bits, axis=self.hqq_axis)
elif self.quant_method == 'quanto':
from transformers import QuantoConfig
if self.quant_bits == 8:
weights = 'int8'
elif self.quant_bits == 'float8':
weights = 'float8'
elif self.quant_bits == 4:
weights = 'int4'
elif self.quant_bits == 2:
weights = 'int2'
else:
raise ValueError('quanto quantization only support quant bits 2/4/8/float8')
quantization_config = QuantoConfig(weights=weights)
else: # 'eetq'
from transformers import EetqConfig
quantization_config = EetqConfig(f'int{self.quant_bits}')
return quantization_config
def __post_init__(self):
if self.bnb_4bit_compute_dtype is None:
if self.torch_dtype in {torch.float16, torch.float32}:
self.bnb_4bit_compute_dtype = torch.float32
elif self.torch_dtype == torch.bfloat16:
self.bnb_4bit_compute_dtype = torch.bfloat16
self.bnb_4bit_compute_dtype: torch.dtype = HfConfigFactory.to_torch_dtype(self.bnb_4bit_compute_dtype)
|