""" SSMoELM Packed Inference — 12MB メモリ推論 packed uint8 weights をメモリに保持し、forward時にオンデマンドで unpack する。 使い方: python inference_packed.py --prompt "Hello" """ import argparse import math from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from safetensors.numpy import load_file from tokenizers import Tokenizer D_MODEL = 768 N_LAYERS = 6 N_HEADS = 12 KV_HEADS = 3 HEAD_DIM = 64 N_EXPERTS = 8 N_ACTIVE = 2 D_FF = 256 VOCAB_SIZE = 8192 CTX_LEN = 2048 BOS_ID, EOS_ID, EOT_ID = 0, 1, 6 # ── Packed Linear Modules ──────────────────────────────────────────────────── class Linear1bit(nn.Module): """1-bit packed linear layer: scale(fp16) + packed bits(uint8) → fp32 matmul on-the-fly""" def __init__(self, out_f: int, in_f: int): super().__init__() self.in_features = in_f self.register_buffer("scale", torch.zeros(out_f, dtype=torch.float16)) self.register_buffer("packed", torch.zeros(out_f, (in_f + 7) // 8, dtype=torch.uint8)) def _unpack(self) -> torch.Tensor: # packed: [out, ceil(in/8)] → [out, in] values ±1 bits = ((self.packed.unsqueeze(-1) >> torch.arange(7, -1, -1, device=self.packed.device, dtype=torch.uint8)) & 1) # [out, ceil(in/8), 8] bits = bits.reshape(self.packed.shape[0], -1)[:, :self.in_features] # [out, in] w = bits.float() * 2.0 - 1.0 # {0,1} → {-1,1} w = w * self.scale.float().unsqueeze(-1) # row-wise scale return w # [out, in] def forward(self, x: torch.Tensor) -> torch.Tensor: w = self._unpack() out = F.linear(x, w) del w return out class Linear4bit(nn.Module): """4-bit nibble-packed linear layer""" def __init__(self, out_f: int, in_f: int): super().__init__() self.in_features = in_f self.register_buffer("scale", torch.zeros(out_f, dtype=torch.float16)) self.register_buffer("packed", torch.zeros(out_f, (in_f + 1) // 2, dtype=torch.uint8)) def _unpack(self) -> torch.Tensor: lo = (self.packed & 0x0F).to(torch.int8) - 8 # [out, in//2] hi = ((self.packed >> 4) & 0x0F).to(torch.int8) - 8 w = torch.stack([lo, hi], dim=-1).reshape(self.packed.shape[0], -1) w = w[:, :self.in_features].float() w = w * self.scale.float().unsqueeze(-1) return w def forward(self, x: torch.Tensor) -> torch.Tensor: w = self._unpack() out = F.linear(x, w) del w return out class EmbeddingPacked(nn.Module): """4-bit packed embedding table""" def __init__(self, vocab: int, d: int): super().__init__() self.vocab = vocab self.d = d self.register_buffer("scale", torch.zeros(vocab, dtype=torch.float16)) self.register_buffer("packed", torch.zeros(vocab, (d + 1) // 2, dtype=torch.uint8)) def get_weight(self) -> torch.Tensor: lo = (self.packed & 0x0F).to(torch.int8) - 8 hi = ((self.packed >> 4) & 0x0F).to(torch.int8) - 8 w = torch.stack([lo, hi], dim=-1).reshape(self.vocab, -1)[:, :self.d].float() return w * self.scale.float().unsqueeze(-1) def forward(self, idx: torch.Tensor) -> torch.Tensor: return self.get_weight()[idx] # ── Model ───────────────────────────────────────────────────────────────────── class RMSNorm(nn.Module): def __init__(self, d: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(d)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: rms = (x.float().pow(2).mean(-1, keepdim=True) + self.eps).rsqrt() return (self.weight * (x.float() * rms)).to(x.dtype) def precompute_rope(head_dim: int, max_len: int, base: float = 10000.0) -> torch.Tensor: inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) freqs = torch.outer(torch.arange(max_len).float(), inv_freq) return torch.cat([freqs, freqs], dim=-1) def rotate_half(x: torch.Tensor) -> torch.Tensor: h = x.shape[-1] // 2 return torch.cat([-x[..., h:], x[..., :h]], dim=-1) def apply_rope(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: cos = freqs.cos()[None, :, None, :].to(x.dtype) sin = freqs.sin()[None, :, None, :].to(x.dtype) return x * cos + rotate_half(x) * sin class Attention(nn.Module): def __init__(self, layer_idx: int): super().__init__() boundary = {0, 5} vo_4bit = layer_idx in boundary d = D_MODEL self.n_heads = N_HEADS self.kv_heads = KV_HEADS self.head_dim = HEAD_DIM self.n_rep = N_HEADS // KV_HEADS self.q_proj = Linear4bit(N_HEADS * HEAD_DIM, d) self.k_proj = Linear4bit(KV_HEADS * HEAD_DIM, d) self.v_proj = Linear4bit(KV_HEADS * HEAD_DIM, d) if vo_4bit else Linear1bit(KV_HEADS * HEAD_DIM, d) self.o_proj = Linear4bit(d, N_HEADS * HEAD_DIM) if vo_4bit else Linear1bit(d, N_HEADS * HEAD_DIM) def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: B, T, _ = x.shape q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim) k = self.k_proj(x).reshape(B, T, self.kv_heads, self.head_dim) v = self.v_proj(x).reshape(B, T, self.kv_heads, self.head_dim) q, k = apply_rope(q, freqs), apply_rope(k, freqs) k = k.repeat_interleave(self.n_rep, dim=2) v = v.repeat_interleave(self.n_rep, dim=2) out = F.scaled_dot_product_attention( q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True) return self.o_proj(out.transpose(1, 2).reshape(B, T, -1)) class SwiGLU(nn.Module): def __init__(self, d_model: int, d_ff: int, bits: int): super().__init__() L = Linear4bit if bits == 4 else Linear1bit self.gate = L(d_ff, d_model) self.up = L(d_ff, d_model) self.down = L(d_model, d_ff) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down(F.silu(self.gate(x)) * self.up(x)) class MoELayer(nn.Module): def __init__(self): super().__init__() self.shared = SwiGLU(D_MODEL, D_FF, bits=4) # stacked routed expert weights (1-bit) self.gate_scale = nn.ParameterList([nn.Parameter(torch.zeros(D_FF), requires_grad=False) for _ in range(N_EXPERTS)]) self.gate_packed = nn.ParameterList([nn.Parameter(torch.zeros(D_FF, (D_MODEL+7)//8, dtype=torch.uint8), requires_grad=False) for _ in range(N_EXPERTS)]) self.up_scale = nn.ParameterList([nn.Parameter(torch.zeros(D_FF), requires_grad=False) for _ in range(N_EXPERTS)]) self.up_packed = nn.ParameterList([nn.Parameter(torch.zeros(D_FF, (D_MODEL+7)//8, dtype=torch.uint8), requires_grad=False) for _ in range(N_EXPERTS)]) self.down_scale = nn.ParameterList([nn.Parameter(torch.zeros(D_MODEL), requires_grad=False) for _ in range(N_EXPERTS)]) self.down_packed = nn.ParameterList([nn.Parameter(torch.zeros(D_MODEL, (D_FF+7)//8, dtype=torch.uint8), requires_grad=False) for _ in range(N_EXPERTS)]) self.router = nn.Parameter(torch.zeros(N_EXPERTS, D_MODEL)) def _unpack1bit(self, scale: torch.Tensor, packed: torch.Tensor, in_f: int) -> torch.Tensor: bits = ((packed.unsqueeze(-1) >> torch.arange(7, -1, -1, device=packed.device, dtype=torch.uint8)) & 1) bits = bits.reshape(packed.shape[0], -1)[:, :in_f].float() * 2.0 - 1.0 return bits * scale.float().unsqueeze(-1) def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, d = x.shape shared_out = self.shared(x) logits = x @ self.router.T top_idx = logits.topk(N_ACTIVE, dim=-1).indices top_w = F.softmax(logits.gather(-1, top_idx).float(), dim=-1).to(x.dtype) # process only active experts (memory-efficient) x_flat = x.reshape(B * T, d) out = torch.zeros_like(x_flat) for k in range(N_ACTIVE): e_idx = top_idx[..., k].reshape(-1) # [B*T] w_k = top_w[..., k].reshape(-1, 1) # [B*T, 1] for e in range(N_EXPERTS): mask = (e_idx == e) if not mask.any(): continue x_e = x_flat[mask] # unpack this expert's weights on-the-fly wg = self._unpack1bit(self.gate_scale[e], self.gate_packed[e], D_MODEL) wu = self._unpack1bit(self.up_scale[e], self.up_packed[e], D_MODEL) wd = self._unpack1bit(self.down_scale[e], self.down_packed[e], D_FF) h = F.silu(F.linear(x_e, wg)) * F.linear(x_e, wu) out[mask] += F.linear(h, wd) * w_k[mask] del wg, wu, wd, h return shared_out + out.reshape(B, T, d) class TransformerLayer(nn.Module): def __init__(self, layer_idx: int): super().__init__() self.attn_norm = RMSNorm(D_MODEL) self.ffn_norm = RMSNorm(D_MODEL) self.attn = Attention(layer_idx) self.moe = MoELayer() def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.attn_norm(x), freqs) x = x + self.moe(self.ffn_norm(x)) return x class SSMoELMPacked(nn.Module): def __init__(self): super().__init__() self.embed = EmbeddingPacked(VOCAB_SIZE, D_MODEL) self.layers = nn.ModuleList([TransformerLayer(i) for i in range(N_LAYERS)]) self.norm = RMSNorm(D_MODEL) self.register_buffer("freqs", precompute_rope(HEAD_DIM, CTX_LEN)) def forward(self, x: torch.Tensor) -> torch.Tensor: T = x.shape[1] h = self.embed(x).float() freqs = self.freqs[:T] for layer in self.layers: h = layer(h, freqs) h = self.norm(h) w = self.embed.get_weight() return h @ w.T @torch.inference_mode() def generate(self, input_ids: list[int], max_new_tokens: int = 200, temperature: float = 0.8, top_p: float = 0.9, eos_ids: tuple[int, ...] = (EOS_ID, EOT_ID)) -> list[int]: ids = list(input_ids) generated = [] for _ in range(max_new_tokens): x = torch.tensor([ids[-CTX_LEN:]], dtype=torch.long) logits = self(x)[0, -1] if temperature > 0: logits_np = logits.numpy().astype(np.float64) logits_np = (logits_np - logits_np.max()) / temperature probs = np.exp(logits_np); probs /= probs.sum() idx = np.argsort(-probs); cumsum = np.cumsum(probs[idx]) cutoff = np.searchsorted(cumsum, top_p) + 1 probs[idx[cutoff:]] = 0.0; probs /= probs.sum() next_id = int(np.random.choice(idx, p=probs)) else: next_id = int(logits.argmax().item()) if next_id in eos_ids: break generated.append(next_id) ids.append(next_id) return generated # ── Load ────────────────────────────────────────────────────────────────────── def load_packed_model(path: str) -> SSMoELMPacked: data = load_file(path) data = {k.replace("/", "."): v for k, v in data.items()} model = SSMoELMPacked() def _set(module, scale_name, packed_name, key_base): s_key = f"{key_base}__scale" p_key_bin = f"{key_base}__bin" p_key_int4 = f"{key_base}__int4" if s_key in data: getattr(module, scale_name).data.copy_(torch.from_numpy(data[s_key].astype(np.float16))) if p_key_bin in data: getattr(module, packed_name).data.copy_(torch.from_numpy(data[p_key_bin])) elif p_key_int4 in data: getattr(module, packed_name).data.copy_(torch.from_numpy(data[p_key_int4])) # embed _set(model.embed, "scale", "packed", "embed_weight") for i, layer in enumerate(model.layers): pfx = f"layers.{i}" # attention for proj_name, key in [("q_proj","q_weight"),("k_proj","k_weight"), ("v_proj","v_weight"),("o_proj","o_weight")]: proj = getattr(layer.attn, proj_name) _set(proj, "scale", "packed", f"{pfx}.attn.{key}") # norms for norm_name, key in [("attn_norm","attn_norm"),("ffn_norm","ffn_norm")]: norm = getattr(layer, norm_name) w = data.get(f"{pfx}.{key}.weight") if w is not None: norm.weight.data.copy_(torch.from_numpy(w.astype(np.float32))) # shared expert se = layer.moe.shared for attr, wname in [("gate","gate_weight"),("up","up_weight"),("down","down_weight")]: _set(getattr(se, attr), "scale", "packed", f"{pfx}.moe.shared_expert.{wname}") # router rw = data.get(f"{pfx}.moe.router_weight") if rw is not None: layer.moe.router.data.copy_(torch.from_numpy(rw.astype(np.float32))) # routed experts: stacked weights [E, out, in] → scale [E*out], packed [E*out, ...] for attr, wname, out_f, in_f in [ ("gate", "gate", D_FF, D_MODEL), ("up", "up", D_FF, D_MODEL), ("down", "down", D_MODEL, D_FF), ]: s_key = f"{pfx}.moe.{wname}_weight__scale" p_key = (f"{pfx}.moe.{wname}_weight__bin" if f"{pfx}.moe.{wname}_weight__bin" in data else f"{pfx}.moe.{wname}_weight__int4") if s_key not in data or p_key not in data: continue s_arr = data[s_key] # [E*out_f] p_arr = data[p_key] # [E*out_f, packed_cols] for e in range(N_EXPERTS): sl = slice(e * out_f, (e + 1) * out_f) getattr(layer.moe, f"{attr}_scale")[e].data.copy_( torch.from_numpy(s_arr[sl].astype(np.float16))) getattr(layer.moe, f"{attr}_packed")[e].data.copy_( torch.from_numpy(p_arr[sl])) return model.eval() # ── CLI ─────────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt", default="model.safetensors", help="Path to packed safetensors") parser.add_argument("--tokenizer", default="tokenizer.json", help="Path to tokenizer.json") parser.add_argument("--prompt", default="The quick brown fox", help="Text prompt") parser.add_argument("--max-tokens", type=int, default=200, help="Max new tokens to generate") parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature") parser.add_argument("--top-p", type=float, default=0.9, help="Top-p nucleus sampling") args = parser.parse_args() print(f"Loading {args.ckpt} ...") model = load_packed_model(args.ckpt) packed_bytes = sum(p.numel() * p.element_size() for p in model.buffers()) \ + sum(p.numel() * p.element_size() for p in model.parameters()) print(f"Memory: {packed_bytes/1024/1024:.1f} MB") tok = Tokenizer.from_file(args.tokenizer) ids = [BOS_ID] + tok.encode(args.prompt).ids print(f"\nPrompt: {args.prompt}") print("Output: ", end="", flush=True) out = model.generate(ids, args.max_tokens, args.temperature, args.top_p) print(tok.decode(out)) if __name__ == "__main__": main()