# Modified from https://github.com/bitsandbytes-foundation/bitsandbytes/blob/888788d75db8ff8e8888838307119f98d1235c24/bitsandbytes/nn/modules.py#L377 # TODO: support IPEX import warnings from typing import Any, Optional import torch from bitsandbytes.functional import dequantize_4bit from bitsandbytes.nn.modules import Params4bit, fix_4bit_weight_quant_state_from_module from torch import nn from ..functional import moe_fused_linear from ..moe_fused_linear import MoeFusedLinear # TODO: Fuse this def moe_fused_linear_4bit(input: torch.Tensor, weight: Params4bit, m_sizes: torch.Tensor) -> torch.Tensor: assert not weight.requires_grad # Cast weight to input.dtype # The grouped GEMM kernels use float32 accumulator weight = dequantize_4bit(weight, weight.quant_state).to(input.dtype) return moe_fused_linear(input, weight, m_sizes) class MoeFusedLinear4bit(MoeFusedLinear): def __init__( self, in_features: int, out_features: int, num_experts: int, *, weight: Optional[nn.Parameter] = None, # Used for initializing from a non-quantized module compute_dtype: Optional[torch.dtype] = None, compress_statistics: bool = True, quant_type: str = "fp4", quant_storage: torch.dtype = torch.uint8, device: Optional[torch.device] = None, ) -> None: super().__init__(in_features, out_features, num_experts, device=device) self.weight = Params4bit( self.weight, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self, ) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype self.compute_type_is_set = compute_dtype is not None self.quant_state = None self.quant_storage = quant_storage def set_compute_type(self, x: torch.Tensor) -> None: if x.dtype in [torch.float32, torch.bfloat16]: # the input is in a dtype that is safe to compute in, we switch # to this type for speed and stability self.compute_dtype = x.dtype elif x.dtype == torch.float16: # we take the compoute dtype passed into the layer if self.compute_dtype in [None, torch.float32] and (x.numel() == x.shape[-1]): # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast # warn the user about this warnings.warn( "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). " "This will lead to slow inference.", ) warnings.filterwarnings("ignore", message=".*inference.") if self.compute_dtype in [None, torch.float32] and (x.numel() != x.shape[-1]): warnings.warn( "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). " "This will lead to slow inference or training speed.", ) warnings.filterwarnings("ignore", message=".*inference or training") def _save_to_state_dict(self, destination: dict[str, Any], prefix: str, keep_vars: bool) -> None: super()._save_to_state_dict(destination, prefix, keep_vars) if getattr(self.weight, "quant_state", None) is not None: for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() def forward(self, x: torch.Tensor, m_sizes: torch.Tensor) -> torch.Tensor: fix_4bit_weight_quant_state_from_module(self) if not self.compute_type_is_set: self.set_compute_type(x) self.compute_type_is_set = True inp_dtype = x.dtype if self.compute_dtype is not None: x = x.to(self.compute_dtype) x = moe_fused_linear_4bit(x, self.weight, m_sizes) x = x.to(inp_dtype) return x