File size: 5,770 Bytes
ca700c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from inspect import signature
from typing import Any, Optional, Union
from accelerate import init_empty_weights
from bitsandbytes.nn import Linear4bit
from torch import nn
from transformers import BitsAndBytesConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import Conv1D
from transformers.quantizers.quantizer_bnb_4bit import Bnb4BitHfQuantizer
from transformers.utils import logging
from ..modular_qwen3_moe_fused import MoeFusedLinear
from .layer import MoeFusedLinear4bit
logger = logging.get_logger(__name__)
# Modified from https://github.com/huggingface/transformers/blob/508a7040556dc6b45f09174c662a9632284b2445/src/transformers/integrations/bitsandbytes.py#L150
def _replace_with_bnb_moe_fused_linear(
model: nn.Module,
modules_to_not_convert: list[str],
current_key_name: list[str],
quantization_config: BitsAndBytesConfig,
has_been_replaced: bool,
) -> bool:
for name, module in model.named_children():
current_key_name.append(name)
if isinstance(module, (nn.Linear, Conv1D, MoeFusedLinear)) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
num_experts = None
if isinstance(module, MoeFusedLinear):
in_features = module.in_features
out_features = module.out_features
num_experts = module.num_experts
elif isinstance(module, Conv1D):
in_features, out_features = module.weight.shape
else:
in_features = module.in_features
out_features = module.out_features
if isinstance(module, MoeFusedLinear):
model._modules[name] = MoeFusedLinear4bit(
in_features,
out_features,
num_experts,
compute_dtype=quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type,
quant_storage=quantization_config.bnb_4bit_quant_storage,
)
else:
extra_kwargs = (
{"quant_storage": quantization_config.bnb_4bit_quant_storage}
if "quant_storage" in list(signature(Linear4bit).parameters)
else {}
)
model._modules[name] = Linear4bit(
in_features,
out_features,
module.bias is not None,
quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type,
**extra_kwargs,
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
has_been_replaced = _replace_with_bnb_moe_fused_linear(
module, modules_to_not_convert, current_key_name, quantization_config, has_been_replaced
)
# Remove the last key for recursion
current_key_name.pop(-1)
return has_been_replaced
# model is modified in place
def replace_with_bnb_moe_fused_linear(
model: nn.Module, modules_to_not_convert: Optional[list[str]], quantization_config: BitsAndBytesConfig
) -> None:
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
with init_empty_weights():
has_been_replaced = _replace_with_bnb_moe_fused_linear(
model, modules_to_not_convert, [], quantization_config, False
)
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
def _process_model_before_weight_loading(
self: Bnb4BitHfQuantizer,
model: PreTrainedModel,
device_map: Union[str, dict[str, Any]],
keep_in_fp32_modules: Optional[list[str]] = None,
**kwargs,
) -> None:
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
)
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
if isinstance(device_map, dict) and len(device_map) > 1:
keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
self.modules_to_not_convert.extend(keys_on_cpu)
replace_with_bnb_moe_fused_linear(
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
)
model.config.quantization_config = self.quantization_config
def patch_bnb_quantizer() -> None:
Bnb4BitHfQuantizer._process_model_before_weight_loading = _process_model_before_weight_loading
|