# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from compressed_tensors.quantization.quant_args import ( BFLOAT16_DATA, FP4_E2M1_DATA, QuantizationArgs, ) __all__ = [ "maybe_convert_from_mxfp4_exp", "generate_mxfp4_scales", "round_to_power_2", "should_generatre_mxfp4_scales", ] # Reference: https://github.com/vllm-project/vllm/blob/main/tests/quantization/reference_mxfp4.py # noqa: E501 def should_generatre_mxfp4_scales(args: QuantizationArgs): return args.num_bits == 4 and args.type == "float" and args.group_size == 32 def maybe_convert_from_mxfp4_exp( args: QuantizationArgs, scale: torch.Tensor ) -> torch.Tensor: """ Converts mxfp4 scales. Scales are powers of 2, with the exponents stored in uint8. Converts to dense dtype so that they can be applied to the weights and activations during QDQ :param scale: uint8 exponent scale :param dtype: dense dtype """ original_dtype = scale.dtype if should_generatre_mxfp4_scales(args): scale_exp = scale.to(torch.int32) - 127 scale = 2.00 ** (scale_exp.to(torch.float)) return scale.to(original_dtype) return scale def round_to_power_2(x: torch.Tensor) -> torch.Tensor: """ Round values to the closest power of 2. This is done by masking the values with BFLOAT16_SIGN_EXPONENT_MASK which essentially removes the mantissa and keeps the exponent. i.e the closest power of 2 for the input_value. E.g: 0.0825 = 1.32 (mantissa) x 2**-4 (exponent) 0.0825 ==> -4 (exponent) + 127 = 123 = 01111011 (8 bits for bfloat16) 0.0825 ==> 0.32 (mantissa) = 0101001 (7 bits for bfloat16) 0.0825 == 0b01111011_0101001 (bfloat16) 0b01111011_0101001 & 111111111_0000000 == 0b01111011_0000000 Keep the exponent + sign bit to give you the closest power of 2, 0.0625 :param x: tensor to round to closest power of 2 """ assert x.dtype == torch.bfloat16 x = x.view(torch.uint16).to(torch.int32) # Find closest power of 2 BFLOAT16_VAL_TO_ADD = 1 << (BFLOAT16_DATA.mantissa - FP4_E2M1_DATA.mantissa - 1) # Add value to push the value to the next exponent BFLOAT16_SIGN_EXPONENT_MASK = ( (1 << (BFLOAT16_DATA.exponent + 1)) - 1 ) << BFLOAT16_DATA.mantissa # mask to only keep exponent - we conservatively round down # to better represent smaller numbers / prevent overflow block_max_uint = torch.bitwise_and( x + BFLOAT16_VAL_TO_ADD, BFLOAT16_SIGN_EXPONENT_MASK ) return block_max_uint.to(torch.uint16).view(torch.bfloat16) def generate_mxfp4_scales(x: torch.Tensor) -> torch.Tensor: """ Generate mxfp4 scales. The scales require the following steps 1. Round to the closest power of 2 2. Convert to exponent Called when calculating qparams using observers. :param x: tensor to round to closest power of 2 :returns scales as exponents """ # Round to closest power of 2 scale_power_2 = round_to_power_2(x) return 127 + torch.floor(torch.log2(scale_power_2)) - 2