|
|
import fnmatch |
|
|
|
|
|
from .._utils import get_init_params |
|
|
from ..layers import (MLP, Attention, ColumnLinear, Embedding, GatedMLP, |
|
|
LayerNorm, RmsNorm, RowLinear) |
|
|
from ..layers.moe import MixtureOfExperts |
|
|
from ..models.modeling_utils import QuantConfig |
|
|
from ..parameter import Parameter |
|
|
from .layers import (FP8Linear, FP8RowLinear, Fp8RowwiseGatedMLP, Fp8RowwiseMLP, |
|
|
Fp8RowwiseRmsNorm, Int8SmoothQuantLinear, |
|
|
Int8SmoothQuantRowLinear, SmoothQuantAttention, |
|
|
SmoothQuantGatedMLP, SmoothQuantLayerNorm, SmoothQuantMLP, |
|
|
SmoothQuantRmsNorm, WeightOnlyGroupwiseQuantColumnLinear, |
|
|
WeightOnlyGroupwiseQuantRowLinear, |
|
|
WeightOnlyQuantColumnLinear, WeightOnlyQuantEmbedding, |
|
|
WeightOnlyQuantRowLinear) |
|
|
from .mode import W8A8_SQ_PLUGIN_LIST, QuantAlgo, QuantMode |
|
|
|
|
|
|
|
|
def quantize_layers( |
|
|
model, |
|
|
quant_config: QuantConfig, |
|
|
quant_map, |
|
|
preprocess_init_params=None, |
|
|
): |
|
|
exclude_modules = quant_config.exclude_modules or [ |
|
|
'*lm_head', |
|
|
'*router', |
|
|
'*vocab_embedding', |
|
|
'*position_embedding', |
|
|
'*block_embedding', |
|
|
'*shared_expert_gate', |
|
|
] |
|
|
|
|
|
for name, module, parent in model.named_modules_with_parent(): |
|
|
module_name = name.rsplit('.', 1)[-1] |
|
|
is_excluded = False |
|
|
for exclude_module in exclude_modules: |
|
|
if fnmatch.fnmatchcase(name, exclude_module): |
|
|
is_excluded = True |
|
|
|
|
|
|
|
|
if isinstance(module, MixtureOfExperts): |
|
|
init_params = get_init_params(module, MixtureOfExperts) |
|
|
init_params["quant_mode"] = QuantMode(0) |
|
|
original_layer = MixtureOfExperts(**init_params) |
|
|
if parent is not None: |
|
|
setattr(parent, module_name, original_layer) |
|
|
else: |
|
|
model = original_layer |
|
|
break |
|
|
if not is_excluded: |
|
|
quant_cls = None |
|
|
for cls in quant_map: |
|
|
if isinstance(module, cls): |
|
|
quant_cls = quant_map[cls] |
|
|
break |
|
|
|
|
|
if quant_cls is None: |
|
|
continue |
|
|
|
|
|
init_params = get_init_params(module, quant_cls) |
|
|
if "bias" in init_params: |
|
|
init_params["bias"] = init_params["bias"] is not None |
|
|
if isinstance(module, ColumnLinear): |
|
|
init_params[ |
|
|
"out_features"] = module.out_features * module.tp_size |
|
|
elif isinstance(module, RowLinear): |
|
|
init_params["in_features"] = module.in_features * module.tp_size |
|
|
if preprocess_init_params is not None: |
|
|
preprocess_init_params(init_params, name, module) |
|
|
quant_layer = quant_cls(**init_params) |
|
|
if parent is not None: |
|
|
setattr(parent, module_name, quant_layer) |
|
|
else: |
|
|
model = quant_layer |
|
|
|
|
|
setattr(model, 'quant_mode', quant_config.quant_mode) |
|
|
return model |
|
|
|
|
|
|
|
|
def weight_only_quantize(model, quant_config: QuantConfig): |
|
|
assert quant_config.quant_mode.is_weight_only() |
|
|
|
|
|
quant_map = { |
|
|
ColumnLinear: WeightOnlyQuantColumnLinear, |
|
|
RowLinear: WeightOnlyQuantRowLinear, |
|
|
Embedding: WeightOnlyQuantEmbedding, |
|
|
} |
|
|
|
|
|
def preprocess_init_params(init_params, name, module): |
|
|
init_params["quant_mode"] = quant_config.quant_mode |
|
|
if isinstance(module, ColumnLinear): |
|
|
module_name = name.rsplit('.', 1)[-1] |
|
|
init_params["transb"] = module_name == "lm_head" |
|
|
|
|
|
model = quantize_layers( |
|
|
model, |
|
|
quant_config, |
|
|
quant_map, |
|
|
preprocess_init_params, |
|
|
) |
|
|
return model |
|
|
|
|
|
|
|
|
def weight_only_groupwise_quantize(model, quant_config: QuantConfig): |
|
|
assert quant_config.quant_mode.is_weight_only() |
|
|
|
|
|
quant_map = { |
|
|
ColumnLinear: WeightOnlyGroupwiseQuantColumnLinear, |
|
|
RowLinear: WeightOnlyGroupwiseQuantRowLinear, |
|
|
} |
|
|
|
|
|
def preprocess_init_params(init_params, name, module): |
|
|
init_params["group_size"] = quant_config.group_size |
|
|
init_params["pre_quant_scale"] = quant_config.pre_quant_scale |
|
|
init_params["zero"] = quant_config.has_zero_point |
|
|
init_params[ |
|
|
"use_w4a8_awq"] = quant_config.quant_algo == QuantAlgo.W4A8_AWQ |
|
|
|
|
|
model = quantize_layers( |
|
|
model, |
|
|
quant_config, |
|
|
quant_map, |
|
|
preprocess_init_params, |
|
|
) |
|
|
return model |
|
|
|
|
|
|
|
|
def smooth_quantize_ootb( |
|
|
model, |
|
|
quant_config: QuantConfig, |
|
|
): |
|
|
quant_map = { |
|
|
ColumnLinear: Int8SmoothQuantLinear, |
|
|
RowLinear: Int8SmoothQuantRowLinear, |
|
|
} |
|
|
|
|
|
model = quantize_layers( |
|
|
model, |
|
|
quant_config, |
|
|
quant_map, |
|
|
) |
|
|
return model |
|
|
|
|
|
|
|
|
def smooth_quantize_plugin(model, quant_mode): |
|
|
quant_map = { |
|
|
RmsNorm: SmoothQuantRmsNorm, |
|
|
LayerNorm: SmoothQuantLayerNorm, |
|
|
GatedMLP: SmoothQuantGatedMLP, |
|
|
MLP: SmoothQuantMLP, |
|
|
Attention: SmoothQuantAttention, |
|
|
} |
|
|
for name, layer, parent in model.named_modules_with_parent(): |
|
|
layer_name = name.rsplit('.', 1)[-1] |
|
|
if layer_name in ['ln_f', 'ln_embed']: |
|
|
continue |
|
|
|
|
|
quant_cls = None |
|
|
for cls in quant_map: |
|
|
if isinstance(layer, cls): |
|
|
quant_cls = quant_map[cls] |
|
|
break |
|
|
|
|
|
if quant_cls is None: |
|
|
continue |
|
|
|
|
|
init_params = get_init_params(layer, quant_cls) |
|
|
init_params["quant_mode"] = quant_mode |
|
|
if isinstance(layer, Attention): |
|
|
init_params[ |
|
|
"num_attention_heads"] = layer.num_attention_heads * layer.tp_size |
|
|
quant_layer = quant_cls(**init_params) |
|
|
if parent is not None: |
|
|
setattr(parent, layer_name, quant_layer) |
|
|
else: |
|
|
model = quant_layer |
|
|
|
|
|
setattr(model, 'quant_mode', quant_mode) |
|
|
return model |
|
|
|
|
|
|
|
|
def smooth_quantize(model, quant_config: QuantConfig): |
|
|
assert quant_config.quant_mode.has_act_and_weight_quant() |
|
|
if quant_config.quant_algo in W8A8_SQ_PLUGIN_LIST: |
|
|
return smooth_quantize_plugin(model, quant_config.quant_mode) |
|
|
else: |
|
|
return smooth_quantize_ootb(model, quant_config) |
|
|
|
|
|
|
|
|
def fp8_quantize(model, quant_config: QuantConfig): |
|
|
assert quant_config.quant_mode.has_fp8_qdq() |
|
|
|
|
|
quant_map = { |
|
|
ColumnLinear: FP8Linear, |
|
|
RowLinear: FP8RowLinear, |
|
|
} |
|
|
|
|
|
model = quantize_layers( |
|
|
model, |
|
|
quant_config, |
|
|
quant_map, |
|
|
) |
|
|
return model |
|
|
|
|
|
|
|
|
def fp8_rowwise_quantize(model, quant_config: QuantConfig): |
|
|
assert quant_config.quant_mode.has_fp8_rowwise() |
|
|
|
|
|
quant_map = { |
|
|
RmsNorm: Fp8RowwiseRmsNorm, |
|
|
GatedMLP: Fp8RowwiseGatedMLP, |
|
|
MLP: Fp8RowwiseMLP, |
|
|
} |
|
|
for name, layer, parent in model.named_modules_with_parent(): |
|
|
layer_name = name.rsplit('.', 1)[-1] |
|
|
if layer_name in ['ln_f', 'ln_embed'] or "input_layernorm" in name: |
|
|
continue |
|
|
|
|
|
quant_cls = None |
|
|
for cls in quant_map: |
|
|
if isinstance(layer, cls): |
|
|
quant_cls = quant_map[cls] |
|
|
break |
|
|
|
|
|
if quant_cls is None: |
|
|
continue |
|
|
|
|
|
init_params = get_init_params(layer, quant_cls) |
|
|
init_params["quant_mode"] = quant_config.quant_mode |
|
|
quant_layer = quant_cls(**init_params, clamp_val=quant_config.clamp_val) |
|
|
if parent is not None: |
|
|
setattr(parent, layer_name, quant_layer) |
|
|
else: |
|
|
model = quant_layer |
|
|
|
|
|
setattr(model, 'quant_mode', quant_config.quant_mode) |
|
|
return model |
|
|
|
|
|
|
|
|
def kv_cache_quantize(model, quant_config: QuantConfig): |
|
|
assert quant_config.quant_mode.has_kv_cache_quant() |
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, (Attention, SmoothQuantAttention)): |
|
|
module.kv_cache_scaling_factor = Parameter(shape=(1, ), |
|
|
dtype='float32') |
|
|
|
|
|
|
|
|
def quantize(model, quant_config: QuantConfig): |
|
|
quant_mode = quant_config.quant_mode |
|
|
|
|
|
if quant_mode.has_fp8_qdq(): |
|
|
model = fp8_quantize(model, quant_config) |
|
|
elif quant_mode.has_fp8_rowwise(): |
|
|
model = fp8_rowwise_quantize(model, quant_config) |
|
|
elif quant_mode.has_act_and_weight_quant(): |
|
|
model = smooth_quantize(model, quant_config) |
|
|
elif quant_mode.is_weight_only(): |
|
|
if quant_mode.has_per_group_scaling(): |
|
|
model = weight_only_groupwise_quantize(model, quant_config) |
|
|
else: |
|
|
model = weight_only_quantize(model, quant_config) |
|
|
|
|
|
if quant_mode.has_kv_cache_quant(): |
|
|
model = kv_cache_quantize(model, quant_config) |
|
|
|
|
|
return model |
|
|
|