File size: 11,266 Bytes
b80607a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e5c273
 
 
 
 
b80607a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e5c273
b80607a
 
 
 
 
 
 
 
0e5c273
b80607a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e5c273
 
b80607a
 
 
 
 
 
 
0e5c273
b80607a
0e5c273
 
 
b80607a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""
SmartCore V1 — SIFIRDAN Colab test (standalone / kendine yeten).

Final base modeli (kdirgul/smartcore-v1 son checkpoint) HF'den çeker, metin üretir.
Model tanımı faz3_train.py ile BİREBİR aynı (içine gömülü) → state_dict tam oturar,
faz3_train.py'ye bağımlılık YOK.

Ortam: Colab GPU + mamba-og fork (Faz 3a kurulu). CUDA şart (Triton kernel).
NOT: Bu bir BASE model (instruction yok) → soru-cevap DEĞİL, metin TAMAMLAMA yapar.

Kullanım:
  HF_TOKEN=hf_xxx python test_smartcore.py --prompt "Türkiye'nin başkenti"
  HF_TOKEN=hf_xxx python test_smartcore.py                    # interaktif REPL
  python test_smartcore.py --ckpt /content/ck/step_022887/ckpt.pt   # yerel .pt
"""
import os, sys, math, argparse
import torch, torch.nn as nn, torch.nn.functional as F
from functools import partial

try:
    from mamba_ssm.modules.block import Block
    from mamba_ssm.modules.mamba3 import Mamba3
    from mamba_ssm.modules.mlp import GatedMLP
    from mamba_ssm.ops.triton.layer_norm import RMSNorm
except Exception as e:
    sys.exit(f"[hata] mamba-og fork import edilemedi ({e!r}). Önce Faz 3a kurulum hücresini çalıştır (CUDA gerekir).")


# ───────────── model (faz3_train.py ile BİREBİR AYNI) ─────────────
def _rms(x, w, eps=1e-5):
    return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)) * w


def _rot_half(x):
    a, b = x.chunk(2, -1)
    return torch.cat((-b, a), -1)


class GQAMixer(nn.Module):
    def __init__(self, dim, n_heads=12, n_kv=3, base=10000.0, layer_idx=None, device=None, dtype=None):
        super().__init__()
        self.nh, self.nkv, self.hd = n_heads, n_kv, dim // n_heads
        self.rep = n_heads // n_kv
        fk = {"device": device, "dtype": dtype}
        self.q_proj = nn.Linear(dim, n_heads * self.hd, bias=False, **fk)
        self.k_proj = nn.Linear(dim, n_kv * self.hd, bias=False, **fk)
        self.v_proj = nn.Linear(dim, n_kv * self.hd, bias=False, **fk)
        self.out_proj = nn.Linear(n_heads * self.hd, dim, bias=False, **fk)
        self.qn = nn.Parameter(torch.ones(self.hd, **fk))
        self.kn = nn.Parameter(torch.ones(self.hd, **fk))
        self.register_buffer(
            "inv", 1.0 / (base ** (torch.arange(0, self.hd, 2, device=device).float() / self.hd)),
            persistent=False)

    def _rope(self, x, T):
        f = torch.outer(torch.arange(T, device=x.device, dtype=torch.float32), self.inv)
        e = torch.cat((f, f), -1)
        return (x * e.cos()[None, None] + _rot_half(x) * e.sin()[None, None]).to(x.dtype)

    def forward(self, x, **kw):
        B, T, _ = x.shape
        q = self.q_proj(x).view(B, T, self.nh, self.hd).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
        q = _rms(q.float(), self.qn.float()).to(x.dtype)
        k = _rms(k.float(), self.kn.float()).to(x.dtype)
        q, k = self._rope(q, T), self._rope(k, T)
        k = k.repeat_interleave(self.rep, 1)
        v = v.repeat_interleave(self.rep, 1)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        return self.out_proj(y.transpose(1, 2).contiguous().view(B, T, -1))


class HybridLM(nn.Module):
    def __init__(self, cfg, device=None, dtype=None):
        super().__init__()
        self.cfg = cfg
        self.vocab = cfg["vocab_size"]
        self.scaled_embed = cfg.get("scaled_embed", False)
        d = cfg["d_model"]
        self.embedding = nn.Embedding(self.vocab, d, device=device, dtype=dtype)
        self.layers = nn.ModuleList()
        self.attn_idx = []
        for i in range(cfg["n_layers"]):
            is_attn = ((i + 1) % cfg["attn_every"] == 0) and i != 0 and i != cfg["n_layers"] - 1
            fk = {"device": device, "dtype": dtype}
            if is_attn:
                mixer_cls = partial(GQAMixer, n_heads=cfg["n_heads"], n_kv=cfg["n_kv_heads"],
                                    layer_idx=i, **fk)
                self.attn_idx.append(i)
            else:
                ssm = dict(d_state=cfg["d_state"], expand=cfg["expand"], headdim=cfg["head_dim"],
                           ngroups=cfg["ngroups"], rope_fraction=cfg["rope_fraction"],
                           is_outproj_norm=False, is_mimo=cfg["is_mimo"], mimo_rank=cfg["mimo_rank"],
                           chunk_size=cfg["chunk_size"])
                mixer_cls = partial(Mamba3, layer_idx=i, **ssm, **fk)
            blk = Block(d, mixer_cls,
                        partial(GatedMLP, hidden_features=cfg["d_intermediate"], out_features=d, **fk),
                        norm_cls=partial(RMSNorm, eps=1e-5, **fk),
                        fused_add_norm=True, residual_in_fp32=True)
            blk.layer_idx = i
            self.layers.append(blk)
        self.norm_f = RMSNorm(d, eps=1e-5, device=device, dtype=dtype)
        self.lm_head = nn.Linear(d, self.vocab, bias=False, device=device, dtype=dtype)
        self.lm_head.weight = self.embedding.weight   # tied

    def forward(self, ids):
        h = self.embedding(ids)
        if self.scaled_embed:
            h = h * (self.cfg["d_model"] ** 0.5)
        res = None
        for l in self.layers:
            h, res = l(h, res)
        h = self.norm_f((h + res) if res is not None else h)
        return self.lm_head(h.to(self.lm_head.weight.dtype))


# ───────────── tokenizer + checkpoint ─────────────
def load_tok(path, token):
    import sentencepiece as spm
    if not (path and os.path.exists(path)):
        from huggingface_hub import hf_hub_download
        path = hf_hub_download("kdirgul/smartcore-v1", "tokenizer/tokenizer.model",
                               repo_type="model", token=token)
    sp = spm.SentencePieceProcessor(model_file=path)
    print(f"[tok] vocab={sp.get_piece_size()} eos={sp.eos_id()}", flush=True)
    return sp


def resolve_ckpt(spec, repo, token):
    if spec and spec != "latest_hf":
        if os.path.exists(spec):
            return spec
        from huggingface_hub import hf_hub_download
        print(f"[ckpt] HF: {spec}", flush=True)
        return hf_hub_download(repo, spec, repo_type="model", token=token)
    from huggingface_hub import HfApi, hf_hub_download
    api = HfApi(token=token)
    files = [f for f in api.list_repo_files(repo, repo_type="model")
             if f.startswith("checkpoints/step_") and f.endswith("ckpt.pt")]
    if not files:
        sys.exit("[hata] HF'de checkpoint yok.")
    latest = max(files)
    print(f"[ckpt] HF'den indiriliyor: {latest}", flush=True)
    return hf_hub_download(repo, latest, repo_type="model", token=token)


# ───────────── üretim ─────────────
@torch.no_grad()
def generate(model, sp, prompt, max_new=120, temperature=0.7, top_k=40, top_p=0.95,
             rep_penalty=1.3, dev="cuda", seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    eos = sp.eos_id()
    ids = sp.encode(prompt, out_type=int)
    x = torch.tensor([ids], dtype=torch.long, device=dev)
    out = list(ids)
    for _ in range(max_new):
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            logits = model(x)[0, -1].float()
        if rep_penalty and rep_penalty != 1.0:
            for t in set(out):
                logits[t] = logits[t] / rep_penalty if logits[t] > 0 else logits[t] * rep_penalty
        if temperature <= 0:
            nxt = int(logits.argmax())
        else:
            logits = logits / temperature
            if top_k:
                kth = torch.topk(logits, min(top_k, logits.numel())).values[-1]
                logits[logits < kth] = -float("inf")
            probs = F.softmax(logits, dim=-1)
            if top_p and top_p < 1.0:
                sp_, si = torch.sort(probs, descending=True)
                cut = torch.cumsum(sp_, dim=-1) > top_p
                cut[1:] = cut[:-1].clone(); cut[0] = False
                sp_[cut] = 0.0
                probs = torch.zeros_like(probs).scatter_(0, si, sp_)
                probs /= probs.sum()
            nxt = int(torch.multinomial(probs, 1))
        if nxt == eos:
            break
        out.append(nxt)
        x = torch.cat([x, torch.tensor([[nxt]], device=dev)], dim=1)
        if x.shape[1] >= 2048:
            x = x[:, -2048:]
    return sp.decode([t for t in out[len(ids):] if t != eos])   # sadece üretilen kısım


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", default="latest_hf", help="latest_hf | yerel .pt yolu")
    ap.add_argument("--ckpt_repo", default="kdirgul/smartcore-v1")
    ap.add_argument("--tokenizer", default=None)
    ap.add_argument("--prompt", default=None, help="boşsa interaktif REPL")
    ap.add_argument("--chat", action="store_true", help="SFT prompt şablonuyla sar (### Talimat/### Yanıt) — SFT modelleri için")
    ap.add_argument("--max_new", type=int, default=120)
    ap.add_argument("--temperature", type=float, default=0.7)
    ap.add_argument("--top_k", type=int, default=40)
    ap.add_argument("--top_p", type=float, default=0.95)
    ap.add_argument("--rep_penalty", type=float, default=1.3)
    ap.add_argument("--seed", type=int, default=None)
    args = ap.parse_args()

    if not torch.cuda.is_available():
        sys.exit("[hata] CUDA yok — Colab GPU gerekir (Triton kernel).")
    dev = "cuda"
    torch.set_float32_matmul_precision("high")
    token = os.environ.get("HF_TOKEN")
    if not token:
        try:
            from huggingface_hub import get_token
            token = get_token()
        except Exception:
            token = None

    sp = load_tok(args.tokenizer, token)
    path = resolve_ckpt(args.ckpt, args.ckpt_repo, token)
    st = torch.load(path, map_location="cpu")
    cfg = st["cfg"]
    tag = f"sft epoch={st.get('epoch')}" if st.get("sft") else f"base step={st.get('step','?')}"
    print(f"[model] {tag} | {'MIMO' if cfg.get('is_mimo') else 'SISO'} | "
          f"n_layers={cfg['n_layers']} | vocab={cfg['vocab_size']}", flush=True)

    model = HybridLM(cfg, device=dev, dtype=torch.bfloat16)
    miss, unexp = model.load_state_dict(st["model"], strict=False)
    if miss or unexp:
        print(f"[uyarı] eksik={len(miss)} beklenmeyen={len(unexp)} (persistent olmayan buffer normal)", flush=True)
    model.eval()
    print(f"[hazır] {'SFT (chat şablonu)' if args.chat else 'BASE (tamamlama)'} modu.\n", flush=True)

    def wrap(p):
        return f"### Talimat:\n{p}\n\n### Yanıt:\n" if args.chat else p
    g = lambda p: generate(model, sp, wrap(p), args.max_new, args.temperature, args.top_k,
                           args.top_p, args.rep_penalty, dev, args.seed)
    if args.prompt is not None:
        print(f"PROMPT: {args.prompt}\nÇIKTI : {g(args.prompt)}")
    else:
        print("İnteraktif — prompt yaz (boş/çık = quit).")
        while True:
            try:
                p = input("\n> ").strip()
            except (EOFError, KeyboardInterrupt):
                break
            if not p or p.lower() in ("quit", "exit", "çık"):
                break
            print(g(p))


if __name__ == "__main__":
    main()