| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | from compressed_tensors.quantization.quant_args import ( |
| | BFLOAT16_DATA, |
| | FP4_E2M1_DATA, |
| | QuantizationArgs, |
| | ) |
| |
|
| |
|
| | __all__ = [ |
| | "maybe_convert_from_mxfp4_exp", |
| | "generate_mxfp4_scales", |
| | "round_to_power_2", |
| | "should_generatre_mxfp4_scales", |
| | ] |
| |
|
| | |
| |
|
| |
|
| | def should_generatre_mxfp4_scales(args: QuantizationArgs): |
| | return args.num_bits == 4 and args.type == "float" and args.group_size == 32 |
| |
|
| |
|
| | def maybe_convert_from_mxfp4_exp( |
| | args: QuantizationArgs, scale: torch.Tensor |
| | ) -> torch.Tensor: |
| | """ |
| | Converts mxfp4 scales. Scales are powers of 2, with the |
| | exponents stored in uint8. Converts to dense dtype so that |
| | they can be applied to the weights and activations during QDQ |
| | |
| | :param scale: uint8 exponent scale |
| | :param dtype: dense dtype |
| | """ |
| | original_dtype = scale.dtype |
| | if should_generatre_mxfp4_scales(args): |
| | scale_exp = scale.to(torch.int32) - 127 |
| | scale = 2.00 ** (scale_exp.to(torch.float)) |
| | return scale.to(original_dtype) |
| | return scale |
| |
|
| |
|
| | def round_to_power_2(x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Round values to the closest power of 2. |
| | This is done by masking the values with BFLOAT16_SIGN_EXPONENT_MASK |
| | which essentially removes the mantissa and keeps the exponent. |
| | i.e the closest power of 2 for the input_value. |
| | |
| | E.g: |
| | 0.0825 = 1.32 (mantissa) x 2**-4 (exponent) |
| | 0.0825 ==> -4 (exponent) + 127 = 123 = 01111011 (8 bits for bfloat16) |
| | 0.0825 ==> 0.32 (mantissa) = 0101001 (7 bits for bfloat16) |
| | 0.0825 == 0b01111011_0101001 (bfloat16) |
| | 0b01111011_0101001 & 111111111_0000000 == 0b01111011_0000000 |
| | Keep the exponent + sign bit to give you the closest power of 2, 0.0625 |
| | |
| | :param x: tensor to round to closest power of 2 |
| | """ |
| | assert x.dtype == torch.bfloat16 |
| | x = x.view(torch.uint16).to(torch.int32) |
| |
|
| | |
| | BFLOAT16_VAL_TO_ADD = 1 << (BFLOAT16_DATA.mantissa - FP4_E2M1_DATA.mantissa - 1) |
| | |
| | BFLOAT16_SIGN_EXPONENT_MASK = ( |
| | (1 << (BFLOAT16_DATA.exponent + 1)) - 1 |
| | ) << BFLOAT16_DATA.mantissa |
| | |
| | |
| | block_max_uint = torch.bitwise_and( |
| | x + BFLOAT16_VAL_TO_ADD, BFLOAT16_SIGN_EXPONENT_MASK |
| | ) |
| | return block_max_uint.to(torch.uint16).view(torch.bfloat16) |
| |
|
| |
|
| | def generate_mxfp4_scales(x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Generate mxfp4 scales. The scales require the following steps |
| | 1. Round to the closest power of 2 |
| | 2. Convert to exponent |
| | |
| | Called when calculating qparams using observers. |
| | |
| | :param x: tensor to round to closest power of 2 |
| | :returns scales as exponents |
| | """ |
| | |
| | scale_power_2 = round_to_power_2(x) |
| | return 127 + torch.floor(torch.log2(scale_power_2)) - 2 |
| |
|