| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| def modified_weight_quant(w): | |
| """ Per−tensor quantization to 1.58 bits. No grouping is needed for quantization. | |
| Args: | |
| w: a weight tensor with shape [d, k] | |
| Returns: | |
| u: a quantized weight with shape [d, k] | |
| """ | |
| u = w.clamp(-1, 1).round() | |
| return u | |
| def normalize(w): | |
| w = w / torch.norm(w, dim=1, keepdim=True) | |
| return w | |
| class QLinear(nn.Linear): | |
| def __init__(self, | |
| *kargs, | |
| **kwargs | |
| ): | |
| super(QLinear, self).__init__(*kargs, **kwargs) | |
| """ | |
| This is only for training, and kernel optimization is needed for efficiency. | |
| """ | |
| self.scales = nn.Parameter(torch.ones(self.out_features)) | |
| self.quantizer = modified_weight_quant | |
| def forward(self, x): | |
| """i | |
| Args: | |
| x: an input tensor with shape [n, d] | |
| Returns: | |
| y: an output tensor with shape [n, d] | |
| """ | |
| w_quant = self.weight | |
| x = x.to(w_quant.device) | |
| # STE weight quantization | |
| w_quant = w_quant + (self.quantizer(w_quant) - w_quant).detach() | |
| y = F.linear(x, w_quant) | |
| # apply scales post matmul | |
| y = y * self.scales | |
| if self.bias is not None: | |
| y = y + self.bias | |
| return y |