File size: 4,557 Bytes
1faccd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os

import torch

from verl.utils.kernel.fp8_kernel import scaled_fp8_blockwise

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))


class FP8QuantizerHelper:
    def __init__(self, quant_config):
        self.quant_config = quant_config

    def should_quantize_param(self, param_name):
        """Determine whether to quantize to FP8 based on parameter name

        Quantization rules:
        - Must end with .weight (exclude bias)
        - Exclude embedding layers
        - Exclude normalization layers
        - Exclude output layer (lm_head)
        """
        # Must be a weight parameter
        if not param_name.endswith(".weight"):
            return False

        # Layer types to exclude
        exclude_patterns = [
            "embed_tokens",  # Embedding layer
            "lm_head",  # Output layer
            "layernorm",  # LayerNorm
            "norm",  # Various Norm layers
            "ln_",  # LayerNorm variants
            "embeddings",  # Embeddings
            "mlp.gate.weight",  # MoE router
        ]

        # Check if matches exclude patterns
        param_lower = param_name.lower()
        for pattern in exclude_patterns:
            if pattern in param_lower:
                return False

        # Layer types to include (Linear layers)
        include_patterns = [
            "q_proj",  # Query projection
            "k_proj",  # Key projection
            "v_proj",  # Value projection
            "o_proj",  # Output projection
            "gate_proj",  # Gate projection (for MLP)
            "up_proj",  # Up projection (for MLP)
            "down_proj",  # Down projection (for MLP)
            "fc1",  # Fully connected 1
            "fc2",  # Fully connected 2
            "mlp",  # MLP layers
        ]

        # Check if matches include patterns
        for pattern in include_patterns:
            if pattern in param_lower:
                logger.debug(f"Will quantize FP8: {param_name}")
                return True

        # Do not quantize by default
        logger.debug(f"Skip quantization: {param_name}")
        return False

    def quant_weights_by_name(self, weights, dtype=torch.bfloat16):
        """FP8 quantization based on parameter name using a memory-efficient generator.


        Args:
            weights: Generator or iterable of (name, tensor) pairs
            dtype: Data type for intermediate computation

        Yields:
            Tuples of (name, tensor) for each weight and its scale
        """
        if isinstance(self.quant_config, dict):
            weight_block_size = self.quant_config.get("weight_block_size")
        else:
            weight_block_size = getattr(self.quant_config, "weight_block_size", None)

        if weight_block_size is None:
            raise ValueError("weight_block_size not found in quant_config")

        for k, v in weights:
            # Check if quantization is needed
            if not self.should_quantize_param(k):
                yield (k, v)
                continue

            # Quantize to FP8
            try:
                if torch.distributed.get_rank() == 0:
                    logger.debug(f"Quantizing to FP8 blockwise: {k}")

                param_lp, param_scale = scaled_fp8_blockwise(
                    v.to(dtype),
                    weight_block_size=weight_block_size,
                )
                param_scale = param_scale.squeeze(-1)

                # Yield the quantized weight and scale
                yield (k, param_lp)
                yield (k + "_scale_inv", param_scale)

                # Explicitly delete to help GC
                del param_lp, param_scale

            except Exception as e:
                logger.error(f"Failed to quantize {k}: {e}")
                # If quantization fails, use original weights
                yield (k, v)