File size: 15,718 Bytes
6f0b660 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 |
from ..utils import is_accelerate_available, is_torch_available, logging
if is_accelerate_available():
from accelerate import init_empty_weights
if is_torch_available():
import torch
import torch.nn as nn
import torch.nn.functional as F
logger = logging.get_logger(__name__)
# the weights are ternary so can be represented with 2 bits, and they are packed in uint8 tensors, hence the number of values per item is 4
VALUES_PER_ITEM = 4
def pack_weights(quantized_weights: torch.Tensor) -> torch.Tensor:
"""
Packs a tensor of quantized weights into a compact format using 2 bits per value.
Parameters:
-----------
quantized_weights : torch.Tensor
A tensor containing ternary quantized weights with values in {-1, 0, 1}. These values are adjusted to
{0, 1, 2} before being packed.
Returns:
--------
torch.Tensor
A packed tensor where each element stores 4 quantized values (each using 2 bits) in an 8-bit format.
"""
original_shape = quantized_weights.shape
row_dim = (original_shape[0] + VALUES_PER_ITEM - 1) // VALUES_PER_ITEM
if len(original_shape) == 1:
packed_tensor_shape = (row_dim,)
else:
packed_tensor_shape = (row_dim, *original_shape[1:])
quantized_weights += 1
packed = torch.zeros(packed_tensor_shape, device=quantized_weights.device, dtype=torch.uint8)
unpacked = quantized_weights.to(torch.uint8)
it = min(VALUES_PER_ITEM, (original_shape[0] // row_dim) + 1)
for i in range(it):
start = i * row_dim
end = min(start + row_dim, original_shape[0])
packed[: (end - start)] |= unpacked[start:end] << 2 * i
return packed
@torch.compile
def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""
Unpacks a tensor of quantized weights that were stored in a packed format using 2 bits per value.
Parameters:
-----------
packed : torch.Tensor
A tensor containing packed weights where each element represents 4 quantized values (using 2 bits per value).
dtype : torch.dtype
The dtype of the returned Tensor
Returns:
--------
torch.Tensor
A tensor of unpacked weights, where each value is converted from its packed 2-bit representation.
Example:
--------
packed = torch.tensor([[0b10100001, 0b00011000],
[0b10010000, 0b00001010]], dtype=torch.uint8)
# Unpack the values
unpacked = unpack_weights(packed)
# Resulting unpacked tensor
print(unpacked)
# Output: tensor([[ 0, -1],
[-1, 1],
[-1, 1],
[-1, 1],
[ 1, 0],
[ 0, -1],
[ 1, -1],
[ 1, -1]])
Explanation of the example:
---------------------------
Let's take the first value for example 0b10100001, we we will only focus on the first column,
because every element is unpacked across the first dimension
- First 2 bits: `01` → 0 at [0][0]
- Second 2 bits: `00` → -1 at [0][2]
- Third 2 bits: `10` → 1 at [0][4]
- Fourth 2 bits: `10` → 1 at [0][6]
the second value of the same row (0b10010000) will give the values for [0][1], [0][3], [0][5], [0][7]
We subtract 1 because during the packing process, it's easier to work with values like 0, 1, and 2. To make this possible,
we add 1 to the original ternary weights (which are typically -1, 0, and 1) when packing them. When unpacking, we reverse
this by subtracting 1 to restore the original ternary values.
"""
packed_shape = packed.shape
if len(packed_shape) == 1:
original_row_dim = packed_shape[0] * VALUES_PER_ITEM
unpacked_shape = (original_row_dim,)
else:
original_row_dim = packed_shape[0] * VALUES_PER_ITEM
unpacked_shape = (original_row_dim, *packed_shape[1:])
unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8)
for i in range(VALUES_PER_ITEM):
start = i * packed_shape[0]
end = start + packed_shape[0]
mask = 3 << (2 * i)
unpacked[start:end] = (packed & mask) >> (2 * i)
return unpacked.to(dtype) - 1
class BitLinear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
device=None,
dtype=None,
use_rms_norm: bool = False,
rms_norm_eps: float = 1e-6,
):
super().__init__()
self.dtype = dtype
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
"weight",
torch.zeros(
(out_features // VALUES_PER_ITEM, in_features),
dtype=torch.uint8,
device=device,
),
)
self.register_buffer(
"weight_scale",
torch.ones(
(1),
dtype=dtype,
device=device,
),
)
if bias:
self.register_buffer("bias", torch.zeros((out_features), dtype=dtype, device=device))
else:
self.bias = None
# Optional RMSNorm (applied on the activations before quantization).
self.rms_norm = None
if use_rms_norm:
from ..models.llama.modeling_llama import LlamaRMSNorm
self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)
@torch.compile
def activation_quant(self, input, num_bits=8):
"""
Activation function : Performs symmetric, per-token quantization on the input activations.
Parameters:
-----------
x : torch.Tensor
Input activations to be quantized.
num_bits : int, optional (default=8)
Number of bits to use for quantization, determining the quantization range.
Returns:
--------
result : torch.Tensor
Quantized activation tensor, with values mapped to an `int8` range.
scale : torch.Tensor
The per-channel scaling factors used to quantize the tensor.
"""
Qn = -(2 ** (num_bits - 1))
Qp = 2 ** (num_bits - 1) - 1
scale = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (input * scale).round().clamp(Qn, Qp)
return result.to(torch.int8), scale
@torch.compile
def post_quant_process(self, input, input_scale, weight_scale):
out = input / (input_scale * weight_scale)
return out
def forward(self, input):
# Apply RMSNorm on the input if requested.
if self.rms_norm is not None:
input = self.rms_norm(input)
w = self.weight
w_quant = unpack_weights(w, dtype=self.dtype)
input_quant, input_scale = self.activation_quant(input)
y = F.linear(input_quant.to(self.dtype), w_quant)
y = self.post_quant_process(y, self.weight_scale, input_scale)
if self.bias is not None:
y += self.bias.view(1, -1).expand_as(y)
return y
class WeightQuant(torch.autograd.Function):
"""
Implements a custom autograd function for weight quantization.
This performs ternary quantization (-1, 0, 1) based on scaling by the
mean absolute value of the weights. It uses the Straight-Through Estimator
(STE) for the backward pass.
"""
@staticmethod
@torch.compile
def forward(ctx, weight):
dtype = weight.dtype
weight = weight.float()
scale = 1.0 / weight.abs().mean().clamp_(min=1e-5)
weight = (weight * scale).round().clamp(-1, 1) / scale
return weight.to(dtype)
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input
class ActQuant(torch.autograd.Function):
"""
Implements a custom autograd function for activation quantization.
This performs symmetric 8-bit quantization (to the range [-128, 127])
based on the maximum absolute value along the last dimension (per-token/row scaling).
It uses the Straight-Through Estimator (STE) for the backward pass.
"""
@staticmethod
@torch.compile
def forward(ctx, activation):
dtype = activation.dtype
activation = activation.float()
scale = 127 / activation.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
activation = (activation * scale).round().clamp(-128, 127) / scale
return activation.to(dtype)
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input
class AutoBitLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
online_quant: bool = False,
use_rms_norm: bool = False,
rms_norm_eps: float = 1e-6,
):
super().__init__(in_features, out_features, bias)
self.online_quant = online_quant
# Optional RMSNorm
self.rms_norm = None
if use_rms_norm:
from ..models.llama.modeling_llama import LlamaRMSNorm
self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)
if not online_quant:
self.register_buffer(
"weight_scale",
torch.ones(
(1),
dtype=dtype,
device=device,
),
)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict,
prefix,
*args,
**kwargs,
):
if (prefix + "weight") in state_dict and state_dict[prefix + "weight"].dtype != self.weight.dtype:
state_dict[prefix + "weight"] = unpack_weights(state_dict[prefix + "weight"], dtype=self.weight.dtype)
return state_dict
def forward(self, input):
# Optional RMSNorm on activations prior to quantization.
if self.rms_norm is not None:
input = self.rms_norm(input)
if self.online_quant:
weight = WeightQuant.apply(self.weight)
else:
weight = self.weight
input = ActQuant.apply(input)
output = F.linear(input, weight, self.bias)
if not self.online_quant:
output = output * self.weight_scale
return output
def _replace_with_bitnet_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
pre_quantized=False,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successful or not.
"""
if current_key_name is None:
current_key_name = []
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights():
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
in_features = module.in_features
out_features = module.out_features
if quantization_config and quantization_config.linear_class == "autobitlinear":
model._modules[name] = AutoBitLinear(
in_features=in_features,
out_features=out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
online_quant=(quantization_config.quantization_mode == "online"),
use_rms_norm=quantization_config.use_rms_norm,
rms_norm_eps=quantization_config.rms_norm_eps,
)
if quantization_config.quantization_mode == "offline":
model._modules[name].requires_grad_(False)
else:
model._modules[name] = BitLinear(
in_features=in_features,
out_features=out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
use_rms_norm=quantization_config.use_rms_norm if quantization_config else False,
rms_norm_eps=quantization_config.rms_norm_eps if quantization_config else 1e-6,
)
model._modules[name].requires_grad_(False)
has_been_replaced = True
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bitnet_linear(
module,
modules_to_not_convert=modules_to_not_convert,
current_key_name=current_key_name,
quantization_config=quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def replace_with_bitnet_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
pre_quantized=False,
):
"""
A helper function to replace all `torch.nn.Linear` modules by `BitLinear158` modules`.
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`list[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `BitLinear`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
current_key_name (`list[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
`disk`).
"""
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
if quantization_config and quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_bitnet_linear(
model,
modules_to_not_convert,
current_key_name,
quantization_config,
pre_quantized=pre_quantized,
)
if not has_been_replaced:
logger.warning(
"You are loading your model using bitnet but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model
|