File size: 9,490 Bytes
0e9a03e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import json
from typing import Any, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import init_empty_weights
from huggingface_hub import HfApi

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.quantizers import HfQuantizer, get_module_from_name, register_quantization_config, register_quantizer
from transformers.utils.quantization_config import QuantizationConfigMixin


# Implement INT8 Symmetric Linear layer
class Int8SymmetricLinear(torch.nn.Module):
    def __init__(self, in_features, out_features, bias, dtype=torch.float32):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.int8))
        self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=dtype))

        if bias:
            self.register_buffer("bias", torch.zeros((self.out_features), dtype=dtype))
        else:
            self.bias = None

    def forward(self, x):
        dequant_weight = self.weight * self.weight_scale
        output = F.linear(x, dequant_weight)
        if self.bias is not None:
            output = output + self.bias
        return output


# Function to replace standard linear layers with INT8 symmetric quantized layers
def _replace_with_int8_symmetric_linear(
    model,
    modules_to_not_convert=None,
    current_key_name=None,
    quantization_config=None,
    has_been_replaced=False,
    pre_quantized=False,
):
    """
    Recursively replaces nn.Linear modules with Int8SymmetricLinear modules.
    """
    if current_key_name is None:
        current_key_name = []

    for name, module in model.named_children():
        current_key_name.append(name)

        if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert:
            # Check if the current key is not in the `modules_to_not_convert`
            current_key_name_str = ".".join(current_key_name)
            if not any(
                (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
            ):
                with init_empty_weights(include_buffers=True):
                    in_features = module.in_features
                    out_features = module.out_features
                    model._modules[name] = Int8SymmetricLinear(
                        in_features, out_features, module.bias is not None, dtype=module.weight.dtype
                    )
                    has_been_replaced = True
                    model._modules[name].requires_grad_(False)

        if len(list(module.children())) > 0:
            _, has_been_replaced = _replace_with_int8_symmetric_linear(
                module,
                modules_to_not_convert,
                current_key_name,
                quantization_config,
                has_been_replaced=has_been_replaced,
                pre_quantized=pre_quantized,
            )
        # Remove the last key for recursion
        current_key_name.pop(-1)
    return model, has_been_replaced


def replace_with_int8_symmetric_linear(
    model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
):
    """
    Main function to replace model layers with INT8 symmetric quantized versions.
    """
    modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert

    if quantization_config.modules_to_not_convert is not None:
        modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
    modules_to_not_convert = list(set(modules_to_not_convert))

    model, has_been_replaced = _replace_with_int8_symmetric_linear(
        model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
    )

    if not has_been_replaced:
        raise ValueError(
            "You are loading your model using INT8 symmetric quantization but no linear modules were found in your model."
        )

    return model


@register_quantization_config("int8_symmetric")
class Int8SymmetricConfig(QuantizationConfigMixin):
    """
    Configuration for INT8 symmetric quantization.
    """

    def __init__(self, modules_to_not_convert: Optional[list[str]] = None, **kwargs):
        self.quant_method = "int8_symmetric"
        self.modules_to_not_convert = modules_to_not_convert

    def __repr__(self):
        config_dict = self.to_dict()
        return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"

    def to_diff_dict(self) -> dict[str, Any]:
        config_dict = self.to_dict()
        default_config_dict = Int8SymmetricConfig().to_dict()

        serializable_config_dict = {}
        for key, value in config_dict.items():
            if value != default_config_dict[key]:
                serializable_config_dict[key] = value

        return serializable_config_dict


@register_quantizer("int8_symmetric")
class Int8SymmetricQuantizer(HfQuantizer):
    """
    Implementation of INT8 symmetric quantization.

    """

    requires_calibration = False
    requires_parameters_quantization = True

    def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
        super().__init__(quantization_config, **kwargs)
        self.quantization_config = quantization_config

    def _process_model_before_weight_loading(self, model, **kwargs):
        """
        Replace model's linear layers with quantized versions before loading weights.
        """
        self.modules_to_not_convert = self.quantization_config.modules_to_not_convert

        model = replace_with_int8_symmetric_linear(
            model,
            modules_to_not_convert=self.modules_to_not_convert,
            quantization_config=self.quantization_config,
            pre_quantized=self.pre_quantized,
        )

    def check_quantized_param(
        self,
        model,
        param_value: "torch.Tensor",
        param_name: str,
        state_dict: dict[str, Any],
        **kwargs,
    ):
        module, tensor_name = get_module_from_name(model, param_name)

        if isinstance(module, Int8SymmetricLinear):
            if self.pre_quantized or tensor_name == "bias":
                if tensor_name == "weight" and param_value.dtype != torch.int8:
                    raise ValueError("Expect quantized weights but got an unquantized weight")
                return False
            else:
                if tensor_name == "weight_scale":
                    raise ValueError("Expect unquantized weights but got a quantized weight_scale")
                return True
        return False

    def create_quantized_param(
        self,
        model,
        param_value: "torch.Tensor",
        param_name: str,
        target_device: "torch.device",
        state_dict: dict[str, Any],
        unexpected_keys: Optional[list[str]] = None,
    ):
        """
        Quantizes weights to INT8 symmetric format.
        """
        abs_max_per_row = torch.max(torch.abs(param_value), dim=1, keepdim=True)[0].clamp(min=1e-5)

        weight_scale = abs_max_per_row / 127.0

        weight_quantized = torch.round(param_value / weight_scale).clamp(-128, 127).to(torch.int8)

        module, tensor_name = get_module_from_name(model, param_name)
        module._buffers[tensor_name] = weight_quantized.to(target_device)
        module._buffers["weight_scale"] = weight_scale.to(target_device)

    def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
        not_missing_keys = []
        for name, module in model.named_modules():
            if isinstance(module, Int8SymmetricLinear):
                for missing in missing_keys:
                    if (
                        (name in missing or name in f"{prefix}.{missing}")
                        and not missing.endswith(".weight")
                        and not missing.endswith(".bias")
                    ):
                        not_missing_keys.append(missing)
        return [k for k in missing_keys if k not in not_missing_keys]

    def _process_model_after_weight_loading(self, model, **kwargs):
        """
        Post-processing after weights are loaded.
        """
        return True

    def is_serializable(self, safe_serialization=None):
        return True

    @property
    def is_trainable(self) -> bool:
        return False


# Example usage
if __name__ == "__main__":
    model_int8 = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-3.2-1B", quantization_config=Int8SymmetricConfig(), torch_dtype=torch.float, device_map="cpu"
    )

    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
    input_text = "once there is"
    inputs = tokenizer(input_text, return_tensors="pt").to("cpu")
    output = model_int8.generate(
        **inputs,
        max_length=100,
        num_return_sequences=1,
        no_repeat_ngram_size=2,
    )
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    print(generated_text)

    # Save and upload to HUB
    output_model_dir = "Llama-3.2-1B-INT8-CUSTOM"
    model_int8.save_pretrained(output_model_dir)
    tokenizer.save_pretrained(output_model_dir)
    api = HfApi()
    repo_id = "medmekk/Llama-3.2-1B-INT8-CUSTOM"
    api.create_repo(repo_id, private=False)
    api.upload_folder(folder_path=output_model_dir, repo_id=repo_id, repo_type="model")