| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | import re |
| | import transformers |
| |
|
| | class ReplacedLinearLayer(nn.Module): |
| | def __init__(self, input_dim, output_dim, if_conv=True): |
| | super().__init__() |
| |
|
| | self.register_buffer('weights', torch.zeros([output_dim, input_dim], dtype=torch.int8)) |
| | self.register_buffer('scale_matrix', torch.zeros(output_dim, dtype=torch.int8)) |
| |
|
| | |
| | self.bias = None |
| | self.if_conv = if_conv |
| |
|
| | def forward(self, x): |
| | fp32_weights = self.weights.to(x.dtype) |
| | |
| | try: |
| | x = F.linear(x, fp32_weights )* self.scales |
| | if self.bias is not None: |
| | x += self.bias |
| | except Exception as e: |
| | print(e) |
| | print(fp32_weights.shape, self.scales.shape, ) |
| | |
| | exit() |
| | return x |
| | |
| | def do_quantization(self, W, ): |
| | if self.if_conv: |
| | W32 = W.clone().squeeze().T |
| | else: |
| | W32 = W.clone() |
| |
|
| | scales = (torch.max(W32.abs(), dim=-1)[0]/127).to(torch.float32) |
| | self.scales = scales |
| | self.weights = torch.round(W32 / scales[:, None]).to(torch.int8) |
| |
|
| |
|
| | def perform_quantization(module, regex='.*'): |
| | pattern = re.compile(regex) |
| | for name, node in module.named_modules(): |
| | for name2, child in node.named_children(): |
| | if ( isinstance(child, nn.Linear) or isinstance(child, transformers.pytorch_utils.Conv1D) ) and pattern.match(f'{name}.{name2}'): |
| | |
| | fp32_weight, fp32_bias = child.weight, child.bias |
| |
|
| | quant_module = ReplacedLinearLayer(child.weight.shape[1], child.weight.shape[0], if_conv=isinstance(child, transformers.pytorch_utils.Conv1D)) |
| | setattr(node, name2, quant_module) |
| | |
| | |
| | getattr(node, name2).do_quantization(fp32_weight) |
| | if fp32_bias is not None: |
| | getattr(node, name2).bias = fp32_bias |
| | |
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| |
|
| | |
| |
|