|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from enum import IntFlag, auto |
|
|
from typing import Optional |
|
|
|
|
|
from strenum import StrEnum |
|
|
|
|
|
from .._utils import BaseEnumMeta |
|
|
|
|
|
|
|
|
class QuantAlgo(StrEnum, metaclass=BaseEnumMeta): |
|
|
W8A16 = auto() |
|
|
W4A16 = auto() |
|
|
W4A16_AWQ = auto() |
|
|
W4A8_AWQ = auto() |
|
|
W4A16_GPTQ = auto() |
|
|
W8A8_SQ_PER_CHANNEL = auto() |
|
|
W8A8_SQ_PER_TENSOR_PLUGIN = auto() |
|
|
W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN = auto() |
|
|
W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN = auto() |
|
|
W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN = auto() |
|
|
FP8 = auto() |
|
|
FP8_PER_CHANNEL_PER_TOKEN = auto() |
|
|
INT8 = auto() |
|
|
|
|
|
|
|
|
QUANT_ALGO_LIST = list(set(QuantAlgo) - {QuantAlgo.INT8}) |
|
|
KV_CACHE_QUANT_ALGO_LIST = [QuantAlgo.FP8, QuantAlgo.INT8] |
|
|
W8A8_SQ_PLUGIN_LIST = [ |
|
|
QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN, |
|
|
QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN, |
|
|
QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN, |
|
|
QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN, |
|
|
] |
|
|
|
|
|
|
|
|
class QuantMode(IntFlag): |
|
|
|
|
|
|
|
|
|
|
|
INT4_WEIGHTS = auto() |
|
|
|
|
|
INT8_WEIGHTS = auto() |
|
|
|
|
|
ACTIVATIONS = auto() |
|
|
|
|
|
PER_CHANNEL = auto() |
|
|
|
|
|
PER_TOKEN = auto() |
|
|
|
|
|
PER_GROUP = auto() |
|
|
|
|
|
INT8_KV_CACHE = auto() |
|
|
|
|
|
FP8_KV_CACHE = auto() |
|
|
|
|
|
FP8_QDQ = auto() |
|
|
|
|
|
FP8_ROWWISE = auto() |
|
|
|
|
|
|
|
|
COUNT = auto() |
|
|
|
|
|
|
|
|
WEIGHTS_AND_ACTIVATIONS = INT4_WEIGHTS | INT8_WEIGHTS | ACTIVATIONS |
|
|
|
|
|
VALID_FLAGS = COUNT - 1 |
|
|
|
|
|
|
|
|
def _all(self, bits, mask=VALID_FLAGS): |
|
|
return (self & mask) == bits |
|
|
|
|
|
|
|
|
def _any(self, bits): |
|
|
return (self & bits) != 0 |
|
|
|
|
|
def is_int8_weight_only(self): |
|
|
return self._all(self.INT8_WEIGHTS, self.WEIGHTS_AND_ACTIVATIONS) |
|
|
|
|
|
def is_int4_weight_only(self): |
|
|
return self._all(self.INT4_WEIGHTS, self.WEIGHTS_AND_ACTIVATIONS) |
|
|
|
|
|
def is_weight_only(self): |
|
|
return self.is_int4_weight_only() or self.is_int8_weight_only() |
|
|
|
|
|
def is_int4_weight_only_per_group(self): |
|
|
return self.is_int4_weight_only() and self._any(self.PER_GROUP) |
|
|
|
|
|
def has_act_and_weight_quant(self): |
|
|
return self._all(self.INT8_WEIGHTS | self.ACTIVATIONS, |
|
|
self.WEIGHTS_AND_ACTIVATIONS) |
|
|
|
|
|
def has_act_or_weight_quant(self): |
|
|
return self._any(self.INT4_WEIGHTS | self.INT8_WEIGHTS |
|
|
| self.ACTIVATIONS) |
|
|
|
|
|
def has_per_token_dynamic_scaling(self): |
|
|
return self._any(self.PER_TOKEN) |
|
|
|
|
|
def has_act_static_scaling(self): |
|
|
return not self.has_per_token_dynamic_scaling( |
|
|
) and not self.has_fp8_rowwise() |
|
|
|
|
|
def has_per_channel_scaling(self): |
|
|
return self._any(self.PER_CHANNEL) |
|
|
|
|
|
def has_per_group_scaling(self): |
|
|
return self._any(self.PER_GROUP) |
|
|
|
|
|
def has_int8_kv_cache(self): |
|
|
return self._any(self.INT8_KV_CACHE) |
|
|
|
|
|
def has_fp8_kv_cache(self): |
|
|
return self._any(self.FP8_KV_CACHE) |
|
|
|
|
|
def has_kv_cache_quant(self): |
|
|
return self.has_int8_kv_cache() or self.has_fp8_kv_cache() |
|
|
|
|
|
def has_fp8_qdq(self): |
|
|
return self._any(self.FP8_QDQ) |
|
|
|
|
|
def has_fp8_rowwise(self): |
|
|
return self._any(self.FP8_ROWWISE) |
|
|
|
|
|
def has_any_quant(self): |
|
|
return self._any(self.INT4_WEIGHTS | self.INT8_WEIGHTS |
|
|
| self.ACTIVATIONS |
|
|
| self.INT8_KV_CACHE | self.FP8_KV_CACHE |
|
|
| self.FP8_QDQ | self.FP8_ROWWISE) |
|
|
|
|
|
def set_int8_kv_cache(self): |
|
|
return self | self.INT8_KV_CACHE |
|
|
|
|
|
def set_fp8_kv_cache(self): |
|
|
return self | self.FP8_KV_CACHE |
|
|
|
|
|
def set_fp8_qdq(self): |
|
|
return self | self.FP8_QDQ |
|
|
|
|
|
def set_fp8_rowwise(self): |
|
|
return self | self.FP8_ROWWISE | self.PER_TOKEN | self.PER_CHANNEL |
|
|
|
|
|
@staticmethod |
|
|
def from_description(quantize_weights=False, |
|
|
quantize_activations=False, |
|
|
per_token=False, |
|
|
per_channel=False, |
|
|
per_group=False, |
|
|
use_int4_weights=False, |
|
|
use_int8_kv_cache=False, |
|
|
use_fp8_kv_cache=False, |
|
|
use_fp8_qdq=False, |
|
|
use_fp8_rowwise=False): |
|
|
|
|
|
def raise_error(): |
|
|
raise ValueError(f"Unsupported combination of QuantMode args: " |
|
|
f"{quantize_weights=}, " |
|
|
f"{quantize_activations=}, " |
|
|
f"{per_token=}, " |
|
|
f"{per_channel=}, " |
|
|
f"{per_group=}, " |
|
|
f"{use_int4_weights=}" |
|
|
f"{use_int8_kv_cache=}" |
|
|
f"{use_fp8_kv_cache=}" |
|
|
f"{use_fp8_qdq=}" |
|
|
f"{use_fp8_rowwise=}") |
|
|
|
|
|
|
|
|
if quantize_activations and not quantize_weights: |
|
|
raise_error() |
|
|
|
|
|
|
|
|
if (per_token or per_channel) and not (quantize_weights |
|
|
and quantize_activations): |
|
|
raise_error() |
|
|
|
|
|
mode = QuantMode(0) |
|
|
|
|
|
|
|
|
if quantize_weights and use_int4_weights: |
|
|
mode = mode | QuantMode.INT4_WEIGHTS |
|
|
elif quantize_weights: |
|
|
mode = mode | QuantMode.INT8_WEIGHTS |
|
|
|
|
|
|
|
|
if quantize_activations: |
|
|
mode = mode | QuantMode.ACTIVATIONS |
|
|
|
|
|
|
|
|
if per_channel: |
|
|
mode = mode | QuantMode.PER_CHANNEL |
|
|
if per_token: |
|
|
mode = mode | QuantMode.PER_TOKEN |
|
|
if per_group: |
|
|
mode = mode | QuantMode.PER_GROUP |
|
|
|
|
|
|
|
|
if use_int8_kv_cache: |
|
|
mode = mode | QuantMode.INT8_KV_CACHE |
|
|
|
|
|
|
|
|
if use_fp8_kv_cache: |
|
|
mode = mode | QuantMode.FP8_KV_CACHE |
|
|
|
|
|
if use_fp8_qdq: |
|
|
mode = mode | QuantMode.FP8_QDQ |
|
|
|
|
|
if use_fp8_rowwise: |
|
|
mode = mode | QuantMode.FP8_ROWWISE | QuantMode.PER_TOKEN | QuantMode.PER_CHANNEL |
|
|
|
|
|
return mode |
|
|
|
|
|
@staticmethod |
|
|
def use_smooth_quant(per_token=False, per_channel=False): |
|
|
return QuantMode.from_description(True, True, per_token, per_channel) |
|
|
|
|
|
@staticmethod |
|
|
def use_weight_only(use_int4_weights=False, per_group=False): |
|
|
return QuantMode.from_description(quantize_weights=True, |
|
|
quantize_activations=False, |
|
|
per_token=False, |
|
|
per_channel=False, |
|
|
per_group=per_group, |
|
|
use_int4_weights=use_int4_weights) |
|
|
|
|
|
@staticmethod |
|
|
def from_quant_algo( |
|
|
quant_algo: Optional[QuantAlgo], |
|
|
kv_cache_quant_algo: Optional[QuantAlgo] = None, |
|
|
) -> "QuantMode": |
|
|
assert quant_algo is None or quant_algo in QUANT_ALGO_LIST |
|
|
assert kv_cache_quant_algo is None or kv_cache_quant_algo in KV_CACHE_QUANT_ALGO_LIST |
|
|
if quant_algo == QuantAlgo.W8A16: |
|
|
quant_mode = QuantMode.use_weight_only(use_int4_weights=False) |
|
|
elif quant_algo == QuantAlgo.W4A16: |
|
|
quant_mode = QuantMode.use_weight_only(use_int4_weights=True) |
|
|
elif quant_algo == QuantAlgo.W4A16_AWQ: |
|
|
quant_mode = QuantMode.use_weight_only(use_int4_weights=True, |
|
|
per_group=True) |
|
|
elif quant_algo == QuantAlgo.W4A8_AWQ: |
|
|
quant_mode = QuantMode.use_weight_only(use_int4_weights=True, |
|
|
per_group=True) |
|
|
elif quant_algo == QuantAlgo.W4A16_GPTQ: |
|
|
quant_mode = QuantMode.use_weight_only(use_int4_weights=True, |
|
|
per_group=True) |
|
|
elif quant_algo == QuantAlgo.W8A8_SQ_PER_CHANNEL: |
|
|
quant_mode = QuantMode.use_smooth_quant(per_token=False, |
|
|
per_channel=True) |
|
|
elif quant_algo == QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN: |
|
|
quant_mode = QuantMode.use_smooth_quant(per_token=False, |
|
|
per_channel=False) |
|
|
elif quant_algo == QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN: |
|
|
quant_mode = QuantMode.use_smooth_quant(per_token=True, |
|
|
per_channel=True) |
|
|
elif quant_algo == QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN: |
|
|
quant_mode = QuantMode.use_smooth_quant(per_token=False, |
|
|
per_channel=True) |
|
|
elif quant_algo == QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN: |
|
|
quant_mode = QuantMode.use_smooth_quant(per_token=True, |
|
|
per_channel=False) |
|
|
elif quant_algo == QuantAlgo.FP8: |
|
|
quant_mode = QuantMode.from_description(use_fp8_qdq=True) |
|
|
elif quant_algo == QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN: |
|
|
quant_mode = QuantMode.from_description(use_fp8_rowwise=True) |
|
|
else: |
|
|
quant_mode = QuantMode(0) |
|
|
|
|
|
if kv_cache_quant_algo == QuantAlgo.INT8: |
|
|
quant_mode = quant_mode.set_int8_kv_cache() |
|
|
elif kv_cache_quant_algo == QuantAlgo.FP8: |
|
|
quant_mode = quant_mode.set_fp8_kv_cache() |
|
|
|
|
|
return quant_mode |
|
|
|
|
|
def to_dict(self): |
|
|
return { |
|
|
'use_smooth_quant': |
|
|
self.has_act_and_weight_quant(), |
|
|
'per_channel': |
|
|
self.has_per_channel_scaling(), |
|
|
'per_token': |
|
|
self.has_per_token_dynamic_scaling(), |
|
|
'per_group': |
|
|
self.has_per_group_scaling(), |
|
|
'int8_kv_cache': |
|
|
self.has_int8_kv_cache(), |
|
|
'enable_fp8': |
|
|
self.has_fp8_qdq(), |
|
|
'enable_fp8_rowwise': |
|
|
self.has_fp8_rowwise(), |
|
|
'fp8_kv_cache': |
|
|
self.has_fp8_kv_cache(), |
|
|
'use_weight_only': |
|
|
self.is_weight_only(), |
|
|
'weight_only_precision': |
|
|
'int8' if self.is_int8_weight_only() else 'int4', |
|
|
} |
|
|
|