|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
def quantize(x, scale, zero, maxq): |
|
|
if maxq < 0: |
|
|
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero |
|
|
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) |
|
|
return scale * (q - zero) |
|
|
|
|
|
class Quantizer(nn.Module): |
|
|
|
|
|
def __init__(self, shape=1): |
|
|
super(Quantizer, self).__init__() |
|
|
self.register_buffer('maxq', torch.tensor(0)) |
|
|
self.register_buffer('scale', torch.zeros(shape)) |
|
|
self.register_buffer('zero', torch.zeros(shape)) |
|
|
|
|
|
def configure( |
|
|
self, |
|
|
bits, perchannel=False, sym=True, |
|
|
mse=False, norm=2.4, grid=100, maxshrink=.8, |
|
|
trits=False |
|
|
): |
|
|
self.maxq = torch.tensor(2 ** bits - 1) |
|
|
self.perchannel = perchannel |
|
|
self.sym = sym |
|
|
self.mse = mse |
|
|
self.norm = norm |
|
|
self.grid = grid |
|
|
self.maxshrink = maxshrink |
|
|
if trits: |
|
|
self.maxq = torch.tensor(-1) |
|
|
|
|
|
def find_params(self, x, weight=False): |
|
|
dev = x.device |
|
|
self.maxq = self.maxq.to(dev) |
|
|
|
|
|
shape = x.shape |
|
|
if self.perchannel: |
|
|
if weight: |
|
|
x = x.flatten(1) |
|
|
else: |
|
|
if len(shape) == 4: |
|
|
x = x.permute([1, 0, 2, 3]) |
|
|
x = x.flatten(1) |
|
|
if len(shape) == 3: |
|
|
x = x.reshape((-1, shape[-1])).t() |
|
|
if len(shape) == 2: |
|
|
x = x.t() |
|
|
else: |
|
|
x = x.flatten().unsqueeze(0) |
|
|
|
|
|
tmp = torch.zeros(x.shape[0], device=dev) |
|
|
xmin = torch.minimum(x.min(1)[0], tmp) |
|
|
xmax = torch.maximum(x.max(1)[0], tmp) |
|
|
|
|
|
if self.sym: |
|
|
xmax = torch.maximum(torch.abs(xmin), xmax) |
|
|
tmp = xmin < 0 |
|
|
if torch.any(tmp): |
|
|
xmin[tmp] = -xmax[tmp] |
|
|
tmp = (xmin == 0) & (xmax == 0) |
|
|
xmin[tmp] = -1 |
|
|
xmax[tmp] = +1 |
|
|
|
|
|
if self.maxq < 0: |
|
|
self.scale = xmax |
|
|
self.zero = xmin |
|
|
else: |
|
|
self.scale = (xmax - xmin) / self.maxq |
|
|
if self.sym: |
|
|
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) |
|
|
else: |
|
|
self.zero = torch.round(-xmin / self.scale) |
|
|
|
|
|
if self.mse: |
|
|
best = torch.full([x.shape[0]], float('inf'), device=dev) |
|
|
for i in range(int(self.maxshrink * self.grid)): |
|
|
p = 1 - i / self.grid |
|
|
xmin1 = p * xmin |
|
|
xmax1 = p * xmax |
|
|
scale1 = (xmax1 - xmin1) / self.maxq |
|
|
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero |
|
|
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) |
|
|
q -= x |
|
|
q.abs_() |
|
|
q.pow_(self.norm) |
|
|
err = torch.sum(q, 1) |
|
|
tmp = err < best |
|
|
if torch.any(tmp): |
|
|
best[tmp] = err[tmp] |
|
|
self.scale[tmp] = scale1[tmp] |
|
|
self.zero[tmp] = zero1[tmp] |
|
|
if not self.perchannel: |
|
|
if weight: |
|
|
tmp = shape[0] |
|
|
else: |
|
|
tmp = shape[1] if len(shape) != 3 else shape[2] |
|
|
self.scale = self.scale.repeat(tmp) |
|
|
self.zero = self.zero.repeat(tmp) |
|
|
|
|
|
if weight: |
|
|
shape = [-1] + [1] * (len(shape) - 1) |
|
|
self.scale = self.scale.reshape(shape) |
|
|
self.zero = self.zero.reshape(shape) |
|
|
return |
|
|
if len(shape) == 4: |
|
|
self.scale = self.scale.reshape((1, -1, 1, 1)) |
|
|
self.zero = self.zero.reshape((1, -1, 1, 1)) |
|
|
if len(shape) == 3: |
|
|
self.scale = self.scale.reshape((1, 1, -1)) |
|
|
self.zero = self.zero.reshape((1, 1, -1)) |
|
|
if len(shape) == 2: |
|
|
self.scale = self.scale.unsqueeze(0) |
|
|
self.zero = self.zero.unsqueeze(0) |
|
|
|
|
|
def quantize(self, x): |
|
|
if self.ready(): |
|
|
return quantize(x, self.scale, self.zero, self.maxq) |
|
|
return x |
|
|
|
|
|
def enabled(self): |
|
|
return self.maxq > 0 |
|
|
|
|
|
def ready(self): |
|
|
return torch.all(self.scale != 0) |
|
|
|
|
|
|
|
|
try: |
|
|
import quant_cuda |
|
|
except: |
|
|
print('CUDA extension not installed.') |
|
|
|
|
|
|
|
|
class Quant3Linear(nn.Module): |
|
|
|
|
|
def __init__(self, infeatures, outfeatures, faster=False): |
|
|
super().__init__() |
|
|
self.register_buffer('zeros', torch.zeros((outfeatures, 1))) |
|
|
self.register_buffer('scales', torch.zeros((outfeatures, 1))) |
|
|
self.register_buffer('bias', torch.zeros(outfeatures)) |
|
|
self.register_buffer( |
|
|
'qweight', torch.zeros((infeatures // 32 * 3, outfeatures), dtype=torch.int) |
|
|
) |
|
|
self.faster = faster |
|
|
|
|
|
def pack(self, linear, scales, zeros): |
|
|
self.zeros = zeros * scales |
|
|
self.scales = scales.clone() |
|
|
if linear.bias is not None: |
|
|
self.bias = linear.bias.clone() |
|
|
|
|
|
intweight = torch.round((linear.weight.data + self.zeros) / self.scales).to(torch.int) |
|
|
intweight = intweight.t().contiguous() |
|
|
intweight = intweight.numpy().astype(np.uint32) |
|
|
qweight = np.zeros( |
|
|
(intweight.shape[0] // 32 * 3, intweight.shape[1]), dtype=np.uint32 |
|
|
) |
|
|
i = 0 |
|
|
row = 0 |
|
|
while row < qweight.shape[0]: |
|
|
for j in range(i, i + 10): |
|
|
qweight[row] |= intweight[j] << (3 * (j - i)) |
|
|
i += 10 |
|
|
qweight[row] |= intweight[i] << 30 |
|
|
row += 1 |
|
|
qweight[row] |= (intweight[i] >> 2) & 1 |
|
|
i += 1 |
|
|
for j in range(i, i + 10): |
|
|
qweight[row] |= intweight[j] << (3 * (j - i) + 1) |
|
|
i += 10 |
|
|
qweight[row] |= intweight[i] << 31 |
|
|
row += 1 |
|
|
qweight[row] |= (intweight[i] >> 1) & 0x3 |
|
|
i += 1 |
|
|
for j in range(i, i + 10): |
|
|
qweight[row] |= intweight[j] << (3 * (j - i) + 2) |
|
|
i += 10 |
|
|
row += 1 |
|
|
|
|
|
qweight = qweight.astype(np.int32) |
|
|
self.qweight = torch.from_numpy(qweight) |
|
|
|
|
|
def forward(self, x): |
|
|
if x.shape[-1] == x.numel(): |
|
|
outshape = list(x.shape) |
|
|
y = self.bias.clone() |
|
|
outshape[-1] = self.bias.numel() |
|
|
dtype = x.dtype |
|
|
if self.faster: |
|
|
x = x.half() |
|
|
quant_cuda.vecquant3matmul_faster(x, self.qweight, y, self.scales, self.zeros) |
|
|
else: |
|
|
x = x.float() |
|
|
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.zeros) |
|
|
y = y.to(dtype) |
|
|
return y.reshape(outshape) |
|
|
raise ValueError('Only supports a single token currently.') |
|
|
|
|
|
def make_quant3(module, names, name='', faster=False): |
|
|
if isinstance(module, Quant3Linear): |
|
|
return |
|
|
for attr in dir(module): |
|
|
tmp = getattr(module, attr) |
|
|
name1 = name + '.' + attr if name != '' else attr |
|
|
if name1 in names: |
|
|
setattr( |
|
|
module, attr, Quant3Linear(tmp.in_features, tmp.out_features, faster=faster) |
|
|
) |
|
|
for name1, child in module.named_children(): |
|
|
make_quant3(child, names, name + '.' + name1 if name != '' else name1, faster=faster) |
|
|
|