smartcore-v1 / code /kod /lamba_cpu.py
kdirgul's picture
lamba_cpu: cuda-guvenli (GQACPU._rope arange + Mamba3CPU.forward zeros device=)
ba7a119 verified
Raw
History Blame Contribute Delete
19.6 kB
"""
LAMBA V1.0 — saf-PyTorch CPU inference (Triton/mamba_ssm YOK). [v1.1]
Fork (mamba_ssm "norm-free" Mamba-3, Triton kernel) ile EĞİTİLEN ağırlıkları (lamba_v1.pt)
GPU'suz çalıştırmak için fork forward matematiğinin saf-PyTorch reimplementasyonu.
Mamba-3 mixer matematiği (vendor mamba3.py + siso_step kernelinden çıkarıldı), per-head:
_A = -softplus(dd_A) (clamp ≤ -1e-4) # ← data-dependent A (norm-free farkı)
DT = softplus(dd_dt + dt_bias)
trap = sigmoid(trap_proj)
α = exp(_A·DT) # base-e ✅ (kernel exp2(ön-ölçekli adt) ile birebir)
β = α·DT·(1-trap) ; γ = trap·DT # trapezoidal
h = α·h + β·(x_prev ⊗ B_prev) + γ·(x ⊗ B) # B,x = K,V ; B üzerinde partial-RoPE
y = h @ C # C = Q ; C üzerinde partial-RoPE
y += D·x ; y *= silu(z)
B,C paylaşımlı (ngroups=1) → head'lere broadcast + per-head bias. RoPE: rope_fraction=0.5
(ilk d_state/2=64 boyut, 32 angle, INTERLEAVED [çift (2j,2j+1)]). Token-token recurrence + decode-cache (step).
✅ KALİBRE TAMAM (2026-06-27): fork'a full-logit fp32 maxdiff 0.06–0.09 (top-5/argmax birebir).
Kritik düzeltme MLP'deydi: GatedMLP = y·silu(gate) (1.yarı=değer, 2.yarı=gate; mixer değil).
Decode-cache (step) full-recompute ile birebir aynı çıktı, ~6× hızlı (O(L²)→O(L)).
Kullanım (CPU): python lamba_cpu.py --ckpt checkpoints/lamba_v1.pt --tokenizer tokenizer/tokenizer.model --query "..."
"""
import os, sys, math, argparse
import torch, torch.nn as nn, torch.nn.functional as F
torch.set_num_threads(max(1, os.cpu_count() or 4))
def rms_norm(x, w, eps=1e-5):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * w
# ───────────── Mamba-3 mixer (saf-PyTorch, fork matematiği) ─────────────
class Mamba3CPU(nn.Module):
# KALİBRE bayrakları (Colab grid-search ile fork'a karşı belirlenir)
ROPE_INTER = True # fork rotate_pairwise=True ⇒ interleaved (q[0::2], q[1::2])
BETA_ALPHA = True # KALİBRE: β'da alpha çarpanı (grid: True daha iyi)
BETA_SHIFT = False # KALİBRE: β'da dt/trap kaydırılmış mı (grid: False daha iyi)
EXP_E = True # ✅ ÇÖZÜLDÜ: α = exp(_A·DT) base-e (exp2 değil) — mixer kalibre (maxdiff 0.14 bf16)
def __init__(self, cfg):
super().__init__()
d = cfg["d_model"]
self.d_inner = cfg["expand"] * d # 1536
self.headdim = cfg["head_dim"] # 64
self.nheads = self.d_inner // self.headdim # 24
self.d_state = cfg["d_state"] # 128
self.ngroups = cfg.get("ngroups", 1) # num_bc_heads = 1
self.A_floor = 1e-4
rope_fraction = cfg.get("rope_fraction", 0.5)
self.rot = int(self.d_state * rope_fraction) # 64 → döner boyut
if self.rot % 2:
self.rot -= 1
self.n_ang = self.rot // 2 # 32 angle pair
bc = self.d_state * self.ngroups # 128 (SISO, rank=1)
d_in = 2 * self.d_inner + 2 * bc + 3 * self.nheads + self.n_ang # 3432
self.in_proj = nn.Linear(d, d_in, bias=False)
self.out_proj = nn.Linear(self.d_inner, d, bias=False)
self.dt_bias = nn.Parameter(torch.zeros(self.nheads))
self.D = nn.Parameter(torch.ones(self.nheads))
self.B_bias = nn.Parameter(torch.zeros(self.nheads, 1, self.d_state))
self.C_bias = nn.Parameter(torch.zeros(self.nheads, 1, self.d_state))
self.B_norm = nn.Parameter(torch.ones(bc))
self.C_norm = nn.Parameter(torch.ones(bc))
def _rope(self, t, cos, sin):
"""Partial RoPE (ilk `rot`=d_state/2 boyut döner; gerisi sabit).
pairwise: j ↔ j+n çiftleri | interleaved: (2j, 2j+1) çiftleri."""
rot, n = self.rot, self.n_ang
rest = t[..., rot:]
if Mamba3CPU.ROPE_INTER:
head = t[..., :rot]
x1, x2 = head[..., 0::2], head[..., 1::2]
ra, rb = x1 * cos - x2 * sin, x1 * sin + x2 * cos
out = torch.stack([ra, rb], dim=-1).flatten(-2)
return torch.cat([out, rest], dim=-1)
a, b = t[..., :n], t[..., n:rot]
ra, rb = a * cos - b * sin, a * sin + b * cos
return torch.cat([ra, rb, rest], dim=-1)
def forward(self, u):
"""u: (B, L, d_model) → (B, L, d_model). Token-token recurrence."""
B, L, _ = u.shape
H, P, S = self.nheads, self.headdim, self.d_state
proj = self.in_proj(u)
z, x, Bm, Cm, dd_dt, dd_A, trap, ang = torch.split(
proj, [self.d_inner, self.d_inner, S * self.ngroups, S * self.ngroups,
H, H, H, self.n_ang], dim=-1)
x = x.view(B, L, H, P)
z = z.view(B, L, H, P)
_A = (-F.softplus(dd_A.float())).clamp(max=-self.A_floor) # (B,L,H)
DT = F.softplus(dd_dt.float() + self.dt_bias) # (B,L,H)
trap = torch.sigmoid(trap.float()) # (B,L,H)
Bm = rms_norm(Bm.float(), self.B_norm) # (B,L,S) ngroups=1
Cm = rms_norm(Cm.float(), self.C_norm)
Bm = Bm.view(B, L, 1, S) + self.B_bias.view(1, 1, H, S) # head'e broadcast + bias
Cm = Cm.view(B, L, 1, S) + self.C_bias.view(1, 1, H, S)
alpha = torch.exp(_A * DT) if Mamba3CPU.EXP_E else torch.exp2(_A * DT) # base-e vs base-2 decay
bDT = torch.roll(DT, 1, dims=1) if Mamba3CPU.BETA_SHIFT else DT # β'da kaydırılmış dt/trap?
btrap = torch.roll(trap, 1, dims=1) if Mamba3CPU.BETA_SHIFT else trap
beta = (alpha if Mamba3CPU.BETA_ALPHA else 1.0) * bDT * (1 - btrap)
gamma = trap * DT
# KALİBRE: angle birikimi (işaret/dt-ölçeği) — ilk tahmin: cum += DT·angles
h = torch.zeros(B, H, P, S, device=u.device) # ssm_state (V,QK) = (P,S)
x_prev = torch.zeros(B, H, P, device=u.device)
Bk_prev = torch.zeros(B, H, S, device=u.device)
cum = torch.zeros(B, H, self.n_ang, device=u.device)
ys = []
for t in range(L):
# fork: angle = angle_state + tanh(angle_proj)·DT·π (mamba3_mimo_rotary_step referans)
inc = torch.tanh(ang[:, t].float()).unsqueeze(1) * DT[:, t].unsqueeze(-1) * math.pi
cum = cum + inc # (B,H,n_ang)
cos, sin = torch.cos(cum), torch.sin(cum)
Bk = self._rope(Bm[:, t], cos, sin) # (B,H,S)
Cq = self._rope(Cm[:, t], cos, sin)
xt = x[:, t] # (B,H,P)
a = alpha[:, t].view(B, H, 1, 1)
diff = (beta[:, t].view(B, H, 1, 1) * x_prev.unsqueeze(-1) * Bk_prev.unsqueeze(-2)
+ gamma[:, t].view(B, H, 1, 1) * xt.unsqueeze(-1) * Bk.unsqueeze(-2))
h = h * a + diff # (B,H,P,S)
y = (h * Cq.unsqueeze(-2)).sum(-1) # (B,H,P)
y = y + self.D.view(1, H, 1) * xt
y = y * F.silu(z[:, t])
ys.append(y.reshape(B, 1, H * P))
x_prev, Bk_prev = xt, Bk
return self.out_proj(torch.cat(ys, dim=1)) # fp32 (int8-quant uyumlu: .weight'e dokunma)
# ───── decode-cache (tek-token step; forward'ın bir iterasyonu, O(1)) ─────
def init_state(self, B, device=None, dtype=torch.float32):
H, P, S = self.nheads, self.headdim, self.d_state
z = lambda *s: torch.zeros(*s, device=device, dtype=dtype)
return [z(B, H, P, S), z(B, H, P), z(B, H, S), z(B, H, self.n_ang)] # h, x_prev, Bk_prev, cum
def step(self, u, state):
"""u:(B,d_model), state=[h,x_prev,Bk_prev,cum] → (y:(B,d_model), yeni_state)."""
B = u.shape[0]
H, P, S = self.nheads, self.headdim, self.d_state
h, x_prev, Bk_prev, cum = state
z, x, Bm, Cm, dd_dt, dd_A, trap, ang = torch.split(
self.in_proj(u), [self.d_inner, self.d_inner, S * self.ngroups, S * self.ngroups,
H, H, H, self.n_ang], dim=-1)
x = x.view(B, H, P); z = z.view(B, H, P)
_A = (-F.softplus(dd_A.float())).clamp(max=-self.A_floor)
DT = F.softplus(dd_dt.float() + self.dt_bias)
trap = torch.sigmoid(trap.float())
Bm = rms_norm(Bm.float(), self.B_norm).view(B, 1, S) + self.B_bias.view(1, H, S)
Cm = rms_norm(Cm.float(), self.C_norm).view(B, 1, S) + self.C_bias.view(1, H, S)
alpha = torch.exp(_A * DT) if Mamba3CPU.EXP_E else torch.exp2(_A * DT)
beta = (alpha if Mamba3CPU.BETA_ALPHA else 1.0) * DT * (1 - trap) # BETA_SHIFT=False
gamma = trap * DT
cum = cum + torch.tanh(ang.float()).unsqueeze(1) * DT.unsqueeze(-1) * math.pi
cos, sin = torch.cos(cum), torch.sin(cum)
Bk = self._rope(Bm, cos, sin); Cq = self._rope(Cm, cos, sin)
diff = (beta.view(B, H, 1, 1) * x_prev.unsqueeze(-1) * Bk_prev.unsqueeze(-2)
+ gamma.view(B, H, 1, 1) * x.unsqueeze(-1) * Bk.unsqueeze(-2))
h = h * alpha.view(B, H, 1, 1) + diff
y = (h * Cq.unsqueeze(-2)).sum(-1) + self.D.view(1, H, 1) * x
y = (y * F.silu(z)).reshape(B, H * P)
return self.out_proj(y), [h, x, Bk, cum] # fp32 (int8-quant uyumlu)
# ───────────── GQA mixer (saf-PyTorch; hybrid_mamba3 ile aynı) ─────────────
def _rot_half(x):
a, b = x.chunk(2, -1)
return torch.cat((-b, a), -1)
class GQACPU(nn.Module):
def __init__(self, cfg, base=10000.0):
super().__init__()
d = cfg["d_model"]; self.nh = cfg["n_heads"]; self.nkv = cfg["n_kv_heads"]
self.hd = d // self.nh; self.rep = self.nh // self.nkv
self.q_proj = nn.Linear(d, self.nh * self.hd, bias=False)
self.k_proj = nn.Linear(d, self.nkv * self.hd, bias=False)
self.v_proj = nn.Linear(d, self.nkv * self.hd, bias=False)
self.out_proj = nn.Linear(self.nh * self.hd, d, bias=False)
self.qn = nn.Parameter(torch.ones(self.hd))
self.kn = nn.Parameter(torch.ones(self.hd))
self.register_buffer("inv", 1.0 / (base ** (torch.arange(0, self.hd, 2).float() / self.hd)), persistent=False)
def _rope(self, x, T):
f = torch.outer(torch.arange(T, device=x.device).float(), self.inv); e = torch.cat((f, f), -1)
return (x * e.cos()[None, None] + _rot_half(x) * e.sin()[None, None]).to(x.dtype)
def forward(self, x):
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.nh, self.hd).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
q = rms_norm(q.float(), self.qn.float()).to(x.dtype)
k = rms_norm(k.float(), self.kn.float()).to(x.dtype)
q, k = self._rope(q, T), self._rope(k, T)
k = k.repeat_interleave(self.rep, 1); v = v.repeat_interleave(self.rep, 1)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.out_proj(y.transpose(1, 2).reshape(B, T, -1))
# ───── decode-cache (KV-cache; q tek token vs tüm geçmiş = nedensel) ─────
def init_state(self, B, device=None, dtype=torch.float32):
return [None, None] # k_cache, v_cache (B,nkv,t,hd)
def _rope_at(self, x, pos):
f = (self.inv * float(pos)).unsqueeze(0) # (1, hd/2)
e = torch.cat((f, f), -1)
return (x * e.cos()[None, None] + _rot_half(x) * e.sin()[None, None]).to(x.dtype)
def step(self, x, state):
"""x:(B,d_model), state=[k_cache,v_cache] → (y:(B,d_model), yeni_state)."""
B = x.shape[0]; kc, vc = state
pos = 0 if kc is None else kc.shape[2]
xq = x.view(B, 1, -1)
q = self.q_proj(xq).view(B, 1, self.nh, self.hd).transpose(1, 2)
k = self.k_proj(xq).view(B, 1, self.nkv, self.hd).transpose(1, 2)
v = self.v_proj(xq).view(B, 1, self.nkv, self.hd).transpose(1, 2)
q = rms_norm(q.float(), self.qn.float()).to(x.dtype)
k = rms_norm(k.float(), self.kn.float()).to(x.dtype)
q = self._rope_at(q, pos); k = self._rope_at(k, pos)
kc = k if kc is None else torch.cat([kc, k], dim=2)
vc = v if vc is None else torch.cat([vc, v], dim=2)
kk = kc.repeat_interleave(self.rep, 1); vv = vc.repeat_interleave(self.rep, 1)
y = F.scaled_dot_product_attention(q, kk, vv, is_causal=False)
return self.out_proj(y.transpose(1, 2).reshape(B, -1)), [kc, vc]
class GatedMLP(nn.Module):
def __init__(self, cfg):
super().__init__()
d = cfg["d_model"]
mult = 128 # mamba_ssm GatedMLP: hidden 128'in katına yuvarlar
hidden = ((cfg["d_intermediate"] + mult - 1) // mult) * mult # 1500 → 1536
self.fc1 = nn.Linear(d, 2 * hidden, bias=False)
self.fc2 = nn.Linear(hidden, d, bias=False)
def forward(self, x):
# mamba_ssm GatedMLP: 1. yarı = değer, 2. yarı = gate → y * silu(gate)
y, gate = self.fc1(x).chunk(2, -1)
return self.fc2(y * F.silu(gate))
class Block(nn.Module):
def __init__(self, cfg, is_attn):
super().__init__()
self.norm = nn.Parameter(torch.ones(cfg["d_model"]))
self.norm2 = nn.Parameter(torch.ones(cfg["d_model"]))
self.mixer = GQACPU(cfg) if is_attn else Mamba3CPU(cfg)
self.mlp = GatedMLP(cfg)
def forward(self, x):
x = x + self.mixer(rms_norm(x, self.norm))
x = x + self.mlp(rms_norm(x, self.norm2))
return x
def init_state(self, B, **kw):
return self.mixer.init_state(B, **kw)
def step(self, x, mstate):
m_out, mstate = self.mixer.step(rms_norm(x, self.norm), mstate)
x = x + m_out
x = x + self.mlp(rms_norm(x, self.norm2))
return x, mstate
class LambaCPU(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.embedding = nn.Embedding(cfg["vocab_size"], cfg["d_model"])
self.layers = nn.ModuleList()
for i in range(cfg["n_layers"]):
is_attn = ((i + 1) % cfg["attn_every"] == 0) and i != 0 and i != cfg["n_layers"] - 1
self.layers.append(Block(cfg, is_attn))
self.norm_f = nn.Parameter(torch.ones(cfg["d_model"]))
self.lm_head = nn.Linear(cfg["d_model"], cfg["vocab_size"], bias=False)
self.lm_head.weight = self.embedding.weight
def forward(self, ids):
h = self.embedding(ids)
for l in self.layers:
h = l(h)
return self.lm_head(rms_norm(h, self.norm_f))
def init_states(self, B, **kw):
return [l.init_state(B, **kw) for l in self.layers]
@torch.no_grad()
def step_forward(self, ids_step, states):
"""ids_step:(B,1) tek token. → (logits:(B,V), yeni_states). O(1) Mamba + O(t) GQA-attn."""
h = self.embedding(ids_step)[:, 0] # (B,d)
new_states = []
for l, stt in zip(self.layers, states):
h, ns = l.step(h, stt)
new_states.append(ns)
return self.lm_head(rms_norm(h, self.norm_f)), new_states
# ───────────── ağırlık yükleyici (fork ckpt → CPU model) ─────────────
def load_lamba(ckpt_path):
st = torch.load(ckpt_path, map_location="cpu")
cfg, sd = st["cfg"], st["model"]
model = LambaCPU(cfg)
mp = {}
for k, v in sd.items():
nk = k
nk = nk.replace(".mixer.in_proj.", ".mixer.in_proj.").replace(".mixer.out_proj.", ".mixer.out_proj.")
# GQA fork→CPU isim eşlemesi zaten birebir (q_proj/k_proj/v_proj/out_proj/qn/kn)
# Mamba B_norm.weight/C_norm.weight → B_norm/C_norm (Parameter)
nk = nk.replace(".mixer.B_norm.weight", ".mixer.B_norm").replace(".mixer.C_norm.weight", ".mixer.C_norm")
nk = nk.replace("norm_f.weight", "norm_f")
nk = nk.replace(".norm.weight", ".norm").replace(".norm2.weight", ".norm2")
mp[nk] = v
# norm_f / lm_head / embedding
miss, unexp = model.load_state_dict(mp, strict=False)
print(f"[load] eksik={len(miss)} beklenmeyen={len(unexp)}", flush=True)
if miss:
print(" ilk eksikler:", miss[:8])
if unexp:
print(" ilk beklenmeyenler:", unexp[:8])
model.eval()
return model, cfg
@torch.no_grad()
def generate(model, sp, prompt, max_new=64, temperature=0.0, top_k=40, top_p=0.9,
rep_penalty=1.2, device=None):
"""Decode-cache'li O(L) üretim (step_forward). faz7_rag.generate ile signature-uyumlu."""
dev = device or next(model.parameters()).device
ids = sp.encode(prompt, out_type=int); eos = sp.eos_id()
states = model.init_states(1, device=dev)
logits = None
for tid in ids: # prefill: her token 1 kez (O(L))
logits, states = model.step_forward(torch.tensor([[tid]], device=dev), states)
out = []
for _ in range(max_new):
lg = logits[0].float()
if rep_penalty != 1.0:
for t in set(ids + out):
lg[t] = lg[t] / rep_penalty if lg[t] > 0 else lg[t] * rep_penalty
if temperature <= 0:
nxt = int(lg.argmax())
else:
lg = lg / temperature
if top_k:
kth = torch.topk(lg, min(top_k, lg.numel())).values[-1]; lg[lg < kth] = -float("inf")
probs = F.softmax(lg, -1)
if top_p < 1.0:
s, si = torch.sort(probs, descending=True); cut = torch.cumsum(s, -1) > top_p
cut[1:] = cut[:-1].clone(); cut[0] = False; s[cut] = 0
probs = torch.zeros_like(probs).scatter_(0, si, s); probs /= probs.sum()
nxt = int(torch.multinomial(probs, 1))
if nxt == eos:
break
out.append(nxt)
logits, states = model.step_forward(torch.tensor([[nxt]], device=dev), states)
return sp.decode(out)
def quantize_int8(model):
"""CPU int8 dynamic quant (nn.Linear): ~2× küçük bellek (708→325MB), ~1.2× hız,
çıktı greedy'de ~birebir. SSM recurrence fp32 kalır (darboğaz orada; int8 GEMM'leri hızlandırır)."""
qd = getattr(torch.ao.quantization, "quantize_dynamic", None) or torch.quantization.quantize_dynamic
return qd(model, {nn.Linear}, dtype=torch.qint8)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", required=True)
ap.add_argument("--tokenizer", required=True)
ap.add_argument("--query", default="Merhaba")
ap.add_argument("--max_new", type=int, default=64)
ap.add_argument("--int8", action="store_true", help="int8 dynamic quant (~2× küçük bellek, ~1.2× hız, ~birebir)")
args = ap.parse_args()
import sentencepiece as spm
sp = spm.SentencePieceProcessor(model_file=args.tokenizer)
model, cfg = load_lamba(args.ckpt)
if args.int8:
model = quantize_int8(model); print("[int8] dynamic quant uygulandı")
print(f"[model] {'MIMO' if cfg.get('is_mimo') else 'SISO'} | CPU | {sum(p.numel() for p in model.parameters())/1e6:.0f}M")
prompt = f"### Talimat:\n{args.query}\n\n### Yanıt:\n"
print("CEVAP:", generate(model, sp, prompt, args.max_new))
if __name__ == "__main__":
main()