Spaces:
Runtime error
Runtime error
text-generation-webui / venv /lib /python3.10 /site-packages /bitsandbytes /nn /triton_based_modules.py
| import torch | |
| import torch.nn as nn | |
| import time | |
| from functools import partial | |
| from bitsandbytes.triton.triton_utils import is_triton_available | |
| from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise | |
| from bitsandbytes.triton.quantize_rowwise import quantize_rowwise | |
| from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose | |
| from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize | |
| from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose | |
| from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze | |
| class _switchback_global(torch.autograd.Function): | |
| def forward(ctx, X_3D, W, bias): | |
| # reshape input to [N * L, D] | |
| X = X_3D.view(-1, X_3D.size(-1)) | |
| # rowwise quantize for X, global quantize for W | |
| X_int8, state_X = quantize_rowwise(X) | |
| W_int8, state_W = quantize_global(W) | |
| # save for backward. | |
| ctx.save_for_backward = X, W | |
| # matmult, fused dequant and add bias | |
| # call "mixed" because we are mixing rowwise quantized and global quantized | |
| return int8_matmul_mixed_dequanitze( | |
| X_int8, W_int8.t(), state_X, state_W, bias | |
| ).view(*X_3D.size()[:-1], -1) | |
| def backward(ctx, G_3D): | |
| # reshape input to [N_out * L, D] | |
| G = G_3D.reshape(-1, G_3D.size(-1)) | |
| grad_X = grad_W = grad_bias = None | |
| X, W = ctx.save_for_backward | |
| if ctx.needs_input_grad[0]: | |
| # rowwise quantize for G, global quantize for W | |
| # for W, we also fuse the transpose operation because only A @ B^T is supported | |
| # so we transpose once then call .t() in the matmul | |
| G_int8, state_G = quantize_rowwise(G) | |
| W_int8, state_W = quantize_global_transpose(W) | |
| grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view( | |
| *G_3D.size()[:-1], -1 | |
| ) | |
| if ctx.needs_input_grad[1]: | |
| # backward pass uses standard weight grad | |
| grad_W = torch.matmul(G.t(), X.to(G.dtype)) | |
| if ctx.needs_input_grad[2]: | |
| grad_bias = G.sum(dim=0) | |
| return grad_X, grad_W, grad_bias | |
| class _switchback_vectorrize(torch.autograd.Function): | |
| def forward(ctx, X_3D, W, bias): | |
| # reshape input to [N * L, D] | |
| X = X_3D.view(-1, X_3D.size(-1)) | |
| ctx.save_for_backward = X, W | |
| # rowwise quantize for X | |
| # columnwise quantize for W (first rowwise, transpose later) | |
| X_int8, state_X = quantize_rowwise(X) | |
| W_int8, state_W = quantize_rowwise(W) | |
| # matmult, fused dequant and add bias | |
| # call kernel which expects rowwise quantized X and W | |
| return int8_matmul_rowwise_dequantize( | |
| X_int8, W_int8.t(), state_X, state_W, bias | |
| ).view(*X_3D.size()[:-1], -1) | |
| def backward(ctx, G_3D): | |
| X, W = ctx.save_for_backward | |
| G = G_3D.reshape(-1, G_3D.size(-1)) | |
| grad_X = grad_W = grad_bias = None | |
| if ctx.needs_input_grad[0]: | |
| # rowwise quantize for G, columnwise quantize for W and fused transpose | |
| # we call .t() for weight later because only A @ B^T is supported | |
| G_int8, state_G = quantize_rowwise(G) | |
| W_int8, state_W = quantize_columnwise_and_transpose(W) | |
| grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( | |
| *G_3D.size()[:-1], -1 | |
| ) | |
| if ctx.needs_input_grad[1]: | |
| # backward pass uses standard weight grad | |
| grad_W = torch.matmul(G.t(), X.to(G.dtype)) | |
| if ctx.needs_input_grad[2]: | |
| grad_bias = G.sum(dim=0) | |
| return grad_X, grad_W, grad_bias | |
| class _switchback_global_mem_efficient(torch.autograd.Function): | |
| def forward(ctx, X_3D, W, bias): | |
| # reshape input to [N * L, D] | |
| X = X_3D.view(-1, X_3D.size(-1)) | |
| X_3D_sz = X_3D.size() | |
| # rowwise quantize for X, global quantize for W | |
| X_int8, state_X = quantize_rowwise(X) | |
| del X | |
| W_int8, state_W = quantize_global(W) | |
| # save for backward. | |
| ctx.save_for_backward = X_int8, state_X, W_int8, state_W | |
| # matmult, fused dequant and add bias | |
| # call "mixed" because we are mixing rowwise quantized and global quantized | |
| return int8_matmul_mixed_dequanitze( | |
| X_int8, W_int8.t(), state_X, state_W, bias | |
| ).view(*X_3D_sz[:-1], -1) | |
| def backward(ctx, G_3D): | |
| # reshape input to [N_out * L, D] | |
| G = G_3D.reshape(-1, G_3D.size(-1)) | |
| G_3D_sz = G_3D.size() | |
| grad_X = grad_W = grad_bias = None | |
| X_int8, state_X, W_int8, state_W = ctx.save_for_backward | |
| if ctx.needs_input_grad[1]: | |
| real_X = dequantize_rowwise(X_int8, state_X) | |
| del X_int8 | |
| grad_W = torch.matmul(G.t(), real_X.to(G.dtype)) | |
| del real_X | |
| if ctx.needs_input_grad[2]: | |
| grad_bias = G.sum(dim=0) | |
| if ctx.needs_input_grad[0]: | |
| G_int8, state_G = quantize_rowwise(G) | |
| del G | |
| W_int8 = W_int8.t().contiguous() | |
| grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view( | |
| *G_3D_sz[:-1], -1 | |
| ) | |
| return grad_X, grad_W, grad_bias | |
| class SwitchBackLinear(nn.Linear): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| bias: bool = True, | |
| device=None, | |
| dtype=None, | |
| vector_wise_quantization: bool = False, | |
| mem_efficient : bool = False, | |
| ): | |
| super().__init__(in_features, out_features, bias, device, dtype) | |
| if not is_triton_available: | |
| raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear. | |
| Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') | |
| # By default, we use the global quantization. | |
| self.vector_wise_quantization = vector_wise_quantization | |
| if self.vector_wise_quantization: | |
| self._fn = _switchback_vectorrize | |
| if mem_efficient: | |
| print('mem efficient is not supported for vector-wise quantization.') | |
| exit(1) | |
| else: | |
| if mem_efficient: | |
| self._fn = _switchback_global_mem_efficient | |
| else: | |
| self._fn = _switchback_global | |
| def prepare_for_eval(self): | |
| # If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass. | |
| # Note this is experimental and not tested thoroughly. | |
| # Note this needs to be explicitly called with something like | |
| # def cond_prepare(m): | |
| # if hasattr(m, "prepare_for_eval"): | |
| # m.prepare_for_eval() | |
| # model.apply(cond_prepare) | |
| print('=> preparing for eval.') | |
| if self.vector_wise_quantization: | |
| W_int8, state_W = quantize_rowwise(self.weight) | |
| else: | |
| W_int8, state_W = quantize_global(self.weight) | |
| self.register_buffer("W_int8", W_int8) | |
| self.register_buffer("state_W", state_W) | |
| del self.weight | |
| def forward(self, x): | |
| if self.training: | |
| return self._fn.apply(x, self.weight, self.bias) | |
| else: | |
| # If it hasn't been "prepared for eval", run the standard forward pass. | |
| if not hasattr(self, "W_int8"): | |
| return self._fn.apply(x, self.weight, self.bias) | |
| # Otherwise, use pre-computed weights. | |
| X = x.view(-1, x.size(-1)) | |
| X_int8, state_X = quantize_rowwise(X) | |
| if self.vector_wise_quantization: | |
| return int8_matmul_rowwise_dequantize( | |
| X_int8, self.W_int8.t(), state_X, self.state_W, self.bias | |
| ).view(*x.size()[:-1], -1) | |
| else: | |
| return int8_matmul_mixed_dequanitze( | |
| X_int8, self.W_int8.t(), state_X, self.state_W, self.bias | |
| ).view(*x.size()[:-1], -1) | |
| SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False) | |
| SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True) | |
| SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True) | |
| # This is just the standard linear function. | |
| class StandardLinearFunction(torch.autograd.Function): | |
| def forward(ctx, input, weight, bias=None): | |
| X = input.view(-1, input.size(-1)) | |
| ctx.save_for_backward(X, weight, bias) | |
| output = input.matmul(weight.t()) | |
| if bias is not None: | |
| output += bias.unsqueeze(0).expand_as(output) | |
| return output.view(*input.size()[:-1], -1) | |
| def backward(ctx, grad_output_3D): | |
| input, weight, bias = ctx.saved_tensors | |
| grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1)) | |
| grad_input = grad_weight = grad_bias = None | |
| if ctx.needs_input_grad[0]: | |
| grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1) | |
| if ctx.needs_input_grad[1]: | |
| grad_weight = grad_output.t().matmul(input.to(grad_output.dtype)) | |
| if bias is not None and ctx.needs_input_grad[2]: | |
| grad_bias = grad_output.sum(0) | |
| return grad_input, grad_weight, grad_bias | |
| class StandardLinear(nn.Linear): | |
| def forward(self, x): | |
| return StandardLinearFunction.apply(x, self.weight, self.bias) | |