arithmetic-grpo / verl /utils /fp8_utils.py
LeTue09's picture
initial clean commit
1faccd4
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import torch
from verl.utils.kernel.fp8_kernel import scaled_fp8_blockwise
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))
class FP8QuantizerHelper:
def __init__(self, quant_config):
self.quant_config = quant_config
def should_quantize_param(self, param_name):
"""Determine whether to quantize to FP8 based on parameter name
Quantization rules:
- Must end with .weight (exclude bias)
- Exclude embedding layers
- Exclude normalization layers
- Exclude output layer (lm_head)
"""
# Must be a weight parameter
if not param_name.endswith(".weight"):
return False
# Layer types to exclude
exclude_patterns = [
"embed_tokens", # Embedding layer
"lm_head", # Output layer
"layernorm", # LayerNorm
"norm", # Various Norm layers
"ln_", # LayerNorm variants
"embeddings", # Embeddings
"mlp.gate.weight", # MoE router
]
# Check if matches exclude patterns
param_lower = param_name.lower()
for pattern in exclude_patterns:
if pattern in param_lower:
return False
# Layer types to include (Linear layers)
include_patterns = [
"q_proj", # Query projection
"k_proj", # Key projection
"v_proj", # Value projection
"o_proj", # Output projection
"gate_proj", # Gate projection (for MLP)
"up_proj", # Up projection (for MLP)
"down_proj", # Down projection (for MLP)
"fc1", # Fully connected 1
"fc2", # Fully connected 2
"mlp", # MLP layers
]
# Check if matches include patterns
for pattern in include_patterns:
if pattern in param_lower:
logger.debug(f"Will quantize FP8: {param_name}")
return True
# Do not quantize by default
logger.debug(f"Skip quantization: {param_name}")
return False
def quant_weights_by_name(self, weights, dtype=torch.bfloat16):
"""FP8 quantization based on parameter name using a memory-efficient generator.
Args:
weights: Generator or iterable of (name, tensor) pairs
dtype: Data type for intermediate computation
Yields:
Tuples of (name, tensor) for each weight and its scale
"""
if isinstance(self.quant_config, dict):
weight_block_size = self.quant_config.get("weight_block_size")
else:
weight_block_size = getattr(self.quant_config, "weight_block_size", None)
if weight_block_size is None:
raise ValueError("weight_block_size not found in quant_config")
for k, v in weights:
# Check if quantization is needed
if not self.should_quantize_param(k):
yield (k, v)
continue
# Quantize to FP8
try:
if torch.distributed.get_rank() == 0:
logger.debug(f"Quantizing to FP8 blockwise: {k}")
param_lp, param_scale = scaled_fp8_blockwise(
v.to(dtype),
weight_block_size=weight_block_size,
)
param_scale = param_scale.squeeze(-1)
# Yield the quantized weight and scale
yield (k, param_lp)
yield (k + "_scale_inv", param_scale)
# Explicitly delete to help GC
del param_lp, param_scale
except Exception as e:
logger.error(f"Failed to quantize {k}: {e}")
# If quantization fails, use original weights
yield (k, v)