File size: 5,770 Bytes
ca700c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from inspect import signature
from typing import Any, Optional, Union

from accelerate import init_empty_weights
from bitsandbytes.nn import Linear4bit
from torch import nn
from transformers import BitsAndBytesConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import Conv1D
from transformers.quantizers.quantizer_bnb_4bit import Bnb4BitHfQuantizer
from transformers.utils import logging

from ..modular_qwen3_moe_fused import MoeFusedLinear
from .layer import MoeFusedLinear4bit


logger = logging.get_logger(__name__)


# Modified from https://github.com/huggingface/transformers/blob/508a7040556dc6b45f09174c662a9632284b2445/src/transformers/integrations/bitsandbytes.py#L150
def _replace_with_bnb_moe_fused_linear(
    model: nn.Module,
    modules_to_not_convert: list[str],
    current_key_name: list[str],
    quantization_config: BitsAndBytesConfig,
    has_been_replaced: bool,
) -> bool:
    for name, module in model.named_children():
        current_key_name.append(name)

        if isinstance(module, (nn.Linear, Conv1D, MoeFusedLinear)) and name not in modules_to_not_convert:
            # Check if the current key is not in the `modules_to_not_convert`
            current_key_name_str = ".".join(current_key_name)
            if not any(
                (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
            ):
                num_experts = None
                if isinstance(module, MoeFusedLinear):
                    in_features = module.in_features
                    out_features = module.out_features
                    num_experts = module.num_experts
                elif isinstance(module, Conv1D):
                    in_features, out_features = module.weight.shape
                else:
                    in_features = module.in_features
                    out_features = module.out_features

                if isinstance(module, MoeFusedLinear):
                    model._modules[name] = MoeFusedLinear4bit(
                        in_features,
                        out_features,
                        num_experts,
                        compute_dtype=quantization_config.bnb_4bit_compute_dtype,
                        compress_statistics=quantization_config.bnb_4bit_use_double_quant,
                        quant_type=quantization_config.bnb_4bit_quant_type,
                        quant_storage=quantization_config.bnb_4bit_quant_storage,
                    )
                else:
                    extra_kwargs = (
                        {"quant_storage": quantization_config.bnb_4bit_quant_storage}
                        if "quant_storage" in list(signature(Linear4bit).parameters)
                        else {}
                    )
                    model._modules[name] = Linear4bit(
                        in_features,
                        out_features,
                        module.bias is not None,
                        quantization_config.bnb_4bit_compute_dtype,
                        compress_statistics=quantization_config.bnb_4bit_use_double_quant,
                        quant_type=quantization_config.bnb_4bit_quant_type,
                        **extra_kwargs,
                    )

                has_been_replaced = True
                # Store the module class in case we need to transpose the weight later
                model._modules[name].source_cls = type(module)
                # Force requires grad to False to avoid unexpected errors
                model._modules[name].requires_grad_(False)

        if len(list(module.children())) > 0:
            has_been_replaced = _replace_with_bnb_moe_fused_linear(
                module, modules_to_not_convert, current_key_name, quantization_config, has_been_replaced
            )

        # Remove the last key for recursion
        current_key_name.pop(-1)

    return has_been_replaced


# model is modified in place
def replace_with_bnb_moe_fused_linear(
    model: nn.Module, modules_to_not_convert: Optional[list[str]], quantization_config: BitsAndBytesConfig
) -> None:
    modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
    with init_empty_weights():
        has_been_replaced = _replace_with_bnb_moe_fused_linear(
            model, modules_to_not_convert, [], quantization_config, False
        )

    if not has_been_replaced:
        logger.warning(
            "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
            " Please double check your model architecture, or submit an issue on github if you think this is"
            " a bug."
        )


def _process_model_before_weight_loading(
    self: Bnb4BitHfQuantizer,
    model: PreTrainedModel,
    device_map: Union[str, dict[str, Any]],
    keep_in_fp32_modules: Optional[list[str]] = None,
    **kwargs,
) -> None:
    self.modules_to_not_convert = self.get_modules_to_not_convert(
        model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
    )

    # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
    if isinstance(device_map, dict) and len(device_map) > 1:
        keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
        self.modules_to_not_convert.extend(keys_on_cpu)

    replace_with_bnb_moe_fused_linear(
        model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
    )

    model.config.quantization_config = self.quantization_config


def patch_bnb_quantizer() -> None:
    Bnb4BitHfQuantizer._process_model_before_weight_loading = _process_model_before_weight_loading