| import bz2 | |
| import torch | |
| import base64 | |
| import ctypes | |
| import os | |
| import sys | |
| import traceback | |
| import math | |
| from torch.nn.parameter import Parameter | |
| from transformers.utils import logging | |
| import ctypes | |
| import pkg_resources | |
| from typing import List | |
| logger = logging.get_logger(__name__) | |
| try: | |
| import quant_cuda | |
| except: | |
| print('CUDA extension not installed.') | |
| class QuantizedLinear(torch.nn.Module): | |
| def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args, | |
| **kwargs): | |
| super().__init__() | |
| self.weight_bit_width = weight_bit_width | |
| shape = weight.shape | |
| self.shape = shape | |
| self.group_size = 128 | |
| self.register_buffer('qzeros', torch.zeros((math.ceil(shape[1]/self.group_size),shape[0] // 256 * (weight_bit_width * 8)), dtype=torch.int)) | |
| self.register_buffer('scales', torch.zeros((math.ceil(shape[1]/self.group_size),shape[0]), dtype=torch.float)) | |
| self.register_buffer( | |
| 'qweight', torch.zeros((shape[1] // 256 * (weight_bit_width * 8), shape[0]), dtype=torch.int) | |
| ) | |
| def forward(self, x): | |
| intermediate_dtype = torch.float32 | |
| outshape = list(x.shape) | |
| outshape[-1] = self.shape[0] | |
| x = x.reshape(-1, x.shape[-1]) | |
| y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device) | |
| output_dtype = x.dtype | |
| x = x.to(intermediate_dtype) | |
| if self.weight_bit_width == 2: | |
| quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size) | |
| elif self.weight_bit_width == 3: | |
| quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size) | |
| elif self.weight_bit_width == 4: | |
| quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size) | |
| elif self.weight_bit_width == 8: | |
| quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.group_size) | |
| else: | |
| raise NotImplementedError("Only 2,3,4,8 bits are supported.") | |
| y = y.to(output_dtype) | |
| return y.reshape(outshape) | |
| def quantize(model, weight_bit_width, empty_init=False, device=None): | |
| for layer in model.layers: | |
| layer.self_attn.q_proj = QuantizedLinear( | |
| weight_bit_width=weight_bit_width, | |
| weight=layer.self_attn.q_proj.weight, | |
| bias=layer.self_attn.q_proj.bias, | |
| dtype=layer.self_attn.q_proj.weight.dtype, | |
| device=layer.self_attn.q_proj.weight.device if device is None else device, | |
| empty_init=empty_init | |
| ) | |
| layer.self_attn.k_proj = QuantizedLinear( | |
| weight_bit_width=weight_bit_width, | |
| weight=layer.self_attn.k_proj.weight, | |
| bias=layer.self_attn.k_proj.bias, | |
| dtype=layer.self_attn.k_proj.weight.dtype, | |
| device=layer.self_attn.k_proj.weight.device if device is None else device, | |
| empty_init=empty_init | |
| ) | |
| layer.self_attn.v_proj = QuantizedLinear( | |
| weight_bit_width=weight_bit_width, | |
| weight=layer.self_attn.v_proj.weight, | |
| bias=layer.self_attn.v_proj.bias, | |
| dtype=layer.self_attn.v_proj.weight.dtype, | |
| device=layer.self_attn.v_proj.weight.device if device is None else device, | |
| empty_init=empty_init | |
| ) | |
| layer.self_attn.o_proj = QuantizedLinear( | |
| weight_bit_width=weight_bit_width, | |
| weight=layer.self_attn.o_proj.weight, | |
| bias=layer.self_attn.o_proj.bias, | |
| dtype=layer.self_attn.o_proj.weight.dtype, | |
| device=layer.self_attn.o_proj.weight.device if device is None else device, | |
| empty_init=empty_init | |
| ) | |
| layer.mlp.gate_proj = QuantizedLinear( | |
| weight_bit_width=weight_bit_width, | |
| weight=layer.mlp.gate_proj.weight, | |
| bias=layer.mlp.gate_proj.bias, | |
| dtype=layer.mlp.gate_proj.weight.dtype, | |
| device=layer.mlp.gate_proj.weight.device if device is None else device, | |
| empty_init=empty_init | |
| ) | |
| layer.mlp.down_proj = QuantizedLinear( | |
| weight_bit_width=weight_bit_width, | |
| weight=layer.mlp.down_proj.weight, | |
| bias=layer.mlp.down_proj.bias, | |
| dtype=layer.mlp.down_proj.weight.dtype, | |
| device=layer.mlp.down_proj.weight.device if device is None else device, | |
| empty_init=empty_init | |
| ) | |
| layer.mlp.up_proj = QuantizedLinear( | |
| weight_bit_width=weight_bit_width, | |
| weight=layer.mlp.up_proj.weight, | |
| bias=layer.mlp.up_proj.bias, | |
| dtype=layer.mlp.up_proj.weight.dtype, | |
| device=layer.mlp.up_proj.weight.device if device is None else device, | |
| empty_init=empty_init | |
| ) | |
| return model | |