smartcore-v1 / code /kod /faz7_rag.py
kdirgul's picture
faz7_rag.py: --int8 dynamic quant (~2x kucuk bellek 708->325MB, ~1.2x hiz, ciktilar birebir)
99a2344 verified
Raw
History Blame Contribute Delete
18.7 kB
"""
Faz 7 — RAG (retrieval-augmented generation) — SmartCore V1. (Basamak 2)
Pipeline: doküman → chunk → embed → index | soru → embed → top-k ara →
bağlamı SFT şablonunun '### Girdi:' alanına enjekte → v1-instruct cevaplar.
Amaç: 177M'in olgusal halüsinasyonunu retrieval ile çöz (ezberden değil, getirerek).
Generator: HybridLM gömülü (model-bağımsız). Embedding: sentence-transformers (çok dilli TR+EN).
Index: numpy cosine (prototip; ölçek için FAISS/turbovec sonra). 2048 bağlam → bütçe-kırpma
(REFRAG/headroom sıkıştırma = sonraki sürüm).
Ortam: Colab GPU + mamba-og fork (wheel) + `pip install sentence-transformers`.
Kullanım:
HF_TOKEN=hf_xxx python faz7_rag.py --demo --query "Türkiye'nin başkenti neresi?"
HF_TOKEN=hf_xxx python faz7_rag.py --docs /content/dokumanlar # interaktif
"""
import os, sys, glob, re, argparse
import torch, torch.nn as nn, torch.nn.functional as F
from functools import partial
try: # fork (Triton) yalnız GPU/HybridLM yolu için
from mamba_ssm.modules.block import Block
from mamba_ssm.modules.mamba3 import Mamba3
from mamba_ssm.modules.mlp import GatedMLP
from mamba_ssm.ops.triton.layer_norm import RMSNorm
_FORK_OK, _FORK_ERR = True, None
except Exception as e: # CPU: fork yok → LambaCPU (--device cpu) kullanılır
Block = Mamba3 = GatedMLP = RMSNorm = None
_FORK_OK, _FORK_ERR = False, e
# ───────────── model (faz3_train ile birebir) ─────────────
def _rms(x, w, eps=1e-5):
return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)) * w
def _rot_half(x):
a, b = x.chunk(2, -1)
return torch.cat((-b, a), -1)
class GQAMixer(nn.Module):
def __init__(self, dim, n_heads=12, n_kv=3, base=10000.0, layer_idx=None, device=None, dtype=None):
super().__init__()
self.nh, self.nkv, self.hd = n_heads, n_kv, dim // n_heads
self.rep = n_heads // n_kv
fk = {"device": device, "dtype": dtype}
self.q_proj = nn.Linear(dim, n_heads * self.hd, bias=False, **fk)
self.k_proj = nn.Linear(dim, n_kv * self.hd, bias=False, **fk)
self.v_proj = nn.Linear(dim, n_kv * self.hd, bias=False, **fk)
self.out_proj = nn.Linear(n_heads * self.hd, dim, bias=False, **fk)
self.qn = nn.Parameter(torch.ones(self.hd, **fk))
self.kn = nn.Parameter(torch.ones(self.hd, **fk))
self.register_buffer(
"inv", 1.0 / (base ** (torch.arange(0, self.hd, 2, device=device).float() / self.hd)),
persistent=False)
def _rope(self, x, T):
f = torch.outer(torch.arange(T, device=x.device, dtype=torch.float32), self.inv)
e = torch.cat((f, f), -1)
return (x * e.cos()[None, None] + _rot_half(x) * e.sin()[None, None]).to(x.dtype)
def forward(self, x, **kw):
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.nh, self.hd).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
q = _rms(q.float(), self.qn.float()).to(x.dtype)
k = _rms(k.float(), self.kn.float()).to(x.dtype)
q, k = self._rope(q, T), self._rope(k, T)
k = k.repeat_interleave(self.rep, 1)
v = v.repeat_interleave(self.rep, 1)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.out_proj(y.transpose(1, 2).contiguous().view(B, T, -1))
class HybridLM(nn.Module):
def __init__(self, cfg, device=None, dtype=None):
super().__init__()
self.cfg = cfg
self.vocab = cfg["vocab_size"]
self.scaled_embed = cfg.get("scaled_embed", False)
d = cfg["d_model"]
self.embedding = nn.Embedding(self.vocab, d, device=device, dtype=dtype)
self.layers = nn.ModuleList()
for i in range(cfg["n_layers"]):
is_attn = ((i + 1) % cfg["attn_every"] == 0) and i != 0 and i != cfg["n_layers"] - 1
fk = {"device": device, "dtype": dtype}
if is_attn:
mixer_cls = partial(GQAMixer, n_heads=cfg["n_heads"], n_kv=cfg["n_kv_heads"], layer_idx=i, **fk)
else:
ssm = dict(d_state=cfg["d_state"], expand=cfg["expand"], headdim=cfg["head_dim"],
ngroups=cfg["ngroups"], rope_fraction=cfg["rope_fraction"],
is_outproj_norm=False, is_mimo=cfg["is_mimo"], mimo_rank=cfg["mimo_rank"],
chunk_size=cfg["chunk_size"])
mixer_cls = partial(Mamba3, layer_idx=i, **ssm, **fk)
blk = Block(d, mixer_cls,
partial(GatedMLP, hidden_features=cfg["d_intermediate"], out_features=d, **fk),
norm_cls=partial(RMSNorm, eps=1e-5, **fk), fused_add_norm=True, residual_in_fp32=True)
blk.layer_idx = i
self.layers.append(blk)
self.norm_f = RMSNorm(d, eps=1e-5, device=device, dtype=dtype)
self.lm_head = nn.Linear(d, self.vocab, bias=False, device=device, dtype=dtype)
self.lm_head.weight = self.embedding.weight
def forward(self, ids):
h = self.embedding(ids)
if self.scaled_embed:
h = h * (self.cfg["d_model"] ** 0.5)
res = None
for l in self.layers:
h, res = l(h, res)
h = self.norm_f((h + res) if res is not None else h)
return self.lm_head(h.to(self.lm_head.weight.dtype))
# ───────────── tokenizer + checkpoint ─────────────
def load_tok(token, local=None):
import sentencepiece as spm
if local and os.path.exists(local): # public/yerel: repo kökündeki tokenizer.model
return spm.SentencePieceProcessor(model_file=local)
from huggingface_hub import hf_hub_download
p = hf_hub_download("kdirgul/smartcore-v1", "tokenizer/tokenizer.model", repo_type="model", token=token)
return spm.SentencePieceProcessor(model_file=p)
def resolve_ckpt(spec, token):
if os.path.exists(spec):
return spec
from huggingface_hub import hf_hub_download
print(f"[ckpt] HF: {spec}", flush=True)
return hf_hub_download("kdirgul/smartcore-v1", spec, repo_type="model", token=token)
@torch.no_grad()
def generate(model, sp, prompt, max_new=160, temperature=0.3, top_k=40, top_p=0.9, rep_penalty=1.2, device="cuda"):
eos = sp.eos_id()
ids = sp.encode(prompt, out_type=int)
x = torch.tensor([ids], device=device); out = list(ids)
use_amp = (device == "cuda")
for _ in range(max_new):
if use_amp:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model(x)[0, -1].float()
else:
logits = model(x)[0, -1].float() # CPU: saf fp32 (lamba_cpu, Triton'suz)
if rep_penalty != 1.0:
for t in set(out):
logits[t] = logits[t] / rep_penalty if logits[t] > 0 else logits[t] * rep_penalty
if temperature <= 0:
nxt = int(logits.argmax())
else:
logits = logits / temperature
if top_k:
kth = torch.topk(logits, min(top_k, logits.numel())).values[-1]; logits[logits < kth] = -float("inf")
probs = F.softmax(logits, -1)
if top_p < 1.0:
s, si = torch.sort(probs, descending=True); cut = torch.cumsum(s, -1) > top_p
cut[1:] = cut[:-1].clone(); cut[0] = False; s[cut] = 0
probs = torch.zeros_like(probs).scatter_(0, si, s); probs /= probs.sum()
nxt = int(torch.multinomial(probs, 1))
if nxt == eos:
break
out.append(nxt); x = torch.cat([x, torch.tensor([[nxt]], device=device)], 1)
if x.shape[1] >= 2048:
x = x[:, -2048:]
return sp.decode([t for t in out[len(ids):] if t != eos])
# ───────────── RAG (saf-python: yerelde test edilebilir) ─────────────
DEMO_DOCS = [
"Türkiye'nin başkenti Ankara'dır. Ankara ülkenin ikinci en kalabalık şehridir ve yaklaşık 5,7 milyon nüfusa sahiptir.",
"İstanbul, Türkiye'nin en kalabalık şehridir; nüfusu yaklaşık 15,5 milyondur. İstanbul iki kıtaya yayılır.",
"Fotosentez, bitkilerin güneş ışığı, su ve karbondioksiti kullanarak glikoz ve oksijen ürettiği biyokimyasal süreçtir.",
"SmartCore V1, sıfırdan eğitilmiş 177 milyon parametreli, Mamba-3 ve GQA hibrit mimarili, Türkçe ve İngilizce bir dil modelidir.",
"The capital of France is Paris, located on the river Seine. Paris is the most populous city in France.",
"Water boils at 100 degrees Celsius at sea level under one atmosphere of pressure.",
"The speed of light in vacuum is approximately 299,792 kilometers per second.",
"Mount Everest is the highest mountain on Earth, with its peak at 8,849 meters above sea level.",
"Mamba is a state space model (SSM) architecture for sequence modeling, proposed as a linear-time alternative to the Transformer.",
"Mimar Sinan, Osmanlı İmparatorluğu'nun baş mimarıdır; Süleymaniye ve Selimiye camilerini tasarlamıştır.",
]
# A/B için: her biri DEMO_DOCS'ta cevabı OLAN sorular (SFT sohbet testindeki halüsinasyonlarla eşleşir).
DEMO_QUERIES = [
"Türkiye'nin başkenti neresi?", # doc0 (SFT'de cevaplanamadı)
"Fotosentez nedir?", # doc2 (SFT'de yanlış)
"SmartCore V1 nedir?", # doc3
"What is the capital of France?", # doc4
"Su deniz seviyesinde kaç derecede kaynar?", # doc5
"How tall is Mount Everest?", # doc7
]
def chunk_text(text, size=600, overlap=100):
text = " ".join(text.split())
out, i = [], 0
step = max(1, size - overlap)
while i < len(text):
seg = text[i:i + size].strip()
if seg:
out.append(seg)
i += step
return out
# TR tespiti: özel karakter VEYA TR soru-kelimesi (faz7_prep_rag_sft.py ile BİREBİR aynı tutulmalı).
_TR_WORDS = {"nedir", "ne", "hangi", "neresi", "nerede", "nereye", "kim", "kimdir", "neden",
"niçin", "niye", "mıdır", "midir", "mudur", "müdür", "kaç", "kaçtır", "nasıl",
"mı", "mi", "mu", "mü", "için", "ile", "kaçıncı"}
def _is_tr(question):
ql = question.lower()
if any(ch in ql for ch in "çğıöşü"):
return True
words = set(ql.replace("?", " ").replace("'", " ").replace(".", " ").split())
return bool(words & _TR_WORDS)
def build_rag_prompt(question, hits, max_ctx_chars=2400):
ctx = ""
for c, _ in hits:
if len(ctx) + len(c) + 2 > max_ctx_chars:
break
ctx += c + "\n\n" # düz metin (madde yok) → RAG-SFT verisi ham paragrafla eşleşsin
# SFT formatı = Talimat/Girdi/Yanıt; `### Soru` SFT'de YOK → kaldırıldı, soru talimata gömüldü.
# Yönerge soruyla AYNI dilde (EN soruda TR'ye kaymayı önler — France vakası).
# "kısa/doğrudan + adım adım düşünme" → Magpie'den gelen <think> sızmasını ve uzun-jenerasyon halüsinasyonunu kırar.
tr = _is_tr(question)
instr = (f"Aşağıdaki bağlamı kullanarak soruyu kısa ve doğrudan yanıtla; adım adım düşünme. "
f"Cevap bağlamda yoksa \"bilmiyorum\" de.\nSoru: {question}"
if tr else
f"Answer the question using only the context below, briefly and directly; do not think step by step. "
f"If the answer is not in the context, say \"I don't know\".\nQuestion: {question}")
return f"### Talimat:\n{instr}\n\n### Girdi:\n{ctx.strip()}\n\n### Yanıt:\n"
_THINK = re.compile(r"<think>.*?</think>|<think>.*", re.S)
def strip_think(text):
"""Magpie R1 reasoning sızıntısını temizle (açık/kapalı <think> blokları)."""
return _THINK.sub("", text).strip()
def load_docs(path):
docs = []
files = ([path] if os.path.isfile(path)
else glob.glob(os.path.join(path, "**", "*.txt"), recursive=True)
+ glob.glob(os.path.join(path, "**", "*.md"), recursive=True))
for f in files:
with open(f, encoding="utf-8", errors="ignore") as fh:
docs.append(fh.read())
return docs
# ───────────── retrieval (embedder gerekir) ─────────────
def build_index(embedder, chunks):
return embedder.encode(chunks, normalize_embeddings=True, convert_to_numpy=True,
batch_size=64, show_progress_bar=False)
def retrieve(embedder, index, chunks, query, k):
q = embedder.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0]
sims = index @ q
top = sims.argsort()[::-1][:k]
return [(chunks[i], float(sims[i])) for i in top]
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--docs", default=None, help="doküman dosyası/dizini (.txt/.md)")
ap.add_argument("--demo", action="store_true", help="gömülü TR+EN demo korpusu")
ap.add_argument("--query", default=None, help="boşsa interaktif")
ap.add_argument("--ckpt", default="sft/epoch_2/ckpt.pt", help="HF yolu | yerel .pt")
ap.add_argument("--tokenizer", default=None, help="yerel tokenizer.model yolu (public/offline; boşsa HF'den çeker)")
ap.add_argument("--embed_model", default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
ap.add_argument("--top_k", type=int, default=3)
ap.add_argument("--chunk_size", type=int, default=600)
ap.add_argument("--overlap", type=int, default=100)
ap.add_argument("--max_ctx_chars", type=int, default=2400)
ap.add_argument("--temperature", type=float, default=0.3)
ap.add_argument("--max_new", type=int, default=96, help="üretilecek max token (RAG extractive için kısa tut)")
ap.add_argument("--no_rag", action="store_true", help="RAG'sız (kıyas için: bağlam enjekte etme)")
ap.add_argument("--ab", action="store_true",
help="A/B: her soru için RAG vs ham cevabı yan yana üret (tek koşu, tek model yüklemesi)")
ap.add_argument("--device", default="auto", choices=["auto", "cuda", "cpu"],
help="cpu = saf-PyTorch lamba_cpu (GPU/Triton gerekmez); auto = cuda varsa cuda")
ap.add_argument("--int8", action="store_true",
help="CPU int8 dynamic quant (~2× küçük bellek, ~1.2× hız; sadece --device cpu)")
args = ap.parse_args()
device = ("cuda" if torch.cuda.is_available() else "cpu") if args.device == "auto" else args.device
if device == "cuda":
assert torch.cuda.is_available(), "CUDA yok — --device cpu kullan (saf-PyTorch lamba_cpu)."
torch.set_float32_matmul_precision("high")
print(f"[device] {device}", flush=True)
token = os.environ.get("HF_TOKEN")
try:
from huggingface_hub import get_token
token = token or get_token()
except Exception:
pass
# 1) korpus → chunk → index
docs = DEMO_DOCS if args.demo else (load_docs(args.docs) if args.docs else None)
assert docs, "doküman yok: --demo ya da --docs ver."
chunks = []
for d in docs:
chunks += chunk_text(d, args.chunk_size, args.overlap)
print(f"[rag] {len(docs)} doküman → {len(chunks)} chunk | embed: {args.embed_model}", flush=True)
from sentence_transformers import SentenceTransformer
embedder = SentenceTransformer(args.embed_model, device=device)
index = build_index(embedder, chunks)
# 2) generator
sp = load_tok(token, args.tokenizer)
ckpt_path = resolve_ckpt(args.ckpt, token)
st = torch.load(ckpt_path, map_location="cpu")
if device == "cpu": # saf-PyTorch (Triton/mamba_ssm YOK)
import lamba_cpu as LC
model, _ = LC.load_lamba(ckpt_path); gen = LC.generate # decode-cache'li O(L) üretim
if args.int8:
model = LC.quantize_int8(model); print("[int8] dynamic quant uygulandı", flush=True)
else:
assert _FORK_OK, f"GPU yolu mamba-og fork ister ({_FORK_ERR!r}). Wheel kur ya da --device cpu kullan."
model = HybridLM(st["cfg"], device="cuda", dtype=torch.bfloat16)
model.load_state_dict(st["model"], strict=False); model.eval(); gen = generate
tag = f"sft epoch={st.get('epoch')}" if st.get("sft") else f"base step={st.get('step','?')}"
print(f"[model] {tag} | {'MIMO' if st['cfg'].get('is_mimo') else 'SISO'} | {device}\n", flush=True)
def answer(q, use_rag):
hits = retrieve(embedder, index, chunks, q, args.top_k)
prompt = (build_rag_prompt(q, hits, args.max_ctx_chars) if use_rag
else f"### Talimat:\n{q}\n\n### Yanıt:\n")
return hits, strip_think(gen(model, sp, prompt, max_new=args.max_new,
temperature=args.temperature, device=device))
def show(q, use_rag):
hits, ans = answer(q, use_rag)
if use_rag:
print("[getirilen]")
for c, s in hits:
print(f" ({s:.2f}) {c[:90]}")
print(f"CEVAP: {ans}\n")
if args.ab: # A/B: aynı modelle her soruyu RAG ve ham olarak üret
queries = [args.query] if args.query else DEMO_QUERIES
for q in queries:
hits = retrieve(embedder, index, chunks, q, args.top_k)
print("=" * 72)
print(f"SORU: {q}")
print(f"[getirilen] ({hits[0][1]:.2f}) {hits[0][0][:90]}")
a_rag = strip_think(gen(model, sp, build_rag_prompt(q, hits, args.max_ctx_chars),
max_new=args.max_new, temperature=args.temperature, device=device))
a_raw = strip_think(gen(model, sp, f"### Talimat:\n{q}\n\n### Yanıt:\n",
max_new=args.max_new, temperature=args.temperature, device=device))
print(f" [A · RAG]: {a_rag}")
print(f" [B · ham]: {a_raw}")
return
if args.query:
show(args.query, not args.no_rag)
else:
print("İnteraktif RAG — soru yaz (boş/çık = quit).")
while True:
try:
q = input("\n> ").strip()
except (EOFError, KeyboardInterrupt):
break
if not q or q.lower() in ("quit", "exit", "çık"):
break
show(q, not args.no_rag)
if __name__ == "__main__":
main()