| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import os |
|
|
| import torch |
|
|
| from verl.utils.kernel.fp8_kernel import scaled_fp8_blockwise |
|
|
| logger = logging.getLogger(__file__) |
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) |
|
|
|
|
| class FP8QuantizerHelper: |
| def __init__(self, quant_config): |
| self.quant_config = quant_config |
|
|
| def should_quantize_param(self, param_name): |
| """Determine whether to quantize to FP8 based on parameter name |
| |
| Quantization rules: |
| - Must end with .weight (exclude bias) |
| - Exclude embedding layers |
| - Exclude normalization layers |
| - Exclude output layer (lm_head) |
| """ |
| |
| if not param_name.endswith(".weight"): |
| return False |
|
|
| |
| exclude_patterns = [ |
| "embed_tokens", |
| "lm_head", |
| "layernorm", |
| "norm", |
| "ln_", |
| "embeddings", |
| "mlp.gate.weight", |
| ] |
|
|
| |
| param_lower = param_name.lower() |
| for pattern in exclude_patterns: |
| if pattern in param_lower: |
| return False |
|
|
| |
| include_patterns = [ |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "o_proj", |
| "gate_proj", |
| "up_proj", |
| "down_proj", |
| "fc1", |
| "fc2", |
| "mlp", |
| ] |
|
|
| |
| for pattern in include_patterns: |
| if pattern in param_lower: |
| logger.debug(f"Will quantize FP8: {param_name}") |
| return True |
|
|
| |
| logger.debug(f"Skip quantization: {param_name}") |
| return False |
|
|
| def quant_weights_by_name(self, weights, dtype=torch.bfloat16): |
| """FP8 quantization based on parameter name using a memory-efficient generator. |
| |
| |
| Args: |
| weights: Generator or iterable of (name, tensor) pairs |
| dtype: Data type for intermediate computation |
| |
| Yields: |
| Tuples of (name, tensor) for each weight and its scale |
| """ |
| if isinstance(self.quant_config, dict): |
| weight_block_size = self.quant_config.get("weight_block_size") |
| else: |
| weight_block_size = getattr(self.quant_config, "weight_block_size", None) |
|
|
| if weight_block_size is None: |
| raise ValueError("weight_block_size not found in quant_config") |
|
|
| for k, v in weights: |
| |
| if not self.should_quantize_param(k): |
| yield (k, v) |
| continue |
|
|
| |
| try: |
| if torch.distributed.get_rank() == 0: |
| logger.debug(f"Quantizing to FP8 blockwise: {k}") |
|
|
| param_lp, param_scale = scaled_fp8_blockwise( |
| v.to(dtype), |
| weight_block_size=weight_block_size, |
| ) |
| param_scale = param_scale.squeeze(-1) |
|
|
| |
| yield (k, param_lp) |
| yield (k + "_scale_inv", param_scale) |
|
|
| |
| del param_lp, param_scale |
|
|
| except Exception as e: |
| logger.error(f"Failed to quantize {k}: {e}") |
| |
| yield (k, v) |
|
|