File size: 4,557 Bytes
1faccd4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import torch
from verl.utils.kernel.fp8_kernel import scaled_fp8_blockwise
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))
class FP8QuantizerHelper:
def __init__(self, quant_config):
self.quant_config = quant_config
def should_quantize_param(self, param_name):
"""Determine whether to quantize to FP8 based on parameter name
Quantization rules:
- Must end with .weight (exclude bias)
- Exclude embedding layers
- Exclude normalization layers
- Exclude output layer (lm_head)
"""
# Must be a weight parameter
if not param_name.endswith(".weight"):
return False
# Layer types to exclude
exclude_patterns = [
"embed_tokens", # Embedding layer
"lm_head", # Output layer
"layernorm", # LayerNorm
"norm", # Various Norm layers
"ln_", # LayerNorm variants
"embeddings", # Embeddings
"mlp.gate.weight", # MoE router
]
# Check if matches exclude patterns
param_lower = param_name.lower()
for pattern in exclude_patterns:
if pattern in param_lower:
return False
# Layer types to include (Linear layers)
include_patterns = [
"q_proj", # Query projection
"k_proj", # Key projection
"v_proj", # Value projection
"o_proj", # Output projection
"gate_proj", # Gate projection (for MLP)
"up_proj", # Up projection (for MLP)
"down_proj", # Down projection (for MLP)
"fc1", # Fully connected 1
"fc2", # Fully connected 2
"mlp", # MLP layers
]
# Check if matches include patterns
for pattern in include_patterns:
if pattern in param_lower:
logger.debug(f"Will quantize FP8: {param_name}")
return True
# Do not quantize by default
logger.debug(f"Skip quantization: {param_name}")
return False
def quant_weights_by_name(self, weights, dtype=torch.bfloat16):
"""FP8 quantization based on parameter name using a memory-efficient generator.
Args:
weights: Generator or iterable of (name, tensor) pairs
dtype: Data type for intermediate computation
Yields:
Tuples of (name, tensor) for each weight and its scale
"""
if isinstance(self.quant_config, dict):
weight_block_size = self.quant_config.get("weight_block_size")
else:
weight_block_size = getattr(self.quant_config, "weight_block_size", None)
if weight_block_size is None:
raise ValueError("weight_block_size not found in quant_config")
for k, v in weights:
# Check if quantization is needed
if not self.should_quantize_param(k):
yield (k, v)
continue
# Quantize to FP8
try:
if torch.distributed.get_rank() == 0:
logger.debug(f"Quantizing to FP8 blockwise: {k}")
param_lp, param_scale = scaled_fp8_blockwise(
v.to(dtype),
weight_block_size=weight_block_size,
)
param_scale = param_scale.squeeze(-1)
# Yield the quantized weight and scale
yield (k, param_lp)
yield (k + "_scale_inv", param_scale)
# Explicitly delete to help GC
del param_lp, param_scale
except Exception as e:
logger.error(f"Failed to quantize {k}: {e}")
# If quantization fails, use original weights
yield (k, v)
|