""" LAMBA V1.0 — saf-PyTorch CPU inference (Triton/mamba_ssm YOK). [v1.1] Fork (mamba_ssm "norm-free" Mamba-3, Triton kernel) ile EĞİTİLEN ağırlıkları (lamba_v1.pt) GPU'suz çalıştırmak için fork forward matematiğinin saf-PyTorch reimplementasyonu. Mamba-3 mixer matematiği (vendor mamba3.py + siso_step kernelinden çıkarıldı), per-head: _A = -softplus(dd_A) (clamp ≤ -1e-4) # ← data-dependent A (norm-free farkı) DT = softplus(dd_dt + dt_bias) trap = sigmoid(trap_proj) α = exp(_A·DT) # base-e ✅ (kernel exp2(ön-ölçekli adt) ile birebir) β = α·DT·(1-trap) ; γ = trap·DT # trapezoidal h = α·h + β·(x_prev ⊗ B_prev) + γ·(x ⊗ B) # B,x = K,V ; B üzerinde partial-RoPE y = h @ C # C = Q ; C üzerinde partial-RoPE y += D·x ; y *= silu(z) B,C paylaşımlı (ngroups=1) → head'lere broadcast + per-head bias. RoPE: rope_fraction=0.5 (ilk d_state/2=64 boyut, 32 angle, INTERLEAVED [çift (2j,2j+1)]). Token-token recurrence + decode-cache (step). ✅ KALİBRE TAMAM (2026-06-27): fork'a full-logit fp32 maxdiff 0.06–0.09 (top-5/argmax birebir). Kritik düzeltme MLP'deydi: GatedMLP = y·silu(gate) (1.yarı=değer, 2.yarı=gate; mixer değil). Decode-cache (step) full-recompute ile birebir aynı çıktı, ~6× hızlı (O(L²)→O(L)). Kullanım (CPU): python lamba_cpu.py --ckpt checkpoints/lamba_v1.pt --tokenizer tokenizer/tokenizer.model --query "..." """ import os, sys, math, argparse import torch, torch.nn as nn, torch.nn.functional as F torch.set_num_threads(max(1, os.cpu_count() or 4)) def rms_norm(x, w, eps=1e-5): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) * w # ───────────── Mamba-3 mixer (saf-PyTorch, fork matematiği) ───────────── class Mamba3CPU(nn.Module): # KALİBRE bayrakları (Colab grid-search ile fork'a karşı belirlenir) ROPE_INTER = True # fork rotate_pairwise=True ⇒ interleaved (q[0::2], q[1::2]) BETA_ALPHA = True # KALİBRE: β'da alpha çarpanı (grid: True daha iyi) BETA_SHIFT = False # KALİBRE: β'da dt/trap kaydırılmış mı (grid: False daha iyi) EXP_E = True # ✅ ÇÖZÜLDÜ: α = exp(_A·DT) base-e (exp2 değil) — mixer kalibre (maxdiff 0.14 bf16) def __init__(self, cfg): super().__init__() d = cfg["d_model"] self.d_inner = cfg["expand"] * d # 1536 self.headdim = cfg["head_dim"] # 64 self.nheads = self.d_inner // self.headdim # 24 self.d_state = cfg["d_state"] # 128 self.ngroups = cfg.get("ngroups", 1) # num_bc_heads = 1 self.A_floor = 1e-4 rope_fraction = cfg.get("rope_fraction", 0.5) self.rot = int(self.d_state * rope_fraction) # 64 → döner boyut if self.rot % 2: self.rot -= 1 self.n_ang = self.rot // 2 # 32 angle pair bc = self.d_state * self.ngroups # 128 (SISO, rank=1) d_in = 2 * self.d_inner + 2 * bc + 3 * self.nheads + self.n_ang # 3432 self.in_proj = nn.Linear(d, d_in, bias=False) self.out_proj = nn.Linear(self.d_inner, d, bias=False) self.dt_bias = nn.Parameter(torch.zeros(self.nheads)) self.D = nn.Parameter(torch.ones(self.nheads)) self.B_bias = nn.Parameter(torch.zeros(self.nheads, 1, self.d_state)) self.C_bias = nn.Parameter(torch.zeros(self.nheads, 1, self.d_state)) self.B_norm = nn.Parameter(torch.ones(bc)) self.C_norm = nn.Parameter(torch.ones(bc)) def _rope(self, t, cos, sin): """Partial RoPE (ilk `rot`=d_state/2 boyut döner; gerisi sabit). pairwise: j ↔ j+n çiftleri | interleaved: (2j, 2j+1) çiftleri.""" rot, n = self.rot, self.n_ang rest = t[..., rot:] if Mamba3CPU.ROPE_INTER: head = t[..., :rot] x1, x2 = head[..., 0::2], head[..., 1::2] ra, rb = x1 * cos - x2 * sin, x1 * sin + x2 * cos out = torch.stack([ra, rb], dim=-1).flatten(-2) return torch.cat([out, rest], dim=-1) a, b = t[..., :n], t[..., n:rot] ra, rb = a * cos - b * sin, a * sin + b * cos return torch.cat([ra, rb, rest], dim=-1) def forward(self, u): """u: (B, L, d_model) → (B, L, d_model). Token-token recurrence.""" B, L, _ = u.shape H, P, S = self.nheads, self.headdim, self.d_state proj = self.in_proj(u) z, x, Bm, Cm, dd_dt, dd_A, trap, ang = torch.split( proj, [self.d_inner, self.d_inner, S * self.ngroups, S * self.ngroups, H, H, H, self.n_ang], dim=-1) x = x.view(B, L, H, P) z = z.view(B, L, H, P) _A = (-F.softplus(dd_A.float())).clamp(max=-self.A_floor) # (B,L,H) DT = F.softplus(dd_dt.float() + self.dt_bias) # (B,L,H) trap = torch.sigmoid(trap.float()) # (B,L,H) Bm = rms_norm(Bm.float(), self.B_norm) # (B,L,S) ngroups=1 Cm = rms_norm(Cm.float(), self.C_norm) Bm = Bm.view(B, L, 1, S) + self.B_bias.view(1, 1, H, S) # head'e broadcast + bias Cm = Cm.view(B, L, 1, S) + self.C_bias.view(1, 1, H, S) alpha = torch.exp(_A * DT) if Mamba3CPU.EXP_E else torch.exp2(_A * DT) # base-e vs base-2 decay bDT = torch.roll(DT, 1, dims=1) if Mamba3CPU.BETA_SHIFT else DT # β'da kaydırılmış dt/trap? btrap = torch.roll(trap, 1, dims=1) if Mamba3CPU.BETA_SHIFT else trap beta = (alpha if Mamba3CPU.BETA_ALPHA else 1.0) * bDT * (1 - btrap) gamma = trap * DT # KALİBRE: angle birikimi (işaret/dt-ölçeği) — ilk tahmin: cum += DT·angles h = torch.zeros(B, H, P, S, device=u.device) # ssm_state (V,QK) = (P,S) x_prev = torch.zeros(B, H, P, device=u.device) Bk_prev = torch.zeros(B, H, S, device=u.device) cum = torch.zeros(B, H, self.n_ang, device=u.device) ys = [] for t in range(L): # fork: angle = angle_state + tanh(angle_proj)·DT·π (mamba3_mimo_rotary_step referans) inc = torch.tanh(ang[:, t].float()).unsqueeze(1) * DT[:, t].unsqueeze(-1) * math.pi cum = cum + inc # (B,H,n_ang) cos, sin = torch.cos(cum), torch.sin(cum) Bk = self._rope(Bm[:, t], cos, sin) # (B,H,S) Cq = self._rope(Cm[:, t], cos, sin) xt = x[:, t] # (B,H,P) a = alpha[:, t].view(B, H, 1, 1) diff = (beta[:, t].view(B, H, 1, 1) * x_prev.unsqueeze(-1) * Bk_prev.unsqueeze(-2) + gamma[:, t].view(B, H, 1, 1) * xt.unsqueeze(-1) * Bk.unsqueeze(-2)) h = h * a + diff # (B,H,P,S) y = (h * Cq.unsqueeze(-2)).sum(-1) # (B,H,P) y = y + self.D.view(1, H, 1) * xt y = y * F.silu(z[:, t]) ys.append(y.reshape(B, 1, H * P)) x_prev, Bk_prev = xt, Bk return self.out_proj(torch.cat(ys, dim=1)) # fp32 (int8-quant uyumlu: .weight'e dokunma) # ───── decode-cache (tek-token step; forward'ın bir iterasyonu, O(1)) ───── def init_state(self, B, device=None, dtype=torch.float32): H, P, S = self.nheads, self.headdim, self.d_state z = lambda *s: torch.zeros(*s, device=device, dtype=dtype) return [z(B, H, P, S), z(B, H, P), z(B, H, S), z(B, H, self.n_ang)] # h, x_prev, Bk_prev, cum def step(self, u, state): """u:(B,d_model), state=[h,x_prev,Bk_prev,cum] → (y:(B,d_model), yeni_state).""" B = u.shape[0] H, P, S = self.nheads, self.headdim, self.d_state h, x_prev, Bk_prev, cum = state z, x, Bm, Cm, dd_dt, dd_A, trap, ang = torch.split( self.in_proj(u), [self.d_inner, self.d_inner, S * self.ngroups, S * self.ngroups, H, H, H, self.n_ang], dim=-1) x = x.view(B, H, P); z = z.view(B, H, P) _A = (-F.softplus(dd_A.float())).clamp(max=-self.A_floor) DT = F.softplus(dd_dt.float() + self.dt_bias) trap = torch.sigmoid(trap.float()) Bm = rms_norm(Bm.float(), self.B_norm).view(B, 1, S) + self.B_bias.view(1, H, S) Cm = rms_norm(Cm.float(), self.C_norm).view(B, 1, S) + self.C_bias.view(1, H, S) alpha = torch.exp(_A * DT) if Mamba3CPU.EXP_E else torch.exp2(_A * DT) beta = (alpha if Mamba3CPU.BETA_ALPHA else 1.0) * DT * (1 - trap) # BETA_SHIFT=False gamma = trap * DT cum = cum + torch.tanh(ang.float()).unsqueeze(1) * DT.unsqueeze(-1) * math.pi cos, sin = torch.cos(cum), torch.sin(cum) Bk = self._rope(Bm, cos, sin); Cq = self._rope(Cm, cos, sin) diff = (beta.view(B, H, 1, 1) * x_prev.unsqueeze(-1) * Bk_prev.unsqueeze(-2) + gamma.view(B, H, 1, 1) * x.unsqueeze(-1) * Bk.unsqueeze(-2)) h = h * alpha.view(B, H, 1, 1) + diff y = (h * Cq.unsqueeze(-2)).sum(-1) + self.D.view(1, H, 1) * x y = (y * F.silu(z)).reshape(B, H * P) return self.out_proj(y), [h, x, Bk, cum] # fp32 (int8-quant uyumlu) # ───────────── GQA mixer (saf-PyTorch; hybrid_mamba3 ile aynı) ───────────── def _rot_half(x): a, b = x.chunk(2, -1) return torch.cat((-b, a), -1) class GQACPU(nn.Module): def __init__(self, cfg, base=10000.0): super().__init__() d = cfg["d_model"]; self.nh = cfg["n_heads"]; self.nkv = cfg["n_kv_heads"] self.hd = d // self.nh; self.rep = self.nh // self.nkv self.q_proj = nn.Linear(d, self.nh * self.hd, bias=False) self.k_proj = nn.Linear(d, self.nkv * self.hd, bias=False) self.v_proj = nn.Linear(d, self.nkv * self.hd, bias=False) self.out_proj = nn.Linear(self.nh * self.hd, d, bias=False) self.qn = nn.Parameter(torch.ones(self.hd)) self.kn = nn.Parameter(torch.ones(self.hd)) self.register_buffer("inv", 1.0 / (base ** (torch.arange(0, self.hd, 2).float() / self.hd)), persistent=False) def _rope(self, x, T): f = torch.outer(torch.arange(T, device=x.device).float(), self.inv); e = torch.cat((f, f), -1) return (x * e.cos()[None, None] + _rot_half(x) * e.sin()[None, None]).to(x.dtype) def forward(self, x): B, T, _ = x.shape q = self.q_proj(x).view(B, T, self.nh, self.hd).transpose(1, 2) k = self.k_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2) v = self.v_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2) q = rms_norm(q.float(), self.qn.float()).to(x.dtype) k = rms_norm(k.float(), self.kn.float()).to(x.dtype) q, k = self._rope(q, T), self._rope(k, T) k = k.repeat_interleave(self.rep, 1); v = v.repeat_interleave(self.rep, 1) y = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.out_proj(y.transpose(1, 2).reshape(B, T, -1)) # ───── decode-cache (KV-cache; q tek token vs tüm geçmiş = nedensel) ───── def init_state(self, B, device=None, dtype=torch.float32): return [None, None] # k_cache, v_cache (B,nkv,t,hd) def _rope_at(self, x, pos): f = (self.inv * float(pos)).unsqueeze(0) # (1, hd/2) e = torch.cat((f, f), -1) return (x * e.cos()[None, None] + _rot_half(x) * e.sin()[None, None]).to(x.dtype) def step(self, x, state): """x:(B,d_model), state=[k_cache,v_cache] → (y:(B,d_model), yeni_state).""" B = x.shape[0]; kc, vc = state pos = 0 if kc is None else kc.shape[2] xq = x.view(B, 1, -1) q = self.q_proj(xq).view(B, 1, self.nh, self.hd).transpose(1, 2) k = self.k_proj(xq).view(B, 1, self.nkv, self.hd).transpose(1, 2) v = self.v_proj(xq).view(B, 1, self.nkv, self.hd).transpose(1, 2) q = rms_norm(q.float(), self.qn.float()).to(x.dtype) k = rms_norm(k.float(), self.kn.float()).to(x.dtype) q = self._rope_at(q, pos); k = self._rope_at(k, pos) kc = k if kc is None else torch.cat([kc, k], dim=2) vc = v if vc is None else torch.cat([vc, v], dim=2) kk = kc.repeat_interleave(self.rep, 1); vv = vc.repeat_interleave(self.rep, 1) y = F.scaled_dot_product_attention(q, kk, vv, is_causal=False) return self.out_proj(y.transpose(1, 2).reshape(B, -1)), [kc, vc] class GatedMLP(nn.Module): def __init__(self, cfg): super().__init__() d = cfg["d_model"] mult = 128 # mamba_ssm GatedMLP: hidden 128'in katına yuvarlar hidden = ((cfg["d_intermediate"] + mult - 1) // mult) * mult # 1500 → 1536 self.fc1 = nn.Linear(d, 2 * hidden, bias=False) self.fc2 = nn.Linear(hidden, d, bias=False) def forward(self, x): # mamba_ssm GatedMLP: 1. yarı = değer, 2. yarı = gate → y * silu(gate) y, gate = self.fc1(x).chunk(2, -1) return self.fc2(y * F.silu(gate)) class Block(nn.Module): def __init__(self, cfg, is_attn): super().__init__() self.norm = nn.Parameter(torch.ones(cfg["d_model"])) self.norm2 = nn.Parameter(torch.ones(cfg["d_model"])) self.mixer = GQACPU(cfg) if is_attn else Mamba3CPU(cfg) self.mlp = GatedMLP(cfg) def forward(self, x): x = x + self.mixer(rms_norm(x, self.norm)) x = x + self.mlp(rms_norm(x, self.norm2)) return x def init_state(self, B, **kw): return self.mixer.init_state(B, **kw) def step(self, x, mstate): m_out, mstate = self.mixer.step(rms_norm(x, self.norm), mstate) x = x + m_out x = x + self.mlp(rms_norm(x, self.norm2)) return x, mstate class LambaCPU(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg self.embedding = nn.Embedding(cfg["vocab_size"], cfg["d_model"]) self.layers = nn.ModuleList() for i in range(cfg["n_layers"]): is_attn = ((i + 1) % cfg["attn_every"] == 0) and i != 0 and i != cfg["n_layers"] - 1 self.layers.append(Block(cfg, is_attn)) self.norm_f = nn.Parameter(torch.ones(cfg["d_model"])) self.lm_head = nn.Linear(cfg["d_model"], cfg["vocab_size"], bias=False) self.lm_head.weight = self.embedding.weight def forward(self, ids): h = self.embedding(ids) for l in self.layers: h = l(h) return self.lm_head(rms_norm(h, self.norm_f)) def init_states(self, B, **kw): return [l.init_state(B, **kw) for l in self.layers] @torch.no_grad() def step_forward(self, ids_step, states): """ids_step:(B,1) tek token. → (logits:(B,V), yeni_states). O(1) Mamba + O(t) GQA-attn.""" h = self.embedding(ids_step)[:, 0] # (B,d) new_states = [] for l, stt in zip(self.layers, states): h, ns = l.step(h, stt) new_states.append(ns) return self.lm_head(rms_norm(h, self.norm_f)), new_states # ───────────── ağırlık yükleyici (fork ckpt → CPU model) ───────────── def load_lamba(ckpt_path): st = torch.load(ckpt_path, map_location="cpu") cfg, sd = st["cfg"], st["model"] model = LambaCPU(cfg) mp = {} for k, v in sd.items(): nk = k nk = nk.replace(".mixer.in_proj.", ".mixer.in_proj.").replace(".mixer.out_proj.", ".mixer.out_proj.") # GQA fork→CPU isim eşlemesi zaten birebir (q_proj/k_proj/v_proj/out_proj/qn/kn) # Mamba B_norm.weight/C_norm.weight → B_norm/C_norm (Parameter) nk = nk.replace(".mixer.B_norm.weight", ".mixer.B_norm").replace(".mixer.C_norm.weight", ".mixer.C_norm") nk = nk.replace("norm_f.weight", "norm_f") nk = nk.replace(".norm.weight", ".norm").replace(".norm2.weight", ".norm2") mp[nk] = v # norm_f / lm_head / embedding miss, unexp = model.load_state_dict(mp, strict=False) print(f"[load] eksik={len(miss)} beklenmeyen={len(unexp)}", flush=True) if miss: print(" ilk eksikler:", miss[:8]) if unexp: print(" ilk beklenmeyenler:", unexp[:8]) model.eval() return model, cfg @torch.no_grad() def generate(model, sp, prompt, max_new=64, temperature=0.0, top_k=40, top_p=0.9, rep_penalty=1.2, device=None): """Decode-cache'li O(L) üretim (step_forward). faz7_rag.generate ile signature-uyumlu.""" dev = device or next(model.parameters()).device ids = sp.encode(prompt, out_type=int); eos = sp.eos_id() states = model.init_states(1, device=dev) logits = None for tid in ids: # prefill: her token 1 kez (O(L)) logits, states = model.step_forward(torch.tensor([[tid]], device=dev), states) out = [] for _ in range(max_new): lg = logits[0].float() if rep_penalty != 1.0: for t in set(ids + out): lg[t] = lg[t] / rep_penalty if lg[t] > 0 else lg[t] * rep_penalty if temperature <= 0: nxt = int(lg.argmax()) else: lg = lg / temperature if top_k: kth = torch.topk(lg, min(top_k, lg.numel())).values[-1]; lg[lg < kth] = -float("inf") probs = F.softmax(lg, -1) if top_p < 1.0: s, si = torch.sort(probs, descending=True); cut = torch.cumsum(s, -1) > top_p cut[1:] = cut[:-1].clone(); cut[0] = False; s[cut] = 0 probs = torch.zeros_like(probs).scatter_(0, si, s); probs /= probs.sum() nxt = int(torch.multinomial(probs, 1)) if nxt == eos: break out.append(nxt) logits, states = model.step_forward(torch.tensor([[nxt]], device=dev), states) return sp.decode(out) def quantize_int8(model): """CPU int8 dynamic quant (nn.Linear): ~2× küçük bellek (708→325MB), ~1.2× hız, çıktı greedy'de ~birebir. SSM recurrence fp32 kalır (darboğaz orada; int8 GEMM'leri hızlandırır).""" qd = getattr(torch.ao.quantization, "quantize_dynamic", None) or torch.quantization.quantize_dynamic return qd(model, {nn.Linear}, dtype=torch.qint8) def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", required=True) ap.add_argument("--tokenizer", required=True) ap.add_argument("--query", default="Merhaba") ap.add_argument("--max_new", type=int, default=64) ap.add_argument("--int8", action="store_true", help="int8 dynamic quant (~2× küçük bellek, ~1.2× hız, ~birebir)") args = ap.parse_args() import sentencepiece as spm sp = spm.SentencePieceProcessor(model_file=args.tokenizer) model, cfg = load_lamba(args.ckpt) if args.int8: model = quantize_int8(model); print("[int8] dynamic quant uygulandı") print(f"[model] {'MIMO' if cfg.get('is_mimo') else 'SISO'} | CPU | {sum(p.numel() for p in model.parameters())/1e6:.0f}M") prompt = f"### Talimat:\n{args.query}\n\n### Yanıt:\n" print("CEVAP:", generate(model, sp, prompt, args.max_new)) if __name__ == "__main__": main()