File size: 4,210 Bytes
ca700c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Modified from https://github.com/bitsandbytes-foundation/bitsandbytes/blob/888788d75db8ff8e8888838307119f98d1235c24/bitsandbytes/nn/modules.py#L377
# TODO: support IPEX

import warnings
from typing import Any, Optional

import torch
from bitsandbytes.functional import dequantize_4bit
from bitsandbytes.nn.modules import Params4bit, fix_4bit_weight_quant_state_from_module
from torch import nn

from ..functional import moe_fused_linear
from ..moe_fused_linear import MoeFusedLinear


# TODO: Fuse this
def moe_fused_linear_4bit(input: torch.Tensor, weight: Params4bit, m_sizes: torch.Tensor) -> torch.Tensor:
    assert not weight.requires_grad
    # Cast weight to input.dtype
    # The grouped GEMM kernels use float32 accumulator
    weight = dequantize_4bit(weight, weight.quant_state).to(input.dtype)
    return moe_fused_linear(input, weight, m_sizes)


class MoeFusedLinear4bit(MoeFusedLinear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        num_experts: int,
        *,
        weight: Optional[nn.Parameter] = None,  # Used for initializing from a non-quantized module
        compute_dtype: Optional[torch.dtype] = None,
        compress_statistics: bool = True,
        quant_type: str = "fp4",
        quant_storage: torch.dtype = torch.uint8,
        device: Optional[torch.device] = None,
    ) -> None:
        super().__init__(in_features, out_features, num_experts, device=device)
        self.weight = Params4bit(
            self.weight,
            requires_grad=False,
            compress_statistics=compress_statistics,
            quant_type=quant_type,
            quant_storage=quant_storage,
            module=self,
        )
        # self.persistent_buffers = []  # TODO consider as way to save quant state
        self.compute_dtype = compute_dtype
        self.compute_type_is_set = compute_dtype is not None
        self.quant_state = None
        self.quant_storage = quant_storage

    def set_compute_type(self, x: torch.Tensor) -> None:
        if x.dtype in [torch.float32, torch.bfloat16]:
            # the input is in a dtype that is safe to compute in, we switch
            # to this type for speed and stability
            self.compute_dtype = x.dtype
        elif x.dtype == torch.float16:
            # we take the compoute dtype passed into the layer
            if self.compute_dtype in [None, torch.float32] and (x.numel() == x.shape[-1]):
                # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
                # warn the user about this
                warnings.warn(
                    "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). "
                    "This will lead to slow inference.",
                )
                warnings.filterwarnings("ignore", message=".*inference.")
            if self.compute_dtype in [None, torch.float32] and (x.numel() != x.shape[-1]):
                warnings.warn(
                    "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). "
                    "This will lead to slow inference or training speed.",
                )
                warnings.filterwarnings("ignore", message=".*inference or training")

    def _save_to_state_dict(self, destination: dict[str, Any], prefix: str, keep_vars: bool) -> None:
        super()._save_to_state_dict(destination, prefix, keep_vars)

        if getattr(self.weight, "quant_state", None) is not None:
            for k, v in self.weight.quant_state.as_dict(packed=True).items():
                destination[prefix + "weight." + k] = v if keep_vars else v.detach()

    def forward(self, x: torch.Tensor, m_sizes: torch.Tensor) -> torch.Tensor:
        fix_4bit_weight_quant_state_from_module(self)

        if not self.compute_type_is_set:
            self.set_compute_type(x)
            self.compute_type_is_set = True

        inp_dtype = x.dtype
        if self.compute_dtype is not None:
            x = x.to(self.compute_dtype)

        x = moe_fused_linear_4bit(x, self.weight, m_sizes)
        x = x.to(inp_dtype)
        return x