| """ |
| SSMoELM Packed Inference β 12MB γ‘γ’γͺζ¨θ« |
| packed uint8 weights γγ‘γ’γͺγ«δΏζγγforwardζγ«γͺγ³γγγ³γγ§ unpack γγγ |
| |
| δ½ΏγζΉ: |
| python inference_packed.py --prompt "Hello" |
| """ |
| import argparse |
| import math |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from safetensors.numpy import load_file |
| from tokenizers import Tokenizer |
|
|
| D_MODEL = 768 |
| N_LAYERS = 6 |
| N_HEADS = 12 |
| KV_HEADS = 3 |
| HEAD_DIM = 64 |
| N_EXPERTS = 8 |
| N_ACTIVE = 2 |
| D_FF = 256 |
| VOCAB_SIZE = 8192 |
| CTX_LEN = 2048 |
|
|
| BOS_ID, EOS_ID, EOT_ID = 0, 1, 6 |
|
|
|
|
| |
|
|
| class Linear1bit(nn.Module): |
| """1-bit packed linear layer: scale(fp16) + packed bits(uint8) β fp32 matmul on-the-fly""" |
| def __init__(self, out_f: int, in_f: int): |
| super().__init__() |
| self.in_features = in_f |
| self.register_buffer("scale", torch.zeros(out_f, dtype=torch.float16)) |
| self.register_buffer("packed", torch.zeros(out_f, (in_f + 7) // 8, dtype=torch.uint8)) |
|
|
| def _unpack(self) -> torch.Tensor: |
| |
| bits = ((self.packed.unsqueeze(-1) |
| >> torch.arange(7, -1, -1, device=self.packed.device, dtype=torch.uint8)) |
| & 1) |
| bits = bits.reshape(self.packed.shape[0], -1)[:, :self.in_features] |
| w = bits.float() * 2.0 - 1.0 |
| w = w * self.scale.float().unsqueeze(-1) |
| return w |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| w = self._unpack() |
| out = F.linear(x, w) |
| del w |
| return out |
|
|
|
|
| class Linear4bit(nn.Module): |
| """4-bit nibble-packed linear layer""" |
| def __init__(self, out_f: int, in_f: int): |
| super().__init__() |
| self.in_features = in_f |
| self.register_buffer("scale", torch.zeros(out_f, dtype=torch.float16)) |
| self.register_buffer("packed", torch.zeros(out_f, (in_f + 1) // 2, dtype=torch.uint8)) |
|
|
| def _unpack(self) -> torch.Tensor: |
| lo = (self.packed & 0x0F).to(torch.int8) - 8 |
| hi = ((self.packed >> 4) & 0x0F).to(torch.int8) - 8 |
| w = torch.stack([lo, hi], dim=-1).reshape(self.packed.shape[0], -1) |
| w = w[:, :self.in_features].float() |
| w = w * self.scale.float().unsqueeze(-1) |
| return w |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| w = self._unpack() |
| out = F.linear(x, w) |
| del w |
| return out |
|
|
|
|
| class EmbeddingPacked(nn.Module): |
| """4-bit packed embedding table""" |
| def __init__(self, vocab: int, d: int): |
| super().__init__() |
| self.vocab = vocab |
| self.d = d |
| self.register_buffer("scale", torch.zeros(vocab, dtype=torch.float16)) |
| self.register_buffer("packed", torch.zeros(vocab, (d + 1) // 2, dtype=torch.uint8)) |
|
|
| def get_weight(self) -> torch.Tensor: |
| lo = (self.packed & 0x0F).to(torch.int8) - 8 |
| hi = ((self.packed >> 4) & 0x0F).to(torch.int8) - 8 |
| w = torch.stack([lo, hi], dim=-1).reshape(self.vocab, -1)[:, :self.d].float() |
| return w * self.scale.float().unsqueeze(-1) |
|
|
| def forward(self, idx: torch.Tensor) -> torch.Tensor: |
| return self.get_weight()[idx] |
|
|
|
|
| |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, d: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(d)) |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| rms = (x.float().pow(2).mean(-1, keepdim=True) + self.eps).rsqrt() |
| return (self.weight * (x.float() * rms)).to(x.dtype) |
|
|
|
|
| def precompute_rope(head_dim: int, max_len: int, base: float = 10000.0) -> torch.Tensor: |
| inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) |
| freqs = torch.outer(torch.arange(max_len).float(), inv_freq) |
| return torch.cat([freqs, freqs], dim=-1) |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| h = x.shape[-1] // 2 |
| return torch.cat([-x[..., h:], x[..., :h]], dim=-1) |
|
|
|
|
| def apply_rope(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: |
| cos = freqs.cos()[None, :, None, :].to(x.dtype) |
| sin = freqs.sin()[None, :, None, :].to(x.dtype) |
| return x * cos + rotate_half(x) * sin |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, layer_idx: int): |
| super().__init__() |
| boundary = {0, 5} |
| vo_4bit = layer_idx in boundary |
| d = D_MODEL |
| self.n_heads = N_HEADS |
| self.kv_heads = KV_HEADS |
| self.head_dim = HEAD_DIM |
| self.n_rep = N_HEADS // KV_HEADS |
| self.q_proj = Linear4bit(N_HEADS * HEAD_DIM, d) |
| self.k_proj = Linear4bit(KV_HEADS * HEAD_DIM, d) |
| self.v_proj = Linear4bit(KV_HEADS * HEAD_DIM, d) if vo_4bit else Linear1bit(KV_HEADS * HEAD_DIM, d) |
| self.o_proj = Linear4bit(d, N_HEADS * HEAD_DIM) if vo_4bit else Linear1bit(d, N_HEADS * HEAD_DIM) |
|
|
| def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: |
| B, T, _ = x.shape |
| q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim) |
| k = self.k_proj(x).reshape(B, T, self.kv_heads, self.head_dim) |
| v = self.v_proj(x).reshape(B, T, self.kv_heads, self.head_dim) |
| q, k = apply_rope(q, freqs), apply_rope(k, freqs) |
| k = k.repeat_interleave(self.n_rep, dim=2) |
| v = v.repeat_interleave(self.n_rep, dim=2) |
| out = F.scaled_dot_product_attention( |
| q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True) |
| return self.o_proj(out.transpose(1, 2).reshape(B, T, -1)) |
|
|
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, d_model: int, d_ff: int, bits: int): |
| super().__init__() |
| L = Linear4bit if bits == 4 else Linear1bit |
| self.gate = L(d_ff, d_model) |
| self.up = L(d_ff, d_model) |
| self.down = L(d_model, d_ff) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.down(F.silu(self.gate(x)) * self.up(x)) |
|
|
|
|
| class MoELayer(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.shared = SwiGLU(D_MODEL, D_FF, bits=4) |
| |
| self.gate_scale = nn.ParameterList([nn.Parameter(torch.zeros(D_FF), requires_grad=False) for _ in range(N_EXPERTS)]) |
| self.gate_packed = nn.ParameterList([nn.Parameter(torch.zeros(D_FF, (D_MODEL+7)//8, dtype=torch.uint8), requires_grad=False) for _ in range(N_EXPERTS)]) |
| self.up_scale = nn.ParameterList([nn.Parameter(torch.zeros(D_FF), requires_grad=False) for _ in range(N_EXPERTS)]) |
| self.up_packed = nn.ParameterList([nn.Parameter(torch.zeros(D_FF, (D_MODEL+7)//8, dtype=torch.uint8), requires_grad=False) for _ in range(N_EXPERTS)]) |
| self.down_scale = nn.ParameterList([nn.Parameter(torch.zeros(D_MODEL), requires_grad=False) for _ in range(N_EXPERTS)]) |
| self.down_packed = nn.ParameterList([nn.Parameter(torch.zeros(D_MODEL, (D_FF+7)//8, dtype=torch.uint8), requires_grad=False) for _ in range(N_EXPERTS)]) |
| self.router = nn.Parameter(torch.zeros(N_EXPERTS, D_MODEL)) |
|
|
| def _unpack1bit(self, scale: torch.Tensor, packed: torch.Tensor, in_f: int) -> torch.Tensor: |
| bits = ((packed.unsqueeze(-1) |
| >> torch.arange(7, -1, -1, device=packed.device, dtype=torch.uint8)) & 1) |
| bits = bits.reshape(packed.shape[0], -1)[:, :in_f].float() * 2.0 - 1.0 |
| return bits * scale.float().unsqueeze(-1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, T, d = x.shape |
| shared_out = self.shared(x) |
|
|
| logits = x @ self.router.T |
| top_idx = logits.topk(N_ACTIVE, dim=-1).indices |
| top_w = F.softmax(logits.gather(-1, top_idx).float(), dim=-1).to(x.dtype) |
|
|
| |
| x_flat = x.reshape(B * T, d) |
| out = torch.zeros_like(x_flat) |
|
|
| for k in range(N_ACTIVE): |
| e_idx = top_idx[..., k].reshape(-1) |
| w_k = top_w[..., k].reshape(-1, 1) |
|
|
| for e in range(N_EXPERTS): |
| mask = (e_idx == e) |
| if not mask.any(): |
| continue |
| x_e = x_flat[mask] |
| |
| wg = self._unpack1bit(self.gate_scale[e], self.gate_packed[e], D_MODEL) |
| wu = self._unpack1bit(self.up_scale[e], self.up_packed[e], D_MODEL) |
| wd = self._unpack1bit(self.down_scale[e], self.down_packed[e], D_FF) |
| h = F.silu(F.linear(x_e, wg)) * F.linear(x_e, wu) |
| out[mask] += F.linear(h, wd) * w_k[mask] |
| del wg, wu, wd, h |
|
|
| return shared_out + out.reshape(B, T, d) |
|
|
|
|
| class TransformerLayer(nn.Module): |
| def __init__(self, layer_idx: int): |
| super().__init__() |
| self.attn_norm = RMSNorm(D_MODEL) |
| self.ffn_norm = RMSNorm(D_MODEL) |
| self.attn = Attention(layer_idx) |
| self.moe = MoELayer() |
|
|
| def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: |
| x = x + self.attn(self.attn_norm(x), freqs) |
| x = x + self.moe(self.ffn_norm(x)) |
| return x |
|
|
|
|
| class SSMoELMPacked(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.embed = EmbeddingPacked(VOCAB_SIZE, D_MODEL) |
| self.layers = nn.ModuleList([TransformerLayer(i) for i in range(N_LAYERS)]) |
| self.norm = RMSNorm(D_MODEL) |
| self.register_buffer("freqs", precompute_rope(HEAD_DIM, CTX_LEN)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| T = x.shape[1] |
| h = self.embed(x).float() |
| freqs = self.freqs[:T] |
| for layer in self.layers: |
| h = layer(h, freqs) |
| h = self.norm(h) |
| w = self.embed.get_weight() |
| return h @ w.T |
|
|
| @torch.inference_mode() |
| def generate(self, input_ids: list[int], max_new_tokens: int = 200, |
| temperature: float = 0.8, top_p: float = 0.9, |
| eos_ids: tuple[int, ...] = (EOS_ID, EOT_ID)) -> list[int]: |
| ids = list(input_ids) |
| generated = [] |
| for _ in range(max_new_tokens): |
| x = torch.tensor([ids[-CTX_LEN:]], dtype=torch.long) |
| logits = self(x)[0, -1] |
| if temperature > 0: |
| logits_np = logits.numpy().astype(np.float64) |
| logits_np = (logits_np - logits_np.max()) / temperature |
| probs = np.exp(logits_np); probs /= probs.sum() |
| idx = np.argsort(-probs); cumsum = np.cumsum(probs[idx]) |
| cutoff = np.searchsorted(cumsum, top_p) + 1 |
| probs[idx[cutoff:]] = 0.0; probs /= probs.sum() |
| next_id = int(np.random.choice(idx, p=probs)) |
| else: |
| next_id = int(logits.argmax().item()) |
| if next_id in eos_ids: |
| break |
| generated.append(next_id) |
| ids.append(next_id) |
| return generated |
|
|
|
|
| |
|
|
| def load_packed_model(path: str) -> SSMoELMPacked: |
| data = load_file(path) |
| data = {k.replace("/", "."): v for k, v in data.items()} |
|
|
| model = SSMoELMPacked() |
|
|
| def _set(module, scale_name, packed_name, key_base): |
| s_key = f"{key_base}__scale" |
| p_key_bin = f"{key_base}__bin" |
| p_key_int4 = f"{key_base}__int4" |
| if s_key in data: |
| getattr(module, scale_name).data.copy_(torch.from_numpy(data[s_key].astype(np.float16))) |
| if p_key_bin in data: |
| getattr(module, packed_name).data.copy_(torch.from_numpy(data[p_key_bin])) |
| elif p_key_int4 in data: |
| getattr(module, packed_name).data.copy_(torch.from_numpy(data[p_key_int4])) |
|
|
| |
| _set(model.embed, "scale", "packed", "embed_weight") |
|
|
| for i, layer in enumerate(model.layers): |
| pfx = f"layers.{i}" |
| |
| for proj_name, key in [("q_proj","q_weight"),("k_proj","k_weight"), |
| ("v_proj","v_weight"),("o_proj","o_weight")]: |
| proj = getattr(layer.attn, proj_name) |
| _set(proj, "scale", "packed", f"{pfx}.attn.{key}") |
| |
| for norm_name, key in [("attn_norm","attn_norm"),("ffn_norm","ffn_norm")]: |
| norm = getattr(layer, norm_name) |
| w = data.get(f"{pfx}.{key}.weight") |
| if w is not None: |
| norm.weight.data.copy_(torch.from_numpy(w.astype(np.float32))) |
| |
| se = layer.moe.shared |
| for attr, wname in [("gate","gate_weight"),("up","up_weight"),("down","down_weight")]: |
| _set(getattr(se, attr), "scale", "packed", f"{pfx}.moe.shared_expert.{wname}") |
| |
| rw = data.get(f"{pfx}.moe.router_weight") |
| if rw is not None: |
| layer.moe.router.data.copy_(torch.from_numpy(rw.astype(np.float32))) |
| |
| for attr, wname, out_f, in_f in [ |
| ("gate", "gate", D_FF, D_MODEL), |
| ("up", "up", D_FF, D_MODEL), |
| ("down", "down", D_MODEL, D_FF), |
| ]: |
| s_key = f"{pfx}.moe.{wname}_weight__scale" |
| p_key = (f"{pfx}.moe.{wname}_weight__bin" |
| if f"{pfx}.moe.{wname}_weight__bin" in data |
| else f"{pfx}.moe.{wname}_weight__int4") |
| if s_key not in data or p_key not in data: |
| continue |
| s_arr = data[s_key] |
| p_arr = data[p_key] |
| for e in range(N_EXPERTS): |
| sl = slice(e * out_f, (e + 1) * out_f) |
| getattr(layer.moe, f"{attr}_scale")[e].data.copy_( |
| torch.from_numpy(s_arr[sl].astype(np.float16))) |
| getattr(layer.moe, f"{attr}_packed")[e].data.copy_( |
| torch.from_numpy(p_arr[sl])) |
|
|
| return model.eval() |
|
|
|
|
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--ckpt", default="model.safetensors", help="Path to packed safetensors") |
| parser.add_argument("--tokenizer", default="tokenizer.json", help="Path to tokenizer.json") |
| parser.add_argument("--prompt", default="The quick brown fox", help="Text prompt") |
| parser.add_argument("--max-tokens", type=int, default=200, help="Max new tokens to generate") |
| parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature") |
| parser.add_argument("--top-p", type=float, default=0.9, help="Top-p nucleus sampling") |
| args = parser.parse_args() |
|
|
| print(f"Loading {args.ckpt} ...") |
| model = load_packed_model(args.ckpt) |
| packed_bytes = sum(p.numel() * p.element_size() for p in model.buffers()) \ |
| + sum(p.numel() * p.element_size() for p in model.parameters()) |
| print(f"Memory: {packed_bytes/1024/1024:.1f} MB") |
|
|
| tok = Tokenizer.from_file(args.tokenizer) |
| ids = [BOS_ID] + tok.encode(args.prompt).ids |
| print(f"\nPrompt: {args.prompt}") |
| print("Output: ", end="", flush=True) |
| out = model.generate(ids, args.max_tokens, args.temperature, args.top_p) |
| print(tok.decode(out)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|