| # SPDX-License-Identifier: Apache-2.0 | |
| import re | |
| from collections.abc import Iterable, Mapping | |
| from types import MappingProxyType | |
| from typing import Any, Optional | |
| import torch | |
| from aiter.ops.triton.quant import dynamic_mxfp4_quant | |
| from torch import nn | |
| def deep_compare(dict1: Any, dict2: Any) -> bool: | |
| if type(dict1) is not type(dict2): | |
| return False | |
| if isinstance(dict1, dict): | |
| if dict1.keys() != dict2.keys(): | |
| return False | |
| return all(deep_compare(dict1[k], dict2[k]) for k in dict1) | |
| elif isinstance(dict1, list): | |
| return set(dict1) == set(dict2) | |
| else: | |
| return dict1 == dict2 | |
| def should_ignore_layer( | |
| layer_name: Optional[str], | |
| ignore: Iterable[str], | |
| fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), | |
| ) -> bool: | |
| if layer_name is None: | |
| return False | |
| # layer_name = model.layers.0.self_attn.qkv_proj | |
| # proj_name = qkv_proj | |
| proj_name = layer_name.split(".")[-1] | |
| # Fused layers like gate_up_proj or qkv_proj will not be fused | |
| # in the safetensors checkpoint. So, we convert the name | |
| # from the fused version to unfused + check to make sure that | |
| # each shard of the fused layer has the same scheme. | |
| if proj_name in fused_mapping: | |
| shard_proj_names = fused_mapping[proj_name] | |
| # Convert fused_name --> [shard_names] | |
| shard_names = [ | |
| layer_name.replace(proj_name, shard_proj_name) | |
| for shard_proj_name in shard_proj_names | |
| ] | |
| # Layer should be ignored if shards are ignored. | |
| should_ignore_layer = None | |
| for shard_name in shard_names: | |
| should_ignore_shard = check_equal_or_regex_match( | |
| layer_name=shard_name, targets=ignore | |
| ) | |
| # If shard_idx=0, set layer ignore to match shard. | |
| if should_ignore_layer is None: | |
| should_ignore_layer = should_ignore_shard | |
| # If shard_idx=1+ confirm scheme matches prior shards. | |
| elif should_ignore_shard != should_ignore_layer: | |
| raise ValueError( | |
| f"Found a different quantization schemes for " | |
| f"{shard_proj_names} in {layer_name}. vLLM " | |
| "requires all to use the same scheme." | |
| ) | |
| # Unfused layers like down_proj and o_proj will match | |
| # the safetensors checkpoint already. | |
| else: | |
| should_ignore_layer = check_equal_or_regex_match( | |
| layer_name=layer_name, targets=ignore | |
| ) | |
| assert should_ignore_layer is not None | |
| return should_ignore_layer | |
| def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool: | |
| """ | |
| Checks whether a layer_name is exactly equal or a regex match for | |
| if target starts with 're:' to any target in list. | |
| """ | |
| for target in targets: | |
| if _is_equal_or_regex_match(layer_name, target): | |
| return True | |
| return False | |
| def _is_equal_or_regex_match( | |
| value: str, target: str, check_contains: bool = False | |
| ) -> bool: | |
| """ | |
| Checks whether a value is exactly equal or a regex match for target | |
| if target starts with 're:'. If check_contains is set to True, | |
| additionally checks if the target string is contained within the value. | |
| """ | |
| if target.startswith("re:"): | |
| pattern = target[3:] | |
| if re.match(pattern, value): | |
| return True | |
| elif check_contains: | |
| if target.lower() in value.lower(): | |
| return True | |
| elif target == value: | |
| return True | |
| return False | |
| # utility for tensor dims > 2 cases | |
| def b_dynamic_mxfp4_quant(x): | |
| h, b, d = x.shape | |
| x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d)) | |
| return x.view(h, b, d // 2), x_scales.view(h, b, d // 32) | |
| def mxfp4_to_f32(x, is_threed): | |
| # 2 because we pack fp4 in uint8. | |
| x = x.repeat_interleave(2, dim=-1) | |
| if is_threed: | |
| x[..., ::2] = x[..., ::2] & 0xF | |
| x[..., 1::2] = x[..., 1::2] >> 4 | |
| else: | |
| x[:, ::2] = x[:, ::2] & 0xF | |
| x[:, 1::2] = x[:, 1::2] >> 4 | |
| mxfp4_list = [ | |
| 0.0, | |
| 0.5, | |
| 1.0, | |
| 1.5, | |
| 2.0, | |
| 3.0, | |
| 4.0, | |
| 6.0, | |
| -0.0, | |
| -0.5, | |
| -1.0, | |
| -1.5, | |
| -2.0, | |
| -3.0, | |
| -4.0, | |
| -6.0, | |
| ] | |
| mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda") | |
| return mxfp4_in_f32[x.long()] | |
| def e8m0_to_f32(x): | |
| # Convert the input tensor `x` (assumed to be in e8m0 format) to float32. | |
| # e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa. | |
| # This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats. | |
| # Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127). | |
| x_f32 = 2 ** ((x.to(torch.float32)) - 127) | |
| # If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf. | |
| # Since this custom format has no mantissa, treat 2^128 as NaN. | |
| x_f32[x_f32 == 128] = float("nan") | |
| return x_f32 | |
| def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str): | |
| if "mxfp4" in quant_format: | |
| # when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor | |
| # do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8) | |
| # and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8) | |
| if w.dtype == torch.bfloat16: | |
| w_kc, w_vc = w.unflatten( | |
| 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) | |
| ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) | |
| w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1)) | |
| w_kc = w_kc.transpose(-2, -1) | |
| w_s_kc = w_s_kc.transpose(-2, -1) | |
| w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc) | |
| w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2) | |
| w_s_vc = w_s_vc.contiguous().transpose(1, 2) | |
| elif w.dtype == torch.uint8: # static quant for mxfp4 | |
| # when dtype is uint8, it means the w has been quantized to mxfp4 format | |
| # but we must separate it to w_kc and w_vc. | |
| # The quantized tensor size is only half of original tensor size | |
| # and the scaling factor is 1/32, the transpose behavior will be not correct | |
| # need to upcast it to fp32 to separate w to w_kc and w_vc | |
| # to ensure the following transpose behavior is correct | |
| # and then do mxfp4 quant again | |
| w = mxfp4_to_f32(w, True).to(torch.bfloat16) | |
| w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1) | |
| w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16) | |
| w = w * w_scales | |
| w_kc, w_vc = w.unflatten( | |
| 0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim)) | |
| ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) | |
| w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1)) | |
| w_kc = w_kc.transpose(-2, -1) | |
| w_s_kc = w_s_kc.transpose(-2, -1) | |
| w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc) | |
| w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2) | |
| w_s_vc = w_s_vc.contiguous().transpose(1, 2) | |
| return w_kc, w_s_kc, w_vc, w_s_vc | |
Xet Storage Details
- Size:
- 7.45 kB
- Xet hash:
- c53888de55d98bd3c22cf693afda00cadeb2052a1d7027645cb7afd8ab8dd136
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.