File size: 19,225 Bytes
6f0b660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional

from .base import HfQuantizer


if TYPE_CHECKING:
    from ..modeling_utils import PreTrainedModel

from ..utils import (
    is_accelerate_available,
    is_kernels_available,
    is_torch_available,
    is_triton_available,
    logging,
)
from .quantizers_utils import get_module_from_name


if is_torch_available():
    import torch

logger = logging.get_logger(__name__)
triton_kernels_hub = None


class Mxfp4HfQuantizer(HfQuantizer):
    """
    FP4 quantization using fbgemm kernels
    """

    requires_parameters_quantization = True
    requires_calibration = False

    required_packages = ["accelerate"]

    def __init__(self, quantization_config, **kwargs):
        super().__init__(quantization_config, **kwargs)
        self.quantization_config = quantization_config
        self.triton_kernels_hub = None

    def _lazy_import_kernels(self):
        """Lazy import and initialize kernels only when needed"""
        if self.triton_kernels_hub is None:
            try:
                from kernels import get_kernel

                self.triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
            except ImportError:
                raise ImportError("kernels package is required for MXFP4 quantization")
        return self.triton_kernels_hub

    def validate_environment(self, *args, **kwargs):
        if not is_torch_available():
            raise ImportError(
                "Using mxfp4 quantization requires torch"
                "Please install the latest version of torch ( pip install --upgrade torch )"
            )

        if self.quantization_config.dequantize:
            return

        if not (torch.cuda.is_available() or torch.xpu.is_available()):
            if self.pre_quantized:
                logger.warning_once(
                    "Using MXFP4 quantized models requires a GPU, we will default to dequantizing the model to bf16"
                )
                self.quantization_config.dequantize = True
                return
            else:
                raise RuntimeError("Quantizing a model using MXFP4 requires a GPU")

        if not is_accelerate_available():
            raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`")

        if torch.xpu.is_available():
            gpu_is_supported = True
            kernels_available = is_triton_available("3.5.0") and is_kernels_available()
        else:
            compute_capability = torch.cuda.get_device_capability()
            gpu_is_supported = compute_capability >= (7, 5)
            kernels_available = is_triton_available("3.4.0") and is_kernels_available()

        if self.pre_quantized:
            # On unsupported GPUs or without kernels, we will dequantize the model to bf16
            if not gpu_is_supported:
                logger.warning_once(
                    "MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 (e.g T4, A100, L4, H100, or B200) or XPUs (e.g Intel® Data Center GPU Max Series) "
                    "We will default to dequantizing the model to bf16."
                )
                self.quantization_config.dequantize = True
                return

            if not kernels_available:
                logger.warning_once(
                    "MXFP4 quantization requires Triton and kernels installed: CUDA requires Triton >= 3.4.0, XPU requires Triton >= 3.5.0, we will default to dequantizing the model to bf16"
                )
                self.quantization_config.dequantize = True
                return
        elif not gpu_is_supported:
            # we can't quantize the model in this case so we raise an error
            raise ValueError(
                "MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 (e.g T4, A100, L4, H100, or B200) or XPUs (e.g Intel® Data Center GPU Max Series) "
            )
        elif not kernels_available:
            # we can't quantize the model in this case so we raise an error
            raise ValueError(
                "MXFP4 quantization requires Triton and kernels installed: CUDA requires Triton >= 3.4.0, XPU requires Triton >= 3.5.0"
            )

        if not self.pre_quantized:
            self._lazy_import_kernels()

        device_map = kwargs.get("device_map")
        if device_map is None:
            logger.warning_once(
                "You have loaded an FP4 model on CPU and have a CUDA/XPU device available, make sure to set "
                "your model on a GPU/XPU device in order to run your model. To remove this warning, pass device_map = 'cuda' or device_map = 'xpu'. "
            )
        elif device_map is not None:
            if (
                not self.pre_quantized
                and isinstance(device_map, dict)
                and ("cpu" in device_map.values() or "disk" in device_map.values())
            ):
                raise ValueError(
                    "You are attempting to load an FP4 model with a device_map that contains a CPU or disk device."
                    "This is not supported when the model is quantized on the fly. "
                    "Please use a quantized checkpoint or remove the CPU or disk device from the device_map."
                )

    def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
        if dtype is None:
            dtype = torch.bfloat16
            logger.info(
                "Overriding dtype=%s with `dtype=torch.bfloat16` due to "
                "requirements of `fbgemm-gpu` to enable model loading in fp4. "
                "Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
                " dtype=torch.bfloat16 to remove this warning.",
                dtype,
            )
        return dtype

    def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
        from ..integrations import Mxfp4GptOssExperts
        from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts

        # if we are dequantizing, the model doesn't have scales, and blocks only params like gate_up_proj and down_proj so we need to handle this case differently
        if self.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name):
            module, tensor_name = get_module_from_name(model, param_name[: -len("_blocks")])
        else:
            module, tensor_name = get_module_from_name(model, param_name)
        if isinstance(module, Mxfp4GptOssExperts) or (
            isinstance(module, GptOssExperts) and self.quantization_config.dequantize
        ):
            if tensor_name in ["down_proj_bias", "gate_up_proj_bias"]:
                return False
            return True
        return False

    def create_quantized_param(
        self,
        model: "PreTrainedModel",
        param_value: "torch.Tensor",
        param_name: str,
        target_device: "torch.device",
        **kwargs,
    ):
        from ..integrations import (
            Mxfp4GptOssExperts,
            dequantize,
            load_and_swizzle_mxfp4,
            quantize_to_mxfp4,
            swizzle_mxfp4,
        )
        from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts

        if not self.pre_quantized:
            triton_kernels_hub = self._lazy_import_kernels()
            module, _ = get_module_from_name(model, param_name)
            with torch.device(target_device):
                if isinstance(module, Mxfp4GptOssExperts):
                    triton_weight_tensor, weight_scale = quantize_to_mxfp4(param_value, triton_kernels_hub)
                    PrecisionConfig, FlexCtx, InFlexData = (
                        triton_kernels_hub.matmul_ogs.PrecisionConfig,
                        triton_kernels_hub.matmul_ogs.FlexCtx,
                        triton_kernels_hub.matmul_ogs.InFlexData,
                    )
                    triton_weight_tensor, weight_scale = swizzle_mxfp4(
                        triton_weight_tensor, weight_scale, triton_kernels_hub
                    )

                    proj = "gate_up_proj" if "gate_up_proj" in param_name else "down_proj"
                    setattr(module, proj, triton_weight_tensor)
                    setattr(
                        module,
                        f"{proj}_precision_config",
                        PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
                    )

                    delattr(module, f"{proj}_blocks")
                    delattr(module, f"{proj}_scales")

        # The params going here are either gate_up_proj_blocks, or down_proj_blocks, or gate_up_proj_scales, or down_proj_scales
        else:
            #  This is when loading a quantized model (blocks and scales exist)
            empty_param = kwargs.get("empty_param")
            casting_dtype = kwargs.get("casting_dtype")
            to_contiguous = kwargs.get("to_contiguous")
            rank = kwargs.get("rank")
            device_mesh = kwargs.get("device_mesh")
            if ("blocks" in param_name or "scales" in param_name) and self.quantization_config.dequantize:
                # blocks and scales have the same length that's why this works for both
                module, _ = get_module_from_name(model, param_name[: -len("_blocks")])
            else:
                module, _ = get_module_from_name(model, param_name)

            shard_kwargs = {
                "empty_param": empty_param,
                "casting_dtype": casting_dtype,
                "to_contiguous": to_contiguous,
                "rank": rank,
                "device_mesh": device_mesh,
                "model": model,
            }

            if isinstance(module, Mxfp4GptOssExperts) or (
                isinstance(module, GptOssExperts) and self.quantization_config.dequantize
            ):
                if self.quantization_config.dequantize:
                    # dq_param_name is the name of the parameter without the blocks or scales suffix, it's used in this case since we don't switch linears
                    # so we only have the original param name
                    dq_param_name = param_name[: -len("_blocks")]
                    dequantize(module, param_name, param_value, target_device, dq_param_name, **shard_kwargs)
                else:
                    load_and_swizzle_mxfp4(
                        module,
                        param_name,
                        param_value,
                        target_device,
                        self._lazy_import_kernels(),
                        **shard_kwargs,
                    )

    def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
        # we are not really dequantizing, we are just removing everything related to quantization here
        if self.quantization_config.dequantize:
            self.remove_quantization_config(model)
        # clean cache due to triton ops
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        elif torch.xpu.is_available():
            torch.xpu.empty_cache()

    def update_expected_keys(self, model: "PreTrainedModel", expected_keys: list[str], checkpoint_keys: list[str]):
        # Replace expected_keys for experts' gate_up_proj and down_proj with their _blocks and _scales variants
        new_expected_keys = []
        for key in expected_keys:
            if key.endswith(".mlp.experts.gate_up_proj"):
                base = key[: -len("gate_up_proj")]
                new_expected_keys.append(base + "gate_up_proj_blocks")
                new_expected_keys.append(base + "gate_up_proj_scales")
            elif key.endswith(".mlp.experts.down_proj"):
                base = key[: -len("down_proj")]
                new_expected_keys.append(base + "down_proj_blocks")
                new_expected_keys.append(base + "down_proj_scales")
            elif not self.pre_quantized:
                # in this case, we are quantizing the model so we need to update the keys as we changed the layers
                if key.endswith(".mlp.experts.down_proj_blocks"):
                    base = key[: -len("down_proj_blocks")]
                    new_expected_keys.append(base + "down_proj")
                elif key.endswith(".mlp.experts.gate_up_proj_blocks"):
                    base = key[: -len("gate_up_proj_blocks")]
                    new_expected_keys.append(base + "gate_up_proj")
                elif key.endswith("scales"):
                    # we remove it the scales as the checkpoint don't contain them
                    continue
                else:
                    new_expected_keys.append(key)
            else:
                new_expected_keys.append(key)
        return new_expected_keys

    def _process_model_before_weight_loading(
        self,
        model: "PreTrainedModel",
        keep_in_fp32_modules: Optional[list[str]] = None,
        **kwargs,
    ):
        from ..integrations import replace_with_mxfp4_linear

        self.modules_to_not_convert = self.get_modules_to_not_convert(
            model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
        )

        use_kernels = kwargs.get("use_kernels", False)
        # if we are using kernels, we can't use the quantized model, since the forward pass is different and needs special handling
        if use_kernels:
            logger.warning_once(
                "You are using full precision kernels, we will dequantize the model to bf16. "
                "To use the quantized model with quantization kernels, please set use_kernels=False"
            )
            self.quantization_config.dequantize = True

        config = model.config
        model = replace_with_mxfp4_linear(
            model,
            modules_to_not_convert=self.modules_to_not_convert,
            quantization_config=self.quantization_config,
            config=config,
        )

        model.config.quantization_config = self.quantization_config

    def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
        from ..integrations import Mxfp4GptOssExperts

        not_missing_keys = []
        for name, module in model.named_modules():
            if isinstance(module, Mxfp4GptOssExperts):
                for missing in missing_keys:
                    if (
                        (name in missing or name in f"{prefix}.{missing}")
                        and not missing.endswith(".weight")
                        and not missing.endswith(".bias")
                    ):
                        not_missing_keys.append(missing)
        return [k for k in missing_keys if k not in not_missing_keys]

    def update_tp_plan(self, config):
        if "GptOssConfig" in config.__class__.__name__:
            if getattr(config, "base_model_tp_plan", None) is not None:
                config.base_model_tp_plan.update(
                    {
                        "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm",
                        "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm",
                        "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm",
                        "layers.*.mlp.experts.down_proj_scales": "grouped_gemm",
                    }
                )
        return config

    def update_ep_plan(self, config):
        if "GptOssConfig" in config.__class__.__name__:
            if getattr(config, "base_model_ep_plan", None) is not None:
                config.base_model_ep_plan.update(
                    {
                        "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm",
                        "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm",
                        "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm",
                        "layers.*.mlp.experts.down_proj_scales": "grouped_gemm",
                    }
                )
        return config

    def get_param_name(self, param_name: str) -> str:
        if self.quantization_config.dequantize:
            if "_blocks" in param_name:
                return param_name.replace("_blocks", "")
            elif "_scales" in param_name:
                return param_name.replace("_scales", "")
        elif not self.pre_quantized:
            if param_name.endswith("gate_up_proj"):
                return param_name.replace("gate_up_proj", "gate_up_proj_blocks")
            if param_name.endswith("down_proj"):
                return param_name.replace("down_proj", "down_proj_blocks")
        return param_name

    def get_state_dict_and_metadata(self, model, safe_serialization: bool = False):
        from ..integrations import Mxfp4GptOssExperts

        state_dict = model.state_dict()

        for name, module in model.named_modules():
            if (
                isinstance(module, Mxfp4GptOssExperts)
                and hasattr(module, "gate_up_proj")
                and hasattr(module, "down_proj")
            ):
                state_dict[f"{name}.gate_up_proj_blocks"] = (
                    module.gate_up_proj.storage.layout.unswizzle_data(module.gate_up_proj.storage.data)
                    .transpose(-1, -2)
                    .reshape(32, -1, 90, 16)
                )
                state_dict[f"{name}.gate_up_proj_scales"] = (
                    module.gate_up_proj_precision_config.weight_scale.storage.layout.unswizzle_data(
                        module.gate_up_proj_precision_config.weight_scale.storage.data
                    ).transpose(-1, -2)
                )
                state_dict[f"{name}.down_proj_blocks"] = (
                    module.down_proj.storage.layout.unswizzle_data(module.down_proj.storage.data)
                    .transpose(-1, -2)
                    .reshape(32, 2880, 90, -1)
                )
                state_dict[f"{name}.down_proj_scales"] = (
                    module.down_proj_precision_config.weight_scale.storage.layout.unswizzle_data(
                        module.down_proj_precision_config.weight_scale.storage.data
                    ).transpose(-1, -2)
                )

        metadata = {}
        return state_dict, metadata

    def is_serializable(self, safe_serialization=None):
        return True

    @property
    def is_trainable(self) -> bool:
        logger.warning_once(
            "MXFP4 quantization don't support training, please consider dequantizing the model first by passing quantization_config=Mxfp4Config(dequantize=True) to .from_pretrained()"
        )
        return False