| import torch |
| import torch.nn as nn |
| import quiptools_cuda |
| from lib.utils import dtype_from_str, get_hadK |
| from lib import codebook |
| import time |
|
|
|
|
| class QuantizedLinear(nn.Module): |
|
|
| def __init__(self, |
| in_features, |
| out_features, |
| codesz, |
| packsz, |
| pack_out, |
| idx_dtype, |
| codebook_version, |
| outlier_channel_split=False, |
| rank=-1, |
| rescale_WH=False, |
| bias=False): |
| super().__init__() |
|
|
| self.in_features = in_features |
| self.out_features = out_features |
| self.outlier_channel_split = outlier_channel_split |
| self.rank = rank |
| self.rescale_WH = rescale_WH |
|
|
| self.has_bias = bias |
| if self.has_bias: |
| self.register_buffer('bias', torch.ones(out_features)) |
| |
| if self.outlier_channel_split: |
| self.register_buffer('ocs_dupe_inds', torch.arange(in_features)) |
|
|
| if self.rank > 0: |
| self.register_buffer('A', torch.zeros(out_features, rank)) |
| self.register_buffer('B', torch.zeros(rank, in_features)) |
| else: |
| self.A = None |
| self.B = None |
|
|
| if self.rescale_WH: |
| self.register_buffer("scaleWH", torch.ones(in_features)) |
| else: |
| self.scaleWH = None |
|
|
| |
| if pack_out: |
| self.register_buffer( |
| "Qidxs", |
| torch.zeros(out_features // packsz, |
| in_features // codesz, |
| dtype=dtype_from_str(idx_dtype))) |
| else: |
| self.register_buffer( |
| "Qidxs", |
| torch.zeros(out_features, |
| in_features // (codesz * packsz), |
| dtype=dtype_from_str(idx_dtype))) |
|
|
| self.register_buffer("codebook_id", torch.tensor(0)) |
| self.register_buffer("SU", torch.ones(in_features)) |
| self.register_buffer("SV", torch.ones(out_features)) |
| self.register_buffer("Wscale", torch.ones(())) |
|
|
| self.built_codebook_class = False |
| self.built_graph = False |
| self.codebook_version = codebook_version |
|
|
| had_left, K_left = get_hadK(in_features) |
| had_right, K_right = get_hadK(out_features) |
| self.register_buffer('had_left', had_left, persistent=False) |
| self.register_buffer('had_right', had_right, persistent=False) |
| self.K_left = K_left |
| self.K_right = K_right |
| self.packed = (packsz != 1) |
|
|
| def forward(self, input): |
| if not self.built_codebook_class: |
| self.codebook_class = codebook.get_quantized_class(self.codebook_id.item())( |
| self.Qidxs.device) |
| if self.codebook_class.codebook.version != self.codebook_version: |
| raise Exception( |
| f"Saved weights version ({self.codebook_version}) does not match the "\ |
| f"codebook version ({self.codebook_class.codebook.version}). "\ |
| "Please download the latest weights from https://huggingface.co/relaxml") |
| self.built_codebook_class = True |
|
|
| if self.outlier_channel_split: |
| input = input[..., self.ocs_dupe_inds] |
|
|
| result = self.codebook_class(input, |
| self.Qidxs, |
| self.SU, |
| self.SV, |
| self.Wscale, |
| self.had_left, |
| self.had_right, |
| self.K_left, |
| self.K_right, |
| rank=self.rank, |
| A=self.A, |
| B=self.B, |
| rescale_WH=self.rescale_WH, |
| scaleWH=self.scaleWH, |
| packed=self.packed) |
| if self.has_bias: |
| return result + self.bias |
| return result |
| |
|
|