| import math | |
| from torch import nn | |
| from transformers.models.llama.modeling_llama import * | |
| def activation_quant(x, n_bits = 8): | |
| q_min = - 2**(n_bits - 1) | |
| q_max = 2**(n_bits - 1) - 1 | |
| scale = q_max / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) | |
| x_quant = (x * scale).round().clamp_(q_min, q_max) / scale | |
| return x_quant | |
| def weight_quant(w): | |
| scale = 1 / w.abs().mean().clamp_(min=1e-5) | |
| w_quant = (w * scale).round().clamp_(-1, 1) / scale | |
| return w_quant | |
| class BitLinear(nn.Linear): | |
| def __init__(self, | |
| *kargs, | |
| weight_bits=1, | |
| input_bits=8, | |
| **kwargs | |
| ): | |
| super(BitLinear, self).__init__(*kargs, **kwargs) | |
| def forward(self, x): | |
| w = self.weight # a weight tensor with shape [d, k] | |
| x = x.to(w.device) | |
| RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device) | |
| x_norm = RMSNorm(x) | |
| # A trick for implementing Straight−Through−Estimator (STE) using detach() | |
| x_quant = x_norm + (activation_quant(x_norm, 8) - x_norm).detach() | |
| w_quant = w + (weight_quant(w) - w).detach() | |
| y = F.linear(x_quant, w_quant) | |
| return y |