|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import TYPE_CHECKING, Optional |
|
|
|
|
|
from .base import HfQuantizer |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from ..modeling_utils import PreTrainedModel |
|
|
|
|
|
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging |
|
|
from .quantizers_utils import get_module_from_name |
|
|
|
|
|
|
|
|
if is_torch_available(): |
|
|
import torch |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class FbgemmFp8HfQuantizer(HfQuantizer): |
|
|
""" |
|
|
FP8 quantization using fbgemm kernels |
|
|
""" |
|
|
|
|
|
requires_parameters_quantization = True |
|
|
requires_calibration = False |
|
|
|
|
|
required_packages = ["fbgemm-gpu", "accelerate"] |
|
|
|
|
|
def __init__(self, quantization_config, **kwargs): |
|
|
super().__init__(quantization_config, **kwargs) |
|
|
self.quantization_config = quantization_config |
|
|
|
|
|
def validate_environment(self, *args, **kwargs): |
|
|
if not is_torch_available(): |
|
|
raise ImportError( |
|
|
"Using fbgemm fp8 quantization requires torch >= 2.1.0" |
|
|
"Please install the latest version of torch ( pip install --upgrade torch )" |
|
|
) |
|
|
if not is_fbgemm_gpu_available(): |
|
|
raise ImportError( |
|
|
"Using fbgemm fp8 quantization requires fbgemm-gpu library" |
|
|
"Please install the latest version of fbgemm-gpu library by following : https://pytorch.org/FBGEMM/fbgemm_gpu-development/InstallationInstructions.html#fbgemm-gpu-install-libraries" |
|
|
) |
|
|
|
|
|
if not is_accelerate_available("0.32.2"): |
|
|
raise ImportError( |
|
|
"Loading an FP8 quantized model requires accelerate > 0.32.1 (`pip install --upgrade accelerate`)" |
|
|
) |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
raise RuntimeError("Using FP8 quantized models with fbgemm kernels requires a GPU") |
|
|
|
|
|
compute_capability = torch.cuda.get_device_capability() |
|
|
major, minor = compute_capability |
|
|
if major < 9: |
|
|
raise ValueError( |
|
|
"FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)" |
|
|
) |
|
|
|
|
|
device_map = kwargs.get("device_map") |
|
|
if device_map is None: |
|
|
logger.warning_once( |
|
|
"You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set " |
|
|
"your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. " |
|
|
) |
|
|
elif device_map is not None: |
|
|
if ( |
|
|
not self.pre_quantized |
|
|
and isinstance(device_map, dict) |
|
|
and ("cpu" in device_map.values() or "disk" in device_map.values()) |
|
|
): |
|
|
raise ValueError( |
|
|
"You are attempting to load an FP8 model with a device_map that contains a CPU or disk device." |
|
|
"This is not supported when the model is quantized on the fly. " |
|
|
"Please use a quantized checkpoint or remove the CPU or disk device from the device_map." |
|
|
) |
|
|
|
|
|
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype": |
|
|
if dtype is None: |
|
|
dtype = torch.bfloat16 |
|
|
logger.info( |
|
|
"Overriding dtype=%s with `dtype=torch.bloat16` due to " |
|
|
"requirements of `fbgemm-gpu` to enable model loading in fp8. " |
|
|
"Pass your own dtype to specify the dtype of the remaining non-linear layers or pass" |
|
|
" dtype=torch.bfloat16 to remove this warning.", |
|
|
dtype, |
|
|
) |
|
|
elif dtype == torch.float16: |
|
|
raise ValueError( |
|
|
"You cannot use FP8 with dtype=torch.float16.We recommend you passing dtype=torch.bfloat16" |
|
|
) |
|
|
return dtype |
|
|
|
|
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: |
|
|
from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts |
|
|
|
|
|
module, tensor_name = get_module_from_name(model, param_name) |
|
|
|
|
|
if isinstance(module, FbgemmFp8Linear): |
|
|
if self.pre_quantized or tensor_name == "bias": |
|
|
return False |
|
|
else: |
|
|
return True |
|
|
if isinstance(module, FbgemmFp8Llama4TextExperts): |
|
|
if self.pre_quantized or tensor_name == "bias": |
|
|
return False |
|
|
else: |
|
|
return True |
|
|
return False |
|
|
|
|
|
def create_quantized_param( |
|
|
self, |
|
|
model: "PreTrainedModel", |
|
|
param_value: "torch.Tensor", |
|
|
param_name: str, |
|
|
target_device: "torch.device", |
|
|
**kwargs, |
|
|
): |
|
|
from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts |
|
|
|
|
|
module, tensor_name = get_module_from_name(model, param_name) |
|
|
|
|
|
|
|
|
if isinstance(module, FbgemmFp8Linear): |
|
|
if self.pre_quantized or tensor_name == "bias": |
|
|
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn: |
|
|
raise ValueError("Expect quantized weights but got an unquantized weight") |
|
|
else: |
|
|
if tensor_name == "weight_scale": |
|
|
raise ValueError("Expect unquantized weights but got a quantized weight_scale") |
|
|
if isinstance(module, FbgemmFp8Llama4TextExperts): |
|
|
if not (self.pre_quantized or tensor_name == "bias"): |
|
|
if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale": |
|
|
raise ValueError("Expect unquantized weights but got a quantized weight_scale") |
|
|
|
|
|
if isinstance(module, FbgemmFp8Llama4TextExperts): |
|
|
if tensor_name == "gate_up_proj": |
|
|
|
|
|
|
|
|
transposed_param = param_value.transpose(1, 2) |
|
|
|
|
|
|
|
|
original_shape = transposed_param.shape |
|
|
flattened_param = transposed_param.reshape(-1, original_shape[-1]) |
|
|
|
|
|
|
|
|
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param) |
|
|
|
|
|
|
|
|
new_value = new_value_flat.reshape(original_shape) |
|
|
new_value = new_value.transpose(1, 2) |
|
|
weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1]) |
|
|
elif tensor_name == "down_proj": |
|
|
|
|
|
|
|
|
transposed_param = param_value.transpose(1, 2) |
|
|
|
|
|
|
|
|
original_shape = transposed_param.shape |
|
|
flattened_param = transposed_param.reshape(-1, original_shape[-1]) |
|
|
|
|
|
|
|
|
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param) |
|
|
|
|
|
|
|
|
new_value = new_value_flat.reshape(original_shape) |
|
|
new_value = new_value.transpose(1, 2) |
|
|
weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1) |
|
|
|
|
|
module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(weight_scale.to(target_device)) |
|
|
else: |
|
|
new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value) |
|
|
module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter( |
|
|
weight_scale.view(weight_scale.shape[0], 1).to(target_device) |
|
|
) |
|
|
|
|
|
module._parameters[tensor_name] = torch.nn.Parameter(new_value.to(target_device)) |
|
|
|
|
|
del param_name |
|
|
|
|
|
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): |
|
|
return model |
|
|
|
|
|
def _process_model_before_weight_loading( |
|
|
self, |
|
|
model: "PreTrainedModel", |
|
|
keep_in_fp32_modules: Optional[list[str]] = None, |
|
|
**kwargs, |
|
|
): |
|
|
from ..integrations import replace_with_fbgemm_fp8_linear |
|
|
|
|
|
tp_plan = model._tp_plan |
|
|
self.modules_to_not_convert = self.get_modules_to_not_convert( |
|
|
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules |
|
|
) |
|
|
|
|
|
config = model.config |
|
|
model = replace_with_fbgemm_fp8_linear( |
|
|
model, |
|
|
modules_to_not_convert=self.modules_to_not_convert, |
|
|
quantization_config=self.quantization_config, |
|
|
pre_quantized=self.pre_quantized, |
|
|
config=config, |
|
|
tp_plan=tp_plan, |
|
|
) |
|
|
|
|
|
model.config.quantization_config = self.quantization_config |
|
|
|
|
|
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]: |
|
|
from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts |
|
|
|
|
|
not_missing_keys = [] |
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, (FbgemmFp8Linear, FbgemmFp8Llama4TextExperts)): |
|
|
for missing in missing_keys: |
|
|
if ( |
|
|
(name in missing or name in f"{prefix}.{missing}") |
|
|
and not missing.endswith(".weight") |
|
|
and not missing.endswith(".bias") |
|
|
): |
|
|
not_missing_keys.append(missing) |
|
|
return [k for k in missing_keys if k not in not_missing_keys] |
|
|
|
|
|
def update_tp_plan(self, config): |
|
|
if "Llama4" in config.__class__.__name__: |
|
|
text_plan = { |
|
|
|
|
|
|
|
|
|
|
|
"layers.*.self_attn.q_proj.weight": "local_colwise", |
|
|
"layers.*.self_attn.q_proj.weight_scale": "local_colwise", |
|
|
"layers.*.self_attn.k_proj.weight": "local_colwise", |
|
|
"layers.*.self_attn.k_proj.weight_scale": "local_colwise", |
|
|
"layers.*.self_attn.v_proj.weight": "local_colwise", |
|
|
"layers.*.self_attn.v_proj.weight_scale": "local_colwise", |
|
|
"layers.*.self_attn.o_proj.weight": "local_rowwise", |
|
|
"layers.*.self_attn": "gather", |
|
|
|
|
|
"layers.*.input_layernorm.weight": "sequence_parallel", |
|
|
"layers.*.post_attention_layernorm.weight": "sequence_parallel", |
|
|
"norm.weight": "sequence_parallel", |
|
|
|
|
|
|
|
|
|
|
|
"layers.*.feed_forward.shared_expert.gate_proj.weight": "local_colwise", |
|
|
"layers.*.feed_forward.shared_expert.gate_proj.weight_scale": "local_colwise", |
|
|
"layers.*.feed_forward.shared_expert.up_proj.weight": "local_colwise", |
|
|
"layers.*.feed_forward.shared_expert.up_proj.weight_scale": "local_colwise", |
|
|
"layers.*.feed_forward.shared_expert.down_proj.weight": "local_rowwise", |
|
|
"layers.*.feed_forward.experts": "local", |
|
|
"layers.*.feed_forward": "gather", |
|
|
"layers.*.feed_forward.experts.*.gate_proj.weight": "local_colwise", |
|
|
"layers.*.feed_forward.experts.*.gate_proj.weight_scale": "local_colwise", |
|
|
"layers.*.feed_forward.experts.*.up_proj.weight": "local_colwise", |
|
|
"layers.*.feed_forward.experts.*.up_proj.weight_scale": "local_colwise", |
|
|
"layers.*.feed_forward.experts.*.down_proj.weight": "local_rowwise", |
|
|
|
|
|
|
|
|
"layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise", |
|
|
"layers.*.feed_forward.experts.gate_up_proj_scale": "local_packed_rowwise", |
|
|
"layers.*.feed_forward.experts.down_proj": "local_colwise", |
|
|
} |
|
|
if config.get_text_config() is not None: |
|
|
config.get_text_config().base_model_tp_plan = text_plan |
|
|
else: |
|
|
config.base_model_tp_plan = text_plan |
|
|
return config |
|
|
|
|
|
return config |
|
|
|
|
|
def is_serializable(self, safe_serialization=None): |
|
|
return True |
|
|
|
|
|
@property |
|
|
def is_trainable(self) -> bool: |
|
|
return False |
|
|
|