NyxKrage's picture
1.0
ca700c7 verified
from inspect import signature
from typing import Any, Optional, Union
from accelerate import init_empty_weights
from bitsandbytes.nn import Linear4bit
from torch import nn
from transformers import BitsAndBytesConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import Conv1D
from transformers.quantizers.quantizer_bnb_4bit import Bnb4BitHfQuantizer
from transformers.utils import logging
from ..modular_qwen3_moe_fused import MoeFusedLinear
from .layer import MoeFusedLinear4bit
logger = logging.get_logger(__name__)
# Modified from https://github.com/huggingface/transformers/blob/508a7040556dc6b45f09174c662a9632284b2445/src/transformers/integrations/bitsandbytes.py#L150
def _replace_with_bnb_moe_fused_linear(
model: nn.Module,
modules_to_not_convert: list[str],
current_key_name: list[str],
quantization_config: BitsAndBytesConfig,
has_been_replaced: bool,
) -> bool:
for name, module in model.named_children():
current_key_name.append(name)
if isinstance(module, (nn.Linear, Conv1D, MoeFusedLinear)) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
num_experts = None
if isinstance(module, MoeFusedLinear):
in_features = module.in_features
out_features = module.out_features
num_experts = module.num_experts
elif isinstance(module, Conv1D):
in_features, out_features = module.weight.shape
else:
in_features = module.in_features
out_features = module.out_features
if isinstance(module, MoeFusedLinear):
model._modules[name] = MoeFusedLinear4bit(
in_features,
out_features,
num_experts,
compute_dtype=quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type,
quant_storage=quantization_config.bnb_4bit_quant_storage,
)
else:
extra_kwargs = (
{"quant_storage": quantization_config.bnb_4bit_quant_storage}
if "quant_storage" in list(signature(Linear4bit).parameters)
else {}
)
model._modules[name] = Linear4bit(
in_features,
out_features,
module.bias is not None,
quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type,
**extra_kwargs,
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
has_been_replaced = _replace_with_bnb_moe_fused_linear(
module, modules_to_not_convert, current_key_name, quantization_config, has_been_replaced
)
# Remove the last key for recursion
current_key_name.pop(-1)
return has_been_replaced
# model is modified in place
def replace_with_bnb_moe_fused_linear(
model: nn.Module, modules_to_not_convert: Optional[list[str]], quantization_config: BitsAndBytesConfig
) -> None:
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
with init_empty_weights():
has_been_replaced = _replace_with_bnb_moe_fused_linear(
model, modules_to_not_convert, [], quantization_config, False
)
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
def _process_model_before_weight_loading(
self: Bnb4BitHfQuantizer,
model: PreTrainedModel,
device_map: Union[str, dict[str, Any]],
keep_in_fp32_modules: Optional[list[str]] = None,
**kwargs,
) -> None:
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
)
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
if isinstance(device_map, dict) and len(device_map) > 1:
keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
self.modules_to_not_convert.extend(keys_on_cpu)
replace_with_bnb_moe_fused_linear(
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
)
model.config.quantization_config = self.quantization_config
def patch_bnb_quantizer() -> None:
Bnb4BitHfQuantizer._process_model_before_weight_loading = _process_model_before_weight_loading