File size: 833 Bytes
b3c0032 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
import torch.nn as nn
import quiptools_cuda
from lib.utils import dtype_from_str, get_hadK
from lib import codebook
from .quantized_linear import QuantizedLinear
import time
class FusedQuantizedLinear(QuantizedLinear):
def __init__(self, fuse_dim, fuse_sizes, *QL_args, **QL_kwargs):
super(FusedQuantizedLinear, self).__init__(*QL_args, **QL_kwargs)
self.fuse_dim = fuse_dim
self.fuse_sizes = fuse_sizes
self.register_buffer('fuse_scales', torch.ones(len(self.fuse_sizes)))
self.n = len(self.fuse_sizes)
def forward(self, input):
fused_output = super(FusedQuantizedLinear, self).forward(input)
split_outputs = torch.split(fused_output, self.fuse_sizes, self.fuse_dim)
return tuple(split_outputs[i] * self.fuse_scales[i] for i in range(self.n))
|