import torch import torch.nn as nn import quiptools_cuda from lib.utils import dtype_from_str, get_hadK from lib import codebook from .quantized_linear import QuantizedLinear import time class FusedQuantizedLinear(QuantizedLinear): def __init__(self, fuse_dim, fuse_sizes, *QL_args, **QL_kwargs): super(FusedQuantizedLinear, self).__init__(*QL_args, **QL_kwargs) self.fuse_dim = fuse_dim self.fuse_sizes = fuse_sizes self.register_buffer('fuse_scales', torch.ones(len(self.fuse_sizes))) self.n = len(self.fuse_sizes) def forward(self, input): fused_output = super(FusedQuantizedLinear, self).forward(input) split_outputs = torch.split(fused_output, self.fuse_sizes, self.fuse_dim) return tuple(split_outputs[i] * self.fuse_scales[i] for i in range(self.n))