File size: 6,985 Bytes
8c21234 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Literal, Optional
try:
from torchao import quantize_
from torchao.quantization import int4_weight_only
except ImportError:
def quantize_(model, quant_mode):
raise ImportError(
"torchao is not installed. Please install it with `pip install torchao`."
)
def int4_weight_only(group_size):
raise ImportError(
"torchao is not installed. Please install it with `pip install torchao`."
)
def gelu_approx(x):
return F.gelu(x, approximate="tanh")
@dataclass
class LinearWeights:
weight: torch.Tensor
bias: torch.Tensor
def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
return F.linear(x, w.weight, w.bias)
def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
_step = W_q.shape[0]
W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
W_r[:_step] = (W_q & 0b11110000) >> 4
W_r[_step:] = W_q & 0b00001111
W_r.sub_(zero).mul_(scale)
return W_r.reshape(orig_shape)
class QuantizedLinear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
dtype: torch.dtype,
):
# TODO: Take group_size as an input instead of hardcoding it here.
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.ParameterDict(
{
"packed": nn.Parameter(
torch.empty(
out_features * in_features // (128 * 2), 128, dtype=torch.uint8
),
requires_grad=False,
),
"scale": nn.Parameter(
torch.empty(out_features * in_features // 128, 1),
requires_grad=False,
),
"zero_point": nn.Parameter(
torch.empty(out_features * in_features // 128, 1),
requires_grad=False,
),
}
)
self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False)
self.unpacked = False
def unpack(self):
if self.unpacked:
return
self.weight = nn.Parameter(
dequantize_tensor(
self.weight["packed"],
self.weight["scale"],
self.weight["zero_point"],
(self.out_features, self.in_features),
torch.bfloat16,
)
)
with torch.device("meta"):
self.linear = nn.Linear(
self.in_features, self.out_features, dtype=torch.bfloat16
)
self.linear.weight = self.weight
self.linear.bias = nn.Parameter(
self.bias.to(torch.bfloat16), requires_grad=False
)
del self.weight, self.bias
quantize_(self, int4_weight_only(group_size=128))
self.unpacked = True
torch.cuda.empty_cache()
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.unpacked:
self.unpack()
return self.linear(x)
@dataclass
class LayerNormWeights:
weight: torch.Tensor
bias: torch.Tensor
def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
return F.layer_norm(x, w.bias.shape, w.weight, w.bias)
@dataclass
class MLPWeights:
fc1: LinearWeights
fc2: LinearWeights
act: Literal["gelu_approx"] = "gelu_approx"
def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Tensor:
x0 = w.fc1(x)
if lora is not None:
x1 = F.linear(F.linear(x, lora["fc1"]["A"]), lora["fc1"]["B"])
x = x0 + x1
else:
x = x0
x = gelu_approx(x)
x0 = w.fc2(x)
if lora is not None:
x1 = F.linear(F.linear(x, lora["fc2"]["A"]), lora["fc2"]["B"])
x = x0 + x1
else:
x = x0
return x
def moe_mlp(
x: torch.Tensor, mlp_module: nn.Module, experts_per_token: int
) -> torch.Tensor:
B, T, C = x.shape
x = x.reshape(-1, C)
# Router computation
router_logits = mlp_module.router(x)
topk_logits, topk_idxs = torch.topk(router_logits, experts_per_token, dim=-1)
topk_weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(x.dtype)
num_tokens, top_k = topk_idxs.shape
if T == 1:
w1_weight = mlp_module.fc1.weight
w2_weight = mlp_module.fc2.weight
# Flatten to process all token-expert pairs at once
flat_idxs = topk_idxs.view(-1) # [T*A]
flat_weights = topk_weights.view(-1) # [T*A]
# Select expert weights
w1_selected = w1_weight[flat_idxs] # [T*A, H, D]
w2_selected = w2_weight[flat_idxs] # [T*A, D, H]
# Expand input for all token-expert pairs
x_expanded = x.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, C) # [T*A, D]
# First linear layer with GeGLU: [T*A, H, D] @ [T*A, D, 1] -> [T*A, H]
x1_full = torch.bmm(w1_selected, x_expanded.unsqueeze(-1)).squeeze(
-1
) # [T*A, H]
x1, g = x1_full.chunk(2, dim=-1)
x1 = F.gelu(x1) * (g + 1)
# Second linear layer: [T*A, D, H] @ [T*A, H, 1] -> [T*A, D]
expert_outs = torch.bmm(w2_selected, x1.unsqueeze(-1)).squeeze(-1) # [T*A, D]
# Apply weights and reshape
weighted_outs = expert_outs * flat_weights.unsqueeze(-1) # [T*A, D]
weighted_outs = weighted_outs.view(num_tokens, top_k, C) # [T, A, D]
# Sum over experts
mlp_out = weighted_outs.sum(dim=1) # [T, D]
mlp_out = mlp_out.view(B, T, C)
return mlp_out
else:
out = x.new_zeros(x.size())
for expert_id in range(mlp_module.fc1.weight.shape[0]):
token_pos, which_k = (topk_idxs == expert_id).nonzero(as_tuple=True)
if token_pos.numel() == 0:
continue
x_tok = x.index_select(0, token_pos)
gate_tok = topk_weights[token_pos, which_k]
h_full = F.linear(x_tok, mlp_module.fc1.weight[expert_id])
h, g = h_full.chunk(2, dim=-1)
h = F.gelu(h) * (g + 1)
y = F.linear(h, mlp_module.fc2.weight[expert_id])
y.mul_(gate_tok.unsqueeze(-1))
out.index_add_(0, token_pos, y)
return out.view(B, T, C)
@dataclass
class AttentionWeights:
qkv: LinearWeights
proj: LinearWeights
def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
bsz, q_len, d_model = x.shape
head_dim = d_model // n_heads
q, k, v = [
t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
for t in linear(x, w.qkv).chunk(3, dim=-1)
]
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
out = linear(out, w.proj)
return out
|