SSMoELM-Base / inference.py
brulee-1's picture
Upload inference.py with huggingface_hub
1697ebf verified
Raw
History Blame Contribute Delete
16.3 kB
"""
SSMoELM Packed Inference β€” 12MB パヒγƒͺζŽ¨θ«–
packed uint8 weights をパヒγƒͺγ«δΏζŒγ—γ€forward時にγ‚ͺγƒ³γƒ‡γƒžγƒ³γƒ‰γ§ unpack する。
使い方:
python inference_packed.py --prompt "Hello"
"""
import argparse
import math
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.numpy import load_file
from tokenizers import Tokenizer
D_MODEL = 768
N_LAYERS = 6
N_HEADS = 12
KV_HEADS = 3
HEAD_DIM = 64
N_EXPERTS = 8
N_ACTIVE = 2
D_FF = 256
VOCAB_SIZE = 8192
CTX_LEN = 2048
BOS_ID, EOS_ID, EOT_ID = 0, 1, 6
# ── Packed Linear Modules ────────────────────────────────────────────────────
class Linear1bit(nn.Module):
"""1-bit packed linear layer: scale(fp16) + packed bits(uint8) β†’ fp32 matmul on-the-fly"""
def __init__(self, out_f: int, in_f: int):
super().__init__()
self.in_features = in_f
self.register_buffer("scale", torch.zeros(out_f, dtype=torch.float16))
self.register_buffer("packed", torch.zeros(out_f, (in_f + 7) // 8, dtype=torch.uint8))
def _unpack(self) -> torch.Tensor:
# packed: [out, ceil(in/8)] β†’ [out, in] values Β±1
bits = ((self.packed.unsqueeze(-1)
>> torch.arange(7, -1, -1, device=self.packed.device, dtype=torch.uint8))
& 1) # [out, ceil(in/8), 8]
bits = bits.reshape(self.packed.shape[0], -1)[:, :self.in_features] # [out, in]
w = bits.float() * 2.0 - 1.0 # {0,1} β†’ {-1,1}
w = w * self.scale.float().unsqueeze(-1) # row-wise scale
return w # [out, in]
def forward(self, x: torch.Tensor) -> torch.Tensor:
w = self._unpack()
out = F.linear(x, w)
del w
return out
class Linear4bit(nn.Module):
"""4-bit nibble-packed linear layer"""
def __init__(self, out_f: int, in_f: int):
super().__init__()
self.in_features = in_f
self.register_buffer("scale", torch.zeros(out_f, dtype=torch.float16))
self.register_buffer("packed", torch.zeros(out_f, (in_f + 1) // 2, dtype=torch.uint8))
def _unpack(self) -> torch.Tensor:
lo = (self.packed & 0x0F).to(torch.int8) - 8 # [out, in//2]
hi = ((self.packed >> 4) & 0x0F).to(torch.int8) - 8
w = torch.stack([lo, hi], dim=-1).reshape(self.packed.shape[0], -1)
w = w[:, :self.in_features].float()
w = w * self.scale.float().unsqueeze(-1)
return w
def forward(self, x: torch.Tensor) -> torch.Tensor:
w = self._unpack()
out = F.linear(x, w)
del w
return out
class EmbeddingPacked(nn.Module):
"""4-bit packed embedding table"""
def __init__(self, vocab: int, d: int):
super().__init__()
self.vocab = vocab
self.d = d
self.register_buffer("scale", torch.zeros(vocab, dtype=torch.float16))
self.register_buffer("packed", torch.zeros(vocab, (d + 1) // 2, dtype=torch.uint8))
def get_weight(self) -> torch.Tensor:
lo = (self.packed & 0x0F).to(torch.int8) - 8
hi = ((self.packed >> 4) & 0x0F).to(torch.int8) - 8
w = torch.stack([lo, hi], dim=-1).reshape(self.vocab, -1)[:, :self.d].float()
return w * self.scale.float().unsqueeze(-1)
def forward(self, idx: torch.Tensor) -> torch.Tensor:
return self.get_weight()[idx]
# ── Model ─────────────────────────────────────────────────────────────────────
class RMSNorm(nn.Module):
def __init__(self, d: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(d))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = (x.float().pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()
return (self.weight * (x.float() * rms)).to(x.dtype)
def precompute_rope(head_dim: int, max_len: int, base: float = 10000.0) -> torch.Tensor:
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
freqs = torch.outer(torch.arange(max_len).float(), inv_freq)
return torch.cat([freqs, freqs], dim=-1)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
h = x.shape[-1] // 2
return torch.cat([-x[..., h:], x[..., :h]], dim=-1)
def apply_rope(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
cos = freqs.cos()[None, :, None, :].to(x.dtype)
sin = freqs.sin()[None, :, None, :].to(x.dtype)
return x * cos + rotate_half(x) * sin
class Attention(nn.Module):
def __init__(self, layer_idx: int):
super().__init__()
boundary = {0, 5}
vo_4bit = layer_idx in boundary
d = D_MODEL
self.n_heads = N_HEADS
self.kv_heads = KV_HEADS
self.head_dim = HEAD_DIM
self.n_rep = N_HEADS // KV_HEADS
self.q_proj = Linear4bit(N_HEADS * HEAD_DIM, d)
self.k_proj = Linear4bit(KV_HEADS * HEAD_DIM, d)
self.v_proj = Linear4bit(KV_HEADS * HEAD_DIM, d) if vo_4bit else Linear1bit(KV_HEADS * HEAD_DIM, d)
self.o_proj = Linear4bit(d, N_HEADS * HEAD_DIM) if vo_4bit else Linear1bit(d, N_HEADS * HEAD_DIM)
def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
B, T, _ = x.shape
q = self.q_proj(x).reshape(B, T, self.n_heads, self.head_dim)
k = self.k_proj(x).reshape(B, T, self.kv_heads, self.head_dim)
v = self.v_proj(x).reshape(B, T, self.kv_heads, self.head_dim)
q, k = apply_rope(q, freqs), apply_rope(k, freqs)
k = k.repeat_interleave(self.n_rep, dim=2)
v = v.repeat_interleave(self.n_rep, dim=2)
out = F.scaled_dot_product_attention(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
return self.o_proj(out.transpose(1, 2).reshape(B, T, -1))
class SwiGLU(nn.Module):
def __init__(self, d_model: int, d_ff: int, bits: int):
super().__init__()
L = Linear4bit if bits == 4 else Linear1bit
self.gate = L(d_ff, d_model)
self.up = L(d_ff, d_model)
self.down = L(d_model, d_ff)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down(F.silu(self.gate(x)) * self.up(x))
class MoELayer(nn.Module):
def __init__(self):
super().__init__()
self.shared = SwiGLU(D_MODEL, D_FF, bits=4)
# stacked routed expert weights (1-bit)
self.gate_scale = nn.ParameterList([nn.Parameter(torch.zeros(D_FF), requires_grad=False) for _ in range(N_EXPERTS)])
self.gate_packed = nn.ParameterList([nn.Parameter(torch.zeros(D_FF, (D_MODEL+7)//8, dtype=torch.uint8), requires_grad=False) for _ in range(N_EXPERTS)])
self.up_scale = nn.ParameterList([nn.Parameter(torch.zeros(D_FF), requires_grad=False) for _ in range(N_EXPERTS)])
self.up_packed = nn.ParameterList([nn.Parameter(torch.zeros(D_FF, (D_MODEL+7)//8, dtype=torch.uint8), requires_grad=False) for _ in range(N_EXPERTS)])
self.down_scale = nn.ParameterList([nn.Parameter(torch.zeros(D_MODEL), requires_grad=False) for _ in range(N_EXPERTS)])
self.down_packed = nn.ParameterList([nn.Parameter(torch.zeros(D_MODEL, (D_FF+7)//8, dtype=torch.uint8), requires_grad=False) for _ in range(N_EXPERTS)])
self.router = nn.Parameter(torch.zeros(N_EXPERTS, D_MODEL))
def _unpack1bit(self, scale: torch.Tensor, packed: torch.Tensor, in_f: int) -> torch.Tensor:
bits = ((packed.unsqueeze(-1)
>> torch.arange(7, -1, -1, device=packed.device, dtype=torch.uint8)) & 1)
bits = bits.reshape(packed.shape[0], -1)[:, :in_f].float() * 2.0 - 1.0
return bits * scale.float().unsqueeze(-1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, d = x.shape
shared_out = self.shared(x)
logits = x @ self.router.T
top_idx = logits.topk(N_ACTIVE, dim=-1).indices
top_w = F.softmax(logits.gather(-1, top_idx).float(), dim=-1).to(x.dtype)
# process only active experts (memory-efficient)
x_flat = x.reshape(B * T, d)
out = torch.zeros_like(x_flat)
for k in range(N_ACTIVE):
e_idx = top_idx[..., k].reshape(-1) # [B*T]
w_k = top_w[..., k].reshape(-1, 1) # [B*T, 1]
for e in range(N_EXPERTS):
mask = (e_idx == e)
if not mask.any():
continue
x_e = x_flat[mask]
# unpack this expert's weights on-the-fly
wg = self._unpack1bit(self.gate_scale[e], self.gate_packed[e], D_MODEL)
wu = self._unpack1bit(self.up_scale[e], self.up_packed[e], D_MODEL)
wd = self._unpack1bit(self.down_scale[e], self.down_packed[e], D_FF)
h = F.silu(F.linear(x_e, wg)) * F.linear(x_e, wu)
out[mask] += F.linear(h, wd) * w_k[mask]
del wg, wu, wd, h
return shared_out + out.reshape(B, T, d)
class TransformerLayer(nn.Module):
def __init__(self, layer_idx: int):
super().__init__()
self.attn_norm = RMSNorm(D_MODEL)
self.ffn_norm = RMSNorm(D_MODEL)
self.attn = Attention(layer_idx)
self.moe = MoELayer()
def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.attn_norm(x), freqs)
x = x + self.moe(self.ffn_norm(x))
return x
class SSMoELMPacked(nn.Module):
def __init__(self):
super().__init__()
self.embed = EmbeddingPacked(VOCAB_SIZE, D_MODEL)
self.layers = nn.ModuleList([TransformerLayer(i) for i in range(N_LAYERS)])
self.norm = RMSNorm(D_MODEL)
self.register_buffer("freqs", precompute_rope(HEAD_DIM, CTX_LEN))
def forward(self, x: torch.Tensor) -> torch.Tensor:
T = x.shape[1]
h = self.embed(x).float()
freqs = self.freqs[:T]
for layer in self.layers:
h = layer(h, freqs)
h = self.norm(h)
w = self.embed.get_weight()
return h @ w.T
@torch.inference_mode()
def generate(self, input_ids: list[int], max_new_tokens: int = 200,
temperature: float = 0.8, top_p: float = 0.9,
eos_ids: tuple[int, ...] = (EOS_ID, EOT_ID)) -> list[int]:
ids = list(input_ids)
generated = []
for _ in range(max_new_tokens):
x = torch.tensor([ids[-CTX_LEN:]], dtype=torch.long)
logits = self(x)[0, -1]
if temperature > 0:
logits_np = logits.numpy().astype(np.float64)
logits_np = (logits_np - logits_np.max()) / temperature
probs = np.exp(logits_np); probs /= probs.sum()
idx = np.argsort(-probs); cumsum = np.cumsum(probs[idx])
cutoff = np.searchsorted(cumsum, top_p) + 1
probs[idx[cutoff:]] = 0.0; probs /= probs.sum()
next_id = int(np.random.choice(idx, p=probs))
else:
next_id = int(logits.argmax().item())
if next_id in eos_ids:
break
generated.append(next_id)
ids.append(next_id)
return generated
# ── Load ──────────────────────────────────────────────────────────────────────
def load_packed_model(path: str) -> SSMoELMPacked:
data = load_file(path)
data = {k.replace("/", "."): v for k, v in data.items()}
model = SSMoELMPacked()
def _set(module, scale_name, packed_name, key_base):
s_key = f"{key_base}__scale"
p_key_bin = f"{key_base}__bin"
p_key_int4 = f"{key_base}__int4"
if s_key in data:
getattr(module, scale_name).data.copy_(torch.from_numpy(data[s_key].astype(np.float16)))
if p_key_bin in data:
getattr(module, packed_name).data.copy_(torch.from_numpy(data[p_key_bin]))
elif p_key_int4 in data:
getattr(module, packed_name).data.copy_(torch.from_numpy(data[p_key_int4]))
# embed
_set(model.embed, "scale", "packed", "embed_weight")
for i, layer in enumerate(model.layers):
pfx = f"layers.{i}"
# attention
for proj_name, key in [("q_proj","q_weight"),("k_proj","k_weight"),
("v_proj","v_weight"),("o_proj","o_weight")]:
proj = getattr(layer.attn, proj_name)
_set(proj, "scale", "packed", f"{pfx}.attn.{key}")
# norms
for norm_name, key in [("attn_norm","attn_norm"),("ffn_norm","ffn_norm")]:
norm = getattr(layer, norm_name)
w = data.get(f"{pfx}.{key}.weight")
if w is not None:
norm.weight.data.copy_(torch.from_numpy(w.astype(np.float32)))
# shared expert
se = layer.moe.shared
for attr, wname in [("gate","gate_weight"),("up","up_weight"),("down","down_weight")]:
_set(getattr(se, attr), "scale", "packed", f"{pfx}.moe.shared_expert.{wname}")
# router
rw = data.get(f"{pfx}.moe.router_weight")
if rw is not None:
layer.moe.router.data.copy_(torch.from_numpy(rw.astype(np.float32)))
# routed experts: stacked weights [E, out, in] β†’ scale [E*out], packed [E*out, ...]
for attr, wname, out_f, in_f in [
("gate", "gate", D_FF, D_MODEL),
("up", "up", D_FF, D_MODEL),
("down", "down", D_MODEL, D_FF),
]:
s_key = f"{pfx}.moe.{wname}_weight__scale"
p_key = (f"{pfx}.moe.{wname}_weight__bin"
if f"{pfx}.moe.{wname}_weight__bin" in data
else f"{pfx}.moe.{wname}_weight__int4")
if s_key not in data or p_key not in data:
continue
s_arr = data[s_key] # [E*out_f]
p_arr = data[p_key] # [E*out_f, packed_cols]
for e in range(N_EXPERTS):
sl = slice(e * out_f, (e + 1) * out_f)
getattr(layer.moe, f"{attr}_scale")[e].data.copy_(
torch.from_numpy(s_arr[sl].astype(np.float16)))
getattr(layer.moe, f"{attr}_packed")[e].data.copy_(
torch.from_numpy(p_arr[sl]))
return model.eval()
# ── CLI ───────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", default="model.safetensors", help="Path to packed safetensors")
parser.add_argument("--tokenizer", default="tokenizer.json", help="Path to tokenizer.json")
parser.add_argument("--prompt", default="The quick brown fox", help="Text prompt")
parser.add_argument("--max-tokens", type=int, default=200, help="Max new tokens to generate")
parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
parser.add_argument("--top-p", type=float, default=0.9, help="Top-p nucleus sampling")
args = parser.parse_args()
print(f"Loading {args.ckpt} ...")
model = load_packed_model(args.ckpt)
packed_bytes = sum(p.numel() * p.element_size() for p in model.buffers()) \
+ sum(p.numel() * p.element_size() for p in model.parameters())
print(f"Memory: {packed_bytes/1024/1024:.1f} MB")
tok = Tokenizer.from_file(args.tokenizer)
ids = [BOS_ID] + tok.encode(args.prompt).ids
print(f"\nPrompt: {args.prompt}")
print("Output: ", end="", flush=True)
out = model.generate(ids, args.max_tokens, args.temperature, args.top_p)
print(tok.decode(out))
if __name__ == "__main__":
main()