| """ |
| LAMBA V1.0 — saf-PyTorch CPU inference (Triton/mamba_ssm YOK). [v1.1] |
| |
| Fork (mamba_ssm "norm-free" Mamba-3, Triton kernel) ile EĞİTİLEN ağırlıkları (lamba_v1.pt) |
| GPU'suz çalıştırmak için fork forward matematiğinin saf-PyTorch reimplementasyonu. |
| |
| Mamba-3 mixer matematiği (vendor mamba3.py + siso_step kernelinden çıkarıldı), per-head: |
| _A = -softplus(dd_A) (clamp ≤ -1e-4) # ← data-dependent A (norm-free farkı) |
| DT = softplus(dd_dt + dt_bias) |
| trap = sigmoid(trap_proj) |
| α = exp(_A·DT) # base-e ✅ (kernel exp2(ön-ölçekli adt) ile birebir) |
| β = α·DT·(1-trap) ; γ = trap·DT # trapezoidal |
| h = α·h + β·(x_prev ⊗ B_prev) + γ·(x ⊗ B) # B,x = K,V ; B üzerinde partial-RoPE |
| y = h @ C # C = Q ; C üzerinde partial-RoPE |
| y += D·x ; y *= silu(z) |
| B,C paylaşımlı (ngroups=1) → head'lere broadcast + per-head bias. RoPE: rope_fraction=0.5 |
| (ilk d_state/2=64 boyut, 32 angle, INTERLEAVED [çift (2j,2j+1)]). Token-token recurrence + decode-cache (step). |
| |
| ✅ KALİBRE TAMAM (2026-06-27): fork'a full-logit fp32 maxdiff 0.06–0.09 (top-5/argmax birebir). |
| Kritik düzeltme MLP'deydi: GatedMLP = y·silu(gate) (1.yarı=değer, 2.yarı=gate; mixer değil). |
| Decode-cache (step) full-recompute ile birebir aynı çıktı, ~6× hızlı (O(L²)→O(L)). |
| |
| Kullanım (CPU): python lamba_cpu.py --ckpt checkpoints/lamba_v1.pt --tokenizer tokenizer/tokenizer.model --query "..." |
| """ |
| import os, sys, math, argparse |
| import torch, torch.nn as nn, torch.nn.functional as F |
|
|
| torch.set_num_threads(max(1, os.cpu_count() or 4)) |
|
|
|
|
| def rms_norm(x, w, eps=1e-5): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * w |
|
|
|
|
| |
| class Mamba3CPU(nn.Module): |
| |
| ROPE_INTER = True |
| BETA_ALPHA = True |
| BETA_SHIFT = False |
| EXP_E = True |
|
|
| def __init__(self, cfg): |
| super().__init__() |
| d = cfg["d_model"] |
| self.d_inner = cfg["expand"] * d |
| self.headdim = cfg["head_dim"] |
| self.nheads = self.d_inner // self.headdim |
| self.d_state = cfg["d_state"] |
| self.ngroups = cfg.get("ngroups", 1) |
| self.A_floor = 1e-4 |
| rope_fraction = cfg.get("rope_fraction", 0.5) |
| self.rot = int(self.d_state * rope_fraction) |
| if self.rot % 2: |
| self.rot -= 1 |
| self.n_ang = self.rot // 2 |
| bc = self.d_state * self.ngroups |
| d_in = 2 * self.d_inner + 2 * bc + 3 * self.nheads + self.n_ang |
| self.in_proj = nn.Linear(d, d_in, bias=False) |
| self.out_proj = nn.Linear(self.d_inner, d, bias=False) |
| self.dt_bias = nn.Parameter(torch.zeros(self.nheads)) |
| self.D = nn.Parameter(torch.ones(self.nheads)) |
| self.B_bias = nn.Parameter(torch.zeros(self.nheads, 1, self.d_state)) |
| self.C_bias = nn.Parameter(torch.zeros(self.nheads, 1, self.d_state)) |
| self.B_norm = nn.Parameter(torch.ones(bc)) |
| self.C_norm = nn.Parameter(torch.ones(bc)) |
|
|
| def _rope(self, t, cos, sin): |
| """Partial RoPE (ilk `rot`=d_state/2 boyut döner; gerisi sabit). |
| pairwise: j ↔ j+n çiftleri | interleaved: (2j, 2j+1) çiftleri.""" |
| rot, n = self.rot, self.n_ang |
| rest = t[..., rot:] |
| if Mamba3CPU.ROPE_INTER: |
| head = t[..., :rot] |
| x1, x2 = head[..., 0::2], head[..., 1::2] |
| ra, rb = x1 * cos - x2 * sin, x1 * sin + x2 * cos |
| out = torch.stack([ra, rb], dim=-1).flatten(-2) |
| return torch.cat([out, rest], dim=-1) |
| a, b = t[..., :n], t[..., n:rot] |
| ra, rb = a * cos - b * sin, a * sin + b * cos |
| return torch.cat([ra, rb, rest], dim=-1) |
|
|
| def forward(self, u): |
| """u: (B, L, d_model) → (B, L, d_model). Token-token recurrence.""" |
| B, L, _ = u.shape |
| H, P, S = self.nheads, self.headdim, self.d_state |
| proj = self.in_proj(u) |
| z, x, Bm, Cm, dd_dt, dd_A, trap, ang = torch.split( |
| proj, [self.d_inner, self.d_inner, S * self.ngroups, S * self.ngroups, |
| H, H, H, self.n_ang], dim=-1) |
| x = x.view(B, L, H, P) |
| z = z.view(B, L, H, P) |
| _A = (-F.softplus(dd_A.float())).clamp(max=-self.A_floor) |
| DT = F.softplus(dd_dt.float() + self.dt_bias) |
| trap = torch.sigmoid(trap.float()) |
| Bm = rms_norm(Bm.float(), self.B_norm) |
| Cm = rms_norm(Cm.float(), self.C_norm) |
| Bm = Bm.view(B, L, 1, S) + self.B_bias.view(1, 1, H, S) |
| Cm = Cm.view(B, L, 1, S) + self.C_bias.view(1, 1, H, S) |
| alpha = torch.exp(_A * DT) if Mamba3CPU.EXP_E else torch.exp2(_A * DT) |
| bDT = torch.roll(DT, 1, dims=1) if Mamba3CPU.BETA_SHIFT else DT |
| btrap = torch.roll(trap, 1, dims=1) if Mamba3CPU.BETA_SHIFT else trap |
| beta = (alpha if Mamba3CPU.BETA_ALPHA else 1.0) * bDT * (1 - btrap) |
| gamma = trap * DT |
|
|
| |
| h = torch.zeros(B, H, P, S, device=u.device) |
| x_prev = torch.zeros(B, H, P, device=u.device) |
| Bk_prev = torch.zeros(B, H, S, device=u.device) |
| cum = torch.zeros(B, H, self.n_ang, device=u.device) |
| ys = [] |
| for t in range(L): |
| |
| inc = torch.tanh(ang[:, t].float()).unsqueeze(1) * DT[:, t].unsqueeze(-1) * math.pi |
| cum = cum + inc |
| cos, sin = torch.cos(cum), torch.sin(cum) |
| Bk = self._rope(Bm[:, t], cos, sin) |
| Cq = self._rope(Cm[:, t], cos, sin) |
| xt = x[:, t] |
| a = alpha[:, t].view(B, H, 1, 1) |
| diff = (beta[:, t].view(B, H, 1, 1) * x_prev.unsqueeze(-1) * Bk_prev.unsqueeze(-2) |
| + gamma[:, t].view(B, H, 1, 1) * xt.unsqueeze(-1) * Bk.unsqueeze(-2)) |
| h = h * a + diff |
| y = (h * Cq.unsqueeze(-2)).sum(-1) |
| y = y + self.D.view(1, H, 1) * xt |
| y = y * F.silu(z[:, t]) |
| ys.append(y.reshape(B, 1, H * P)) |
| x_prev, Bk_prev = xt, Bk |
| return self.out_proj(torch.cat(ys, dim=1)) |
|
|
| |
| def init_state(self, B, device=None, dtype=torch.float32): |
| H, P, S = self.nheads, self.headdim, self.d_state |
| z = lambda *s: torch.zeros(*s, device=device, dtype=dtype) |
| return [z(B, H, P, S), z(B, H, P), z(B, H, S), z(B, H, self.n_ang)] |
|
|
| def step(self, u, state): |
| """u:(B,d_model), state=[h,x_prev,Bk_prev,cum] → (y:(B,d_model), yeni_state).""" |
| B = u.shape[0] |
| H, P, S = self.nheads, self.headdim, self.d_state |
| h, x_prev, Bk_prev, cum = state |
| z, x, Bm, Cm, dd_dt, dd_A, trap, ang = torch.split( |
| self.in_proj(u), [self.d_inner, self.d_inner, S * self.ngroups, S * self.ngroups, |
| H, H, H, self.n_ang], dim=-1) |
| x = x.view(B, H, P); z = z.view(B, H, P) |
| _A = (-F.softplus(dd_A.float())).clamp(max=-self.A_floor) |
| DT = F.softplus(dd_dt.float() + self.dt_bias) |
| trap = torch.sigmoid(trap.float()) |
| Bm = rms_norm(Bm.float(), self.B_norm).view(B, 1, S) + self.B_bias.view(1, H, S) |
| Cm = rms_norm(Cm.float(), self.C_norm).view(B, 1, S) + self.C_bias.view(1, H, S) |
| alpha = torch.exp(_A * DT) if Mamba3CPU.EXP_E else torch.exp2(_A * DT) |
| beta = (alpha if Mamba3CPU.BETA_ALPHA else 1.0) * DT * (1 - trap) |
| gamma = trap * DT |
| cum = cum + torch.tanh(ang.float()).unsqueeze(1) * DT.unsqueeze(-1) * math.pi |
| cos, sin = torch.cos(cum), torch.sin(cum) |
| Bk = self._rope(Bm, cos, sin); Cq = self._rope(Cm, cos, sin) |
| diff = (beta.view(B, H, 1, 1) * x_prev.unsqueeze(-1) * Bk_prev.unsqueeze(-2) |
| + gamma.view(B, H, 1, 1) * x.unsqueeze(-1) * Bk.unsqueeze(-2)) |
| h = h * alpha.view(B, H, 1, 1) + diff |
| y = (h * Cq.unsqueeze(-2)).sum(-1) + self.D.view(1, H, 1) * x |
| y = (y * F.silu(z)).reshape(B, H * P) |
| return self.out_proj(y), [h, x, Bk, cum] |
|
|
|
|
| |
| def _rot_half(x): |
| a, b = x.chunk(2, -1) |
| return torch.cat((-b, a), -1) |
|
|
|
|
| class GQACPU(nn.Module): |
| def __init__(self, cfg, base=10000.0): |
| super().__init__() |
| d = cfg["d_model"]; self.nh = cfg["n_heads"]; self.nkv = cfg["n_kv_heads"] |
| self.hd = d // self.nh; self.rep = self.nh // self.nkv |
| self.q_proj = nn.Linear(d, self.nh * self.hd, bias=False) |
| self.k_proj = nn.Linear(d, self.nkv * self.hd, bias=False) |
| self.v_proj = nn.Linear(d, self.nkv * self.hd, bias=False) |
| self.out_proj = nn.Linear(self.nh * self.hd, d, bias=False) |
| self.qn = nn.Parameter(torch.ones(self.hd)) |
| self.kn = nn.Parameter(torch.ones(self.hd)) |
| self.register_buffer("inv", 1.0 / (base ** (torch.arange(0, self.hd, 2).float() / self.hd)), persistent=False) |
|
|
| def _rope(self, x, T): |
| f = torch.outer(torch.arange(T, device=x.device).float(), self.inv); e = torch.cat((f, f), -1) |
| return (x * e.cos()[None, None] + _rot_half(x) * e.sin()[None, None]).to(x.dtype) |
|
|
| def forward(self, x): |
| B, T, _ = x.shape |
| q = self.q_proj(x).view(B, T, self.nh, self.hd).transpose(1, 2) |
| k = self.k_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2) |
| v = self.v_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2) |
| q = rms_norm(q.float(), self.qn.float()).to(x.dtype) |
| k = rms_norm(k.float(), self.kn.float()).to(x.dtype) |
| q, k = self._rope(q, T), self._rope(k, T) |
| k = k.repeat_interleave(self.rep, 1); v = v.repeat_interleave(self.rep, 1) |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| return self.out_proj(y.transpose(1, 2).reshape(B, T, -1)) |
|
|
| |
| def init_state(self, B, device=None, dtype=torch.float32): |
| return [None, None] |
|
|
| def _rope_at(self, x, pos): |
| f = (self.inv * float(pos)).unsqueeze(0) |
| e = torch.cat((f, f), -1) |
| return (x * e.cos()[None, None] + _rot_half(x) * e.sin()[None, None]).to(x.dtype) |
|
|
| def step(self, x, state): |
| """x:(B,d_model), state=[k_cache,v_cache] → (y:(B,d_model), yeni_state).""" |
| B = x.shape[0]; kc, vc = state |
| pos = 0 if kc is None else kc.shape[2] |
| xq = x.view(B, 1, -1) |
| q = self.q_proj(xq).view(B, 1, self.nh, self.hd).transpose(1, 2) |
| k = self.k_proj(xq).view(B, 1, self.nkv, self.hd).transpose(1, 2) |
| v = self.v_proj(xq).view(B, 1, self.nkv, self.hd).transpose(1, 2) |
| q = rms_norm(q.float(), self.qn.float()).to(x.dtype) |
| k = rms_norm(k.float(), self.kn.float()).to(x.dtype) |
| q = self._rope_at(q, pos); k = self._rope_at(k, pos) |
| kc = k if kc is None else torch.cat([kc, k], dim=2) |
| vc = v if vc is None else torch.cat([vc, v], dim=2) |
| kk = kc.repeat_interleave(self.rep, 1); vv = vc.repeat_interleave(self.rep, 1) |
| y = F.scaled_dot_product_attention(q, kk, vv, is_causal=False) |
| return self.out_proj(y.transpose(1, 2).reshape(B, -1)), [kc, vc] |
|
|
|
|
| class GatedMLP(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| d = cfg["d_model"] |
| mult = 128 |
| hidden = ((cfg["d_intermediate"] + mult - 1) // mult) * mult |
| self.fc1 = nn.Linear(d, 2 * hidden, bias=False) |
| self.fc2 = nn.Linear(hidden, d, bias=False) |
|
|
| def forward(self, x): |
| |
| y, gate = self.fc1(x).chunk(2, -1) |
| return self.fc2(y * F.silu(gate)) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, cfg, is_attn): |
| super().__init__() |
| self.norm = nn.Parameter(torch.ones(cfg["d_model"])) |
| self.norm2 = nn.Parameter(torch.ones(cfg["d_model"])) |
| self.mixer = GQACPU(cfg) if is_attn else Mamba3CPU(cfg) |
| self.mlp = GatedMLP(cfg) |
|
|
| def forward(self, x): |
| x = x + self.mixer(rms_norm(x, self.norm)) |
| x = x + self.mlp(rms_norm(x, self.norm2)) |
| return x |
|
|
| def init_state(self, B, **kw): |
| return self.mixer.init_state(B, **kw) |
|
|
| def step(self, x, mstate): |
| m_out, mstate = self.mixer.step(rms_norm(x, self.norm), mstate) |
| x = x + m_out |
| x = x + self.mlp(rms_norm(x, self.norm2)) |
| return x, mstate |
|
|
|
|
| class LambaCPU(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.embedding = nn.Embedding(cfg["vocab_size"], cfg["d_model"]) |
| self.layers = nn.ModuleList() |
| for i in range(cfg["n_layers"]): |
| is_attn = ((i + 1) % cfg["attn_every"] == 0) and i != 0 and i != cfg["n_layers"] - 1 |
| self.layers.append(Block(cfg, is_attn)) |
| self.norm_f = nn.Parameter(torch.ones(cfg["d_model"])) |
| self.lm_head = nn.Linear(cfg["d_model"], cfg["vocab_size"], bias=False) |
| self.lm_head.weight = self.embedding.weight |
|
|
| def forward(self, ids): |
| h = self.embedding(ids) |
| for l in self.layers: |
| h = l(h) |
| return self.lm_head(rms_norm(h, self.norm_f)) |
|
|
| def init_states(self, B, **kw): |
| return [l.init_state(B, **kw) for l in self.layers] |
|
|
| @torch.no_grad() |
| def step_forward(self, ids_step, states): |
| """ids_step:(B,1) tek token. → (logits:(B,V), yeni_states). O(1) Mamba + O(t) GQA-attn.""" |
| h = self.embedding(ids_step)[:, 0] |
| new_states = [] |
| for l, stt in zip(self.layers, states): |
| h, ns = l.step(h, stt) |
| new_states.append(ns) |
| return self.lm_head(rms_norm(h, self.norm_f)), new_states |
|
|
|
|
| |
| def load_lamba(ckpt_path): |
| st = torch.load(ckpt_path, map_location="cpu") |
| cfg, sd = st["cfg"], st["model"] |
| model = LambaCPU(cfg) |
| mp = {} |
| for k, v in sd.items(): |
| nk = k |
| nk = nk.replace(".mixer.in_proj.", ".mixer.in_proj.").replace(".mixer.out_proj.", ".mixer.out_proj.") |
| |
| |
| nk = nk.replace(".mixer.B_norm.weight", ".mixer.B_norm").replace(".mixer.C_norm.weight", ".mixer.C_norm") |
| nk = nk.replace("norm_f.weight", "norm_f") |
| nk = nk.replace(".norm.weight", ".norm").replace(".norm2.weight", ".norm2") |
| mp[nk] = v |
| |
| miss, unexp = model.load_state_dict(mp, strict=False) |
| print(f"[load] eksik={len(miss)} beklenmeyen={len(unexp)}", flush=True) |
| if miss: |
| print(" ilk eksikler:", miss[:8]) |
| if unexp: |
| print(" ilk beklenmeyenler:", unexp[:8]) |
| model.eval() |
| return model, cfg |
|
|
|
|
| @torch.no_grad() |
| def generate(model, sp, prompt, max_new=64, temperature=0.0, top_k=40, top_p=0.9, |
| rep_penalty=1.2, device=None): |
| """Decode-cache'li O(L) üretim (step_forward). faz7_rag.generate ile signature-uyumlu.""" |
| dev = device or next(model.parameters()).device |
| ids = sp.encode(prompt, out_type=int); eos = sp.eos_id() |
| states = model.init_states(1, device=dev) |
| logits = None |
| for tid in ids: |
| logits, states = model.step_forward(torch.tensor([[tid]], device=dev), states) |
| out = [] |
| for _ in range(max_new): |
| lg = logits[0].float() |
| if rep_penalty != 1.0: |
| for t in set(ids + out): |
| lg[t] = lg[t] / rep_penalty if lg[t] > 0 else lg[t] * rep_penalty |
| if temperature <= 0: |
| nxt = int(lg.argmax()) |
| else: |
| lg = lg / temperature |
| if top_k: |
| kth = torch.topk(lg, min(top_k, lg.numel())).values[-1]; lg[lg < kth] = -float("inf") |
| probs = F.softmax(lg, -1) |
| if top_p < 1.0: |
| s, si = torch.sort(probs, descending=True); cut = torch.cumsum(s, -1) > top_p |
| cut[1:] = cut[:-1].clone(); cut[0] = False; s[cut] = 0 |
| probs = torch.zeros_like(probs).scatter_(0, si, s); probs /= probs.sum() |
| nxt = int(torch.multinomial(probs, 1)) |
| if nxt == eos: |
| break |
| out.append(nxt) |
| logits, states = model.step_forward(torch.tensor([[nxt]], device=dev), states) |
| return sp.decode(out) |
|
|
|
|
| def quantize_int8(model): |
| """CPU int8 dynamic quant (nn.Linear): ~2× küçük bellek (708→325MB), ~1.2× hız, |
| çıktı greedy'de ~birebir. SSM recurrence fp32 kalır (darboğaz orada; int8 GEMM'leri hızlandırır).""" |
| qd = getattr(torch.ao.quantization, "quantize_dynamic", None) or torch.quantization.quantize_dynamic |
| return qd(model, {nn.Linear}, dtype=torch.qint8) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--ckpt", required=True) |
| ap.add_argument("--tokenizer", required=True) |
| ap.add_argument("--query", default="Merhaba") |
| ap.add_argument("--max_new", type=int, default=64) |
| ap.add_argument("--int8", action="store_true", help="int8 dynamic quant (~2× küçük bellek, ~1.2× hız, ~birebir)") |
| args = ap.parse_args() |
| import sentencepiece as spm |
| sp = spm.SentencePieceProcessor(model_file=args.tokenizer) |
| model, cfg = load_lamba(args.ckpt) |
| if args.int8: |
| model = quantize_int8(model); print("[int8] dynamic quant uygulandı") |
| print(f"[model] {'MIMO' if cfg.get('is_mimo') else 'SISO'} | CPU | {sum(p.numel() for p in model.parameters())/1e6:.0f}M") |
| prompt = f"### Talimat:\n{args.query}\n\n### Yanıt:\n" |
| print("CEVAP:", generate(model, sp, prompt, args.max_new)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|