|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
def modified_weight_quant(w): |
|
|
""" Per−tensor quantization to 1.58 bits. No grouping is needed for quantization. |
|
|
Args: |
|
|
w: a weight tensor with shape [d, k] |
|
|
Returns: |
|
|
u: a quantized weight with shape [d, k] |
|
|
""" |
|
|
u = w.clamp(-1, 1).round() |
|
|
return u |
|
|
|
|
|
def normalize(w): |
|
|
w = w / torch.norm(w, dim=1, keepdim=True) |
|
|
return w |
|
|
|
|
|
class QLinear(nn.Linear): |
|
|
def __init__(self, |
|
|
*kargs, |
|
|
**kwargs |
|
|
): |
|
|
super(QLinear, self).__init__(*kargs, **kwargs) |
|
|
""" |
|
|
This is only for training, and kernel optimization is needed for efficiency. |
|
|
""" |
|
|
self.scales = nn.Parameter(torch.ones(self.out_features)) |
|
|
self.quantizer = modified_weight_quant |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
"""i |
|
|
Args: |
|
|
x: an input tensor with shape [n, d] |
|
|
Returns: |
|
|
y: an output tensor with shape [n, d] |
|
|
""" |
|
|
w_quant = self.weight |
|
|
x = x.to(w_quant.device) |
|
|
|
|
|
w_quant = w_quant + (self.quantizer(w_quant) - w_quant).detach() |
|
|
y = F.linear(x, w_quant) |
|
|
|
|
|
y = y * self.scales |
|
|
if self.bias is not None: |
|
|
y = y + self.bias |
|
|
return y |