| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from dataclasses import dataclass |
| | from typing import Literal, Optional |
| |
|
| | try: |
| | from torchao import quantize_ |
| | from torchao.quantization import int4_weight_only |
| | except ImportError: |
| |
|
| | def quantize_(model, quant_mode): |
| | raise ImportError( |
| | "torchao is not installed. Please install it with `pip install torchao`." |
| | ) |
| |
|
| | def int4_weight_only(group_size): |
| | raise ImportError( |
| | "torchao is not installed. Please install it with `pip install torchao`." |
| | ) |
| |
|
| |
|
| | def gelu_approx(x): |
| | return F.gelu(x, approximate="tanh") |
| |
|
| |
|
| | @dataclass |
| | class LinearWeights: |
| | weight: torch.Tensor |
| | bias: torch.Tensor |
| |
|
| |
|
| | def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor: |
| | return F.linear(x, w.weight, w.bias) |
| |
|
| |
|
| | def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16): |
| | _step = W_q.shape[0] |
| | W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device) |
| | W_r[:_step] = (W_q & 0b11110000) >> 4 |
| | W_r[_step:] = W_q & 0b00001111 |
| | W_r.sub_(zero).mul_(scale) |
| | return W_r.reshape(orig_shape) |
| |
|
| |
|
| | class QuantizedLinear(nn.Module): |
| | def __init__( |
| | self, |
| | in_features: int, |
| | out_features: int, |
| | dtype: torch.dtype, |
| | ): |
| | |
| | super().__init__() |
| | self.in_features = in_features |
| | self.out_features = out_features |
| | self.weight = nn.ParameterDict( |
| | { |
| | "packed": nn.Parameter( |
| | torch.empty( |
| | out_features * in_features // (128 * 2), 128, dtype=torch.uint8 |
| | ), |
| | requires_grad=False, |
| | ), |
| | "scale": nn.Parameter( |
| | torch.empty(out_features * in_features // 128, 1), |
| | requires_grad=False, |
| | ), |
| | "zero_point": nn.Parameter( |
| | torch.empty(out_features * in_features // 128, 1), |
| | requires_grad=False, |
| | ), |
| | } |
| | ) |
| | self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False) |
| | self.unpacked = False |
| |
|
| | def unpack(self): |
| | if self.unpacked: |
| | return |
| |
|
| | self.weight = nn.Parameter( |
| | dequantize_tensor( |
| | self.weight["packed"], |
| | self.weight["scale"], |
| | self.weight["zero_point"], |
| | (self.out_features, self.in_features), |
| | torch.bfloat16, |
| | ) |
| | ) |
| | with torch.device("meta"): |
| | self.linear = nn.Linear( |
| | self.in_features, self.out_features, dtype=torch.bfloat16 |
| | ) |
| | self.linear.weight = self.weight |
| | self.linear.bias = nn.Parameter( |
| | self.bias.to(torch.bfloat16), requires_grad=False |
| | ) |
| |
|
| | del self.weight, self.bias |
| | quantize_(self, int4_weight_only(group_size=128)) |
| | self.unpacked = True |
| | torch.cuda.empty_cache() |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | if not self.unpacked: |
| | self.unpack() |
| | return self.linear(x) |
| |
|
| |
|
| | @dataclass |
| | class LayerNormWeights: |
| | weight: torch.Tensor |
| | bias: torch.Tensor |
| |
|
| |
|
| | def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor: |
| | return F.layer_norm(x, w.bias.shape, w.weight, w.bias) |
| |
|
| |
|
| | @dataclass |
| | class MLPWeights: |
| | fc1: LinearWeights |
| | fc2: LinearWeights |
| | act: Literal["gelu_approx"] = "gelu_approx" |
| |
|
| |
|
| | def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Tensor: |
| | x0 = w.fc1(x) |
| | if lora is not None: |
| | x1 = F.linear(F.linear(x, lora["fc1"]["A"]), lora["fc1"]["B"]) |
| | x = x0 + x1 |
| | else: |
| | x = x0 |
| |
|
| | x = gelu_approx(x) |
| |
|
| | x0 = w.fc2(x) |
| | if lora is not None: |
| | x1 = F.linear(F.linear(x, lora["fc2"]["A"]), lora["fc2"]["B"]) |
| | x = x0 + x1 |
| | else: |
| | x = x0 |
| |
|
| | return x |
| |
|
| |
|
| | @dataclass |
| | class AttentionWeights: |
| | qkv: LinearWeights |
| | proj: LinearWeights |
| |
|
| |
|
| | def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor: |
| | bsz, q_len, d_model = x.shape |
| | head_dim = d_model // n_heads |
| |
|
| | q, k, v = [ |
| | t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) |
| | for t in linear(x, w.qkv).chunk(3, dim=-1) |
| | ] |
| | out = F.scaled_dot_product_attention(q, k, v) |
| | out = out.transpose(1, 2).reshape(bsz, q_len, d_model) |
| | out = linear(out, w.proj) |
| | return out |
| |
|