| import torch |
| import torch.nn as nn |
| import quiptools_cuda |
| from lib.utils import dtype_from_str, get_hadK |
| from lib import codebook |
| from .quantized_linear import QuantizedLinear |
| import time |
|
|
|
|
| class FusedQuantizedLinear(QuantizedLinear): |
|
|
| def __init__(self, fuse_dim, fuse_sizes, *QL_args, **QL_kwargs): |
| super(FusedQuantizedLinear, self).__init__(*QL_args, **QL_kwargs) |
| self.fuse_dim = fuse_dim |
| self.fuse_sizes = fuse_sizes |
| self.register_buffer('fuse_scales', torch.ones(len(self.fuse_sizes))) |
| self.n = len(self.fuse_sizes) |
|
|
| def forward(self, input): |
| fused_output = super(FusedQuantizedLinear, self).forward(input) |
| split_outputs = torch.split(fused_output, self.fuse_sizes, self.fuse_dim) |
| return tuple(split_outputs[i] * self.fuse_scales[i] for i in range(self.n)) |
|
|