File size: 11,905 Bytes
5000658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import IntFlag, auto
from typing import Optional
from strenum import StrEnum
from .._utils import BaseEnumMeta
class QuantAlgo(StrEnum, metaclass=BaseEnumMeta):
W8A16 = auto()
W4A16 = auto()
W4A16_AWQ = auto()
W4A8_AWQ = auto()
W4A16_GPTQ = auto()
W8A8_SQ_PER_CHANNEL = auto()
W8A8_SQ_PER_TENSOR_PLUGIN = auto()
W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN = auto()
W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN = auto()
W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN = auto()
FP8 = auto()
FP8_PER_CHANNEL_PER_TOKEN = auto()
INT8 = auto()
QUANT_ALGO_LIST = list(set(QuantAlgo) - {QuantAlgo.INT8})
KV_CACHE_QUANT_ALGO_LIST = [QuantAlgo.FP8, QuantAlgo.INT8]
W8A8_SQ_PLUGIN_LIST = [
QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN,
QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN,
QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN,
QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN,
]
class QuantMode(IntFlag):
# [WARNING] KEEP BELOW DEFINITION IN SYNC WITH cpp/tensorrt_llm/common/quantization.h
# The weights are quantized to 4 bits.
INT4_WEIGHTS = auto()
# The weights are quantized to 8 bits.
INT8_WEIGHTS = auto()
# The activations are quantized.
ACTIVATIONS = auto()
# The method uses one scaling factor per channel. It's pre-computed (static) from the weights.
PER_CHANNEL = auto()
# The method uses one scaling factor per token. It's computed on-the-fly.
PER_TOKEN = auto()
# The method uses one scaling factor per group. It's pre-computed (static) from the weights.
PER_GROUP = auto()
# The KV cache is quantized in INT8.
INT8_KV_CACHE = auto()
# The KV cache is quantized in FP8.
FP8_KV_CACHE = auto()
# FP8 QDQ
FP8_QDQ = auto()
# FP8 rowwise
FP8_ROWWISE = auto()
# The smallest power-of-two that is not used by a flag. Do not call auto() after that line.
COUNT = auto()
# Bitmask to detect if weights, activations or both are quantized.
WEIGHTS_AND_ACTIVATIONS = INT4_WEIGHTS | INT8_WEIGHTS | ACTIVATIONS
# The mask of all valid flags.
VALID_FLAGS = COUNT - 1
# All the bits set? You can restrict the test to the bits indicated by "mask".
def _all(self, bits, mask=VALID_FLAGS):
return (self & mask) == bits
# Is one of the bits of the mask set?
def _any(self, bits):
return (self & bits) != 0
def is_int8_weight_only(self):
return self._all(self.INT8_WEIGHTS, self.WEIGHTS_AND_ACTIVATIONS)
def is_int4_weight_only(self):
return self._all(self.INT4_WEIGHTS, self.WEIGHTS_AND_ACTIVATIONS)
def is_weight_only(self):
return self.is_int4_weight_only() or self.is_int8_weight_only()
def is_int4_weight_only_per_group(self):
return self.is_int4_weight_only() and self._any(self.PER_GROUP)
def has_act_and_weight_quant(self):
return self._all(self.INT8_WEIGHTS | self.ACTIVATIONS,
self.WEIGHTS_AND_ACTIVATIONS)
def has_act_or_weight_quant(self):
return self._any(self.INT4_WEIGHTS | self.INT8_WEIGHTS
| self.ACTIVATIONS)
def has_per_token_dynamic_scaling(self):
return self._any(self.PER_TOKEN)
def has_act_static_scaling(self):
return not self.has_per_token_dynamic_scaling(
) and not self.has_fp8_rowwise()
def has_per_channel_scaling(self):
return self._any(self.PER_CHANNEL)
def has_per_group_scaling(self):
return self._any(self.PER_GROUP)
def has_int8_kv_cache(self):
return self._any(self.INT8_KV_CACHE)
def has_fp8_kv_cache(self):
return self._any(self.FP8_KV_CACHE)
def has_kv_cache_quant(self):
return self.has_int8_kv_cache() or self.has_fp8_kv_cache()
def has_fp8_qdq(self):
return self._any(self.FP8_QDQ)
def has_fp8_rowwise(self):
return self._any(self.FP8_ROWWISE)
def has_any_quant(self):
return self._any(self.INT4_WEIGHTS | self.INT8_WEIGHTS
| self.ACTIVATIONS
| self.INT8_KV_CACHE | self.FP8_KV_CACHE
| self.FP8_QDQ | self.FP8_ROWWISE)
def set_int8_kv_cache(self):
return self | self.INT8_KV_CACHE
def set_fp8_kv_cache(self):
return self | self.FP8_KV_CACHE
def set_fp8_qdq(self):
return self | self.FP8_QDQ
def set_fp8_rowwise(self):
return self | self.FP8_ROWWISE | self.PER_TOKEN | self.PER_CHANNEL
@staticmethod
def from_description(quantize_weights=False,
quantize_activations=False,
per_token=False,
per_channel=False,
per_group=False,
use_int4_weights=False,
use_int8_kv_cache=False,
use_fp8_kv_cache=False,
use_fp8_qdq=False,
use_fp8_rowwise=False):
def raise_error():
raise ValueError(f"Unsupported combination of QuantMode args: "
f"{quantize_weights=}, "
f"{quantize_activations=}, "
f"{per_token=}, "
f"{per_channel=}, "
f"{per_group=}, "
f"{use_int4_weights=}"
f"{use_int8_kv_cache=}"
f"{use_fp8_kv_cache=}"
f"{use_fp8_qdq=}"
f"{use_fp8_rowwise=}")
# We must quantize weights when we quantize activations.
if quantize_activations and not quantize_weights:
raise_error()
# If we set per_token or per_channel, we must quantize both weights and activations.
if (per_token or per_channel) and not (quantize_weights
and quantize_activations):
raise_error()
mode = QuantMode(0)
# Do we quantize the weights - if so, do we use INT4 or INT8?
if quantize_weights and use_int4_weights:
mode = mode | QuantMode.INT4_WEIGHTS
elif quantize_weights:
mode = mode | QuantMode.INT8_WEIGHTS
# Do we quantize the activations?
if quantize_activations:
mode = mode | QuantMode.ACTIVATIONS
# Per-channel/per-token/per-group additional flags.
if per_channel:
mode = mode | QuantMode.PER_CHANNEL
if per_token:
mode = mode | QuantMode.PER_TOKEN
if per_group:
mode = mode | QuantMode.PER_GROUP
# Int8 KV cache
if use_int8_kv_cache:
mode = mode | QuantMode.INT8_KV_CACHE
# FP8 KV cache
if use_fp8_kv_cache:
mode = mode | QuantMode.FP8_KV_CACHE
if use_fp8_qdq:
mode = mode | QuantMode.FP8_QDQ
if use_fp8_rowwise:
mode = mode | QuantMode.FP8_ROWWISE | QuantMode.PER_TOKEN | QuantMode.PER_CHANNEL
return mode
@staticmethod
def use_smooth_quant(per_token=False, per_channel=False):
return QuantMode.from_description(True, True, per_token, per_channel)
@staticmethod
def use_weight_only(use_int4_weights=False, per_group=False):
return QuantMode.from_description(quantize_weights=True,
quantize_activations=False,
per_token=False,
per_channel=False,
per_group=per_group,
use_int4_weights=use_int4_weights)
@staticmethod
def from_quant_algo(
quant_algo: Optional[QuantAlgo],
kv_cache_quant_algo: Optional[QuantAlgo] = None,
) -> "QuantMode":
assert quant_algo is None or quant_algo in QUANT_ALGO_LIST
assert kv_cache_quant_algo is None or kv_cache_quant_algo in KV_CACHE_QUANT_ALGO_LIST
if quant_algo == QuantAlgo.W8A16:
quant_mode = QuantMode.use_weight_only(use_int4_weights=False)
elif quant_algo == QuantAlgo.W4A16:
quant_mode = QuantMode.use_weight_only(use_int4_weights=True)
elif quant_algo == QuantAlgo.W4A16_AWQ:
quant_mode = QuantMode.use_weight_only(use_int4_weights=True,
per_group=True)
elif quant_algo == QuantAlgo.W4A8_AWQ:
quant_mode = QuantMode.use_weight_only(use_int4_weights=True,
per_group=True)
elif quant_algo == QuantAlgo.W4A16_GPTQ:
quant_mode = QuantMode.use_weight_only(use_int4_weights=True,
per_group=True)
elif quant_algo == QuantAlgo.W8A8_SQ_PER_CHANNEL:
quant_mode = QuantMode.use_smooth_quant(per_token=False,
per_channel=True)
elif quant_algo == QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN:
quant_mode = QuantMode.use_smooth_quant(per_token=False,
per_channel=False)
elif quant_algo == QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN:
quant_mode = QuantMode.use_smooth_quant(per_token=True,
per_channel=True)
elif quant_algo == QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN:
quant_mode = QuantMode.use_smooth_quant(per_token=False,
per_channel=True)
elif quant_algo == QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN:
quant_mode = QuantMode.use_smooth_quant(per_token=True,
per_channel=False)
elif quant_algo == QuantAlgo.FP8:
quant_mode = QuantMode.from_description(use_fp8_qdq=True)
elif quant_algo == QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN:
quant_mode = QuantMode.from_description(use_fp8_rowwise=True)
else:
quant_mode = QuantMode(0)
if kv_cache_quant_algo == QuantAlgo.INT8:
quant_mode = quant_mode.set_int8_kv_cache()
elif kv_cache_quant_algo == QuantAlgo.FP8:
quant_mode = quant_mode.set_fp8_kv_cache()
return quant_mode
def to_dict(self):
return {
'use_smooth_quant':
self.has_act_and_weight_quant(),
'per_channel':
self.has_per_channel_scaling(),
'per_token':
self.has_per_token_dynamic_scaling(),
'per_group':
self.has_per_group_scaling(),
'int8_kv_cache':
self.has_int8_kv_cache(),
'enable_fp8':
self.has_fp8_qdq(),
'enable_fp8_rowwise':
self.has_fp8_rowwise(),
'fp8_kv_cache':
self.has_fp8_kv_cache(),
'use_weight_only':
self.is_weight_only(),
'weight_only_precision':
'int8' if self.is_int8_weight_only() else 'int4',
}
|