""" Faz 3 — SmartCore V1 tam pretraining (fork/Triton hibrit, Colab A100). Model: Mamba-3 SISO (mamba-og Triton kernel) + her 6. katman GQA (torch SDPA, flash_attn'sız). Veri: kdirgul/smartcore-v1-data parquet shard'ları (önce yerele indirilir → resumable, hızlı). Eğitim: WSD (warmup→stable→decay), AdamW (2D-only wd), bf16 autocast, grad-accum ~0.5M token, grad-clip 1.0, z-loss; checkpoint + async HF push + cross-session --resume. ÖNKOŞUL (Colab): mamba-og fork kurulu (Faz 3a). HF_TOKEN env (private repo). Kullanım (ilk): HF_TOKEN=hf_xxx python faz3_train.py Devam (yeni Colab session): HF_TOKEN=hf_xxx python faz3_train.py --resume latest_hf """ import os, sys, time, math, glob, argparse, random, signal, threading import torch, torch.nn as nn, torch.nn.functional as F from functools import partial from concurrent.futures import ThreadPoolExecutor # ───────────────────────── model (fork hibrit) ───────────────────────── # Fork importları: Colab'da başarılı; yerelde (fork yok) None → ShardStream/wsd_lr yine de test edilebilir. try: from mamba_ssm.modules.block import Block from mamba_ssm.modules.mamba3 import Mamba3 from mamba_ssm.modules.mlp import GatedMLP from mamba_ssm.ops.triton.layer_norm import RMSNorm except ImportError: Block = Mamba3 = GatedMLP = RMSNorm = None # yerel: model kurulamaz, veri/LR test edilebilir # Ölçek presetleri (v1.5b: 350m). head_dim=64, attn_every, d_state vb. sabit/aşağıda. PRESETS = { "177m": dict(d_model=768, n_layers=20, d_intermediate=1500, head_dim=64, n_heads=12, n_kv_heads=3), "350m": dict(d_model=1024, n_layers=24, d_intermediate=2048, head_dim=64, n_heads=16, n_kv_heads=4), } def _rms(x, w, eps=1e-5): return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)) * w def _rot_half(x): a, b = x.chunk(2, -1) return torch.cat((-b, a), -1) class GQAMixer(nn.Module): """GQA attention (QK-norm + RoPE, causal) — torch SDPA, flash_attn YOK. Block'a uyumlu (x->tensor). Çıkış projeksiyonu 'out_proj' adıyla → _init_weights residual-rescale'i yakalar.""" def __init__(self, dim, n_heads=12, n_kv=3, base=10000.0, layer_idx=None, device=None, dtype=None): super().__init__() self.nh, self.nkv, self.hd = n_heads, n_kv, dim // n_heads self.rep = n_heads // n_kv fk = {"device": device, "dtype": dtype} self.q_proj = nn.Linear(dim, n_heads * self.hd, bias=False, **fk) self.k_proj = nn.Linear(dim, n_kv * self.hd, bias=False, **fk) self.v_proj = nn.Linear(dim, n_kv * self.hd, bias=False, **fk) self.out_proj = nn.Linear(n_heads * self.hd, dim, bias=False, **fk) self.qn = nn.Parameter(torch.ones(self.hd, **fk)) self.kn = nn.Parameter(torch.ones(self.hd, **fk)) for lin in (self.q_proj, self.k_proj, self.v_proj): nn.init.normal_(lin.weight, std=0.02) self.register_buffer( "inv", 1.0 / (base ** (torch.arange(0, self.hd, 2, device=device).float() / self.hd)), persistent=False) def _rope(self, x, T): f = torch.outer(torch.arange(T, device=x.device, dtype=torch.float32), self.inv) e = torch.cat((f, f), -1) return (x * e.cos()[None, None] + _rot_half(x) * e.sin()[None, None]).to(x.dtype) def forward(self, x, **kw): # kw: Block'tan gelen inference_params vb. yok sayılır (eğitim) B, T, _ = x.shape q = self.q_proj(x).view(B, T, self.nh, self.hd).transpose(1, 2) k = self.k_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2) v = self.v_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2) q = _rms(q.float(), self.qn.float()).to(x.dtype) k = _rms(k.float(), self.kn.float()).to(x.dtype) q, k = self._rope(q, T), self._rope(k, T) k = k.repeat_interleave(self.rep, 1) v = v.repeat_interleave(self.rep, 1) y = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.out_proj(y.transpose(1, 2).contiguous().view(B, T, -1)) def _init_weights(m, n_layer): if isinstance(m, nn.Linear) and m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, std=0.02) for name, p in m.named_parameters(): if name in ("out_proj.weight", "fc2.weight"): # residual rescale (GPT-2/Mamba kuralı) nn.init.kaiming_uniform_(p, a=math.sqrt(5)) with torch.no_grad(): p /= math.sqrt(2 * n_layer) class HybridLM(nn.Module): def __init__(self, cfg, device=None, dtype=None): super().__init__() self.cfg = cfg self.vocab = cfg["vocab_size"] self.scaled_embed = cfg.get("scaled_embed", False) self.z_loss = cfg.get("z_loss", 1e-4) d = cfg["d_model"] self.embedding = nn.Embedding(self.vocab, d, device=device, dtype=dtype) self.layers = nn.ModuleList() self.attn_idx = [] for i in range(cfg["n_layers"]): is_attn = ((i + 1) % cfg["attn_every"] == 0) and i != 0 and i != cfg["n_layers"] - 1 fk = {"device": device, "dtype": dtype} if is_attn: mixer_cls = partial(GQAMixer, n_heads=cfg["n_heads"], n_kv=cfg["n_kv_heads"], layer_idx=i, **fk) self.attn_idx.append(i) else: ssm = dict(d_state=cfg["d_state"], expand=cfg["expand"], headdim=cfg["head_dim"], ngroups=cfg["ngroups"], rope_fraction=cfg["rope_fraction"], is_outproj_norm=False, is_mimo=cfg["is_mimo"], mimo_rank=cfg["mimo_rank"], chunk_size=cfg["chunk_size"]) mixer_cls = partial(Mamba3, layer_idx=i, **ssm, **fk) blk = Block(d, mixer_cls, partial(GatedMLP, hidden_features=cfg["d_intermediate"], out_features=d, **fk), norm_cls=partial(RMSNorm, eps=1e-5, **fk), fused_add_norm=True, residual_in_fp32=True) blk.layer_idx = i self.layers.append(blk) self.norm_f = RMSNorm(d, eps=1e-5, device=device, dtype=dtype) self.lm_head = nn.Linear(d, self.vocab, bias=False, device=device, dtype=dtype) self.apply(partial(_init_weights, n_layer=cfg["n_layers"])) self.lm_head.weight = self.embedding.weight # tied (init sonrası) def forward(self, ids, labels=None): h = self.embedding(ids) if self.scaled_embed: h = h * (self.cfg["d_model"] ** 0.5) res = None for l in self.layers: h, res = l(h, res) h = self.norm_f((h + res) if res is not None else h) logits = self.lm_head(h.to(self.lm_head.weight.dtype)) loss = None if labels is not None: sl = logits[:, :-1].reshape(-1, self.vocab).float() tl = labels[:, 1:].reshape(-1) loss = F.cross_entropy(sl, tl, ignore_index=-100) if self.z_loss > 0: z = torch.logsumexp(sl, dim=-1) loss = loss + self.z_loss * (z ** 2).mean() return logits, loss def n_params(m): seen, t = set(), 0 for p in m.parameters(): if id(p) in seen: continue seen.add(id(p)); t += p.numel() return t # ───────────────────────── veri (resumable shard) ───────────────────────── import pyarrow.parquet as pq MIXES = { "177m": {"en_fineweb_edu": 0.55, "tr_fineweb2_hq": 0.22, "code_codeparrot": 0.13, "math_openwebmath": 0.10}, "350m": {"en_fineweb_edu": 0.47, "tr_tc100b": 0.30, "code_codeparrot": 0.13, "math_openwebmath": 0.10}, # v1.5b: TR↑ TC-100B } def ensure_local_data(data, token): """HF repo ise yerele indir (resumable + hızlı), zaten yerel dizinse aynen döndür.""" if os.path.isdir(data): return data from huggingface_hub import snapshot_download print(f"[veri] {data} indiriliyor (snapshot)...", flush=True) p = snapshot_download(data, repo_type="dataset", token=token, allow_patterns=["*/shard_*.parquet"]) print(f"[veri] indirildi: {p}", flush=True) return p class ShardStream: """Yerel parquet'ten oranlı, DETERMİNİSTİK ve RESUMABLE okuma (cursor + RNG kaydedilir).""" def __init__(self, root, seq_len, mix, seed=42): self.names = list(mix); self.w = [mix[n] for n in self.names] self.seq_len = seq_len self.files = {n: sorted(glob.glob(os.path.join(root, n, "shard_*.parquet"))) for n in self.names} for n in self.names: assert self.files[n], f"shard yok: {root}/{n}" self.cursor = {n: [0, 0] for n in self.names} # [shard_idx, row_idx] self.cache = {} self.rng = random.Random(seed) def _rows(self, n): si = self.cursor[n][0] % len(self.files[n]) if self.cache.get(n, (None,))[0] != si: # arrow kolonu olduğu gibi tut (to_pylist YOK → ~1GB bellek/şard spike'ı önlenir) col = pq.read_table(self.files[n][si], columns=["input_ids"]).column("input_ids") self.cache[n] = (si, col) return self.cache[n][1] def _next(self, n): rows = self._rows(n) if self.cursor[n][1] >= len(rows): self.cursor[n][0] += 1; self.cursor[n][1] = 0; rows = self._rows(n) ri = self.cursor[n][1]; self.cursor[n][1] = ri + 1 return rows[ri].as_py()[:self.seq_len] # tek satırı listeye çevir def batch(self, bsz, device): rows = [self._next(self.rng.choices(self.names, weights=self.w, k=1)[0]) for _ in range(bsz)] return torch.tensor(rows, dtype=torch.long, device=device) def state(self): # derin kopya: cursor mutable; aksi halde sonraki batch() kaydı bozar return {"cursor": {k: list(v) for k, v in self.cursor.items()}, "rng": self.rng.getstate()} def load_state(self, s): self.cursor = {k: list(v) for k, v in s["cursor"].items()} self.rng.setstate(s["rng"]); self.cache = {} # ───────────────────────── WSD LR ───────────────────────── def wsd_lr(step, total, peak, floor, warmup, decay_frac=0.25): if step < warmup: return peak * (step + 1) / warmup dec = int(total * (1 - decay_frac)) if step < dec: return peak return peak - (peak - floor) * (step - dec) / max(1, total - dec) # ───────────────────────── checkpoint + async HF push ───────────────────────── class Ckpt: def __init__(self, local_dir, repo_id, token, keep=3): self.dir = local_dir; self.repo = repo_id; self.keep = keep os.makedirs(local_dir, exist_ok=True) self.api = None if repo_id and token: from huggingface_hub import HfApi self.api = HfApi(token=token) self.ex = ThreadPoolExecutor(max_workers=1); self.lock = threading.Lock() def save(self, step, model, opts, stream, extra): d = os.path.join(self.dir, f"step_{step:06d}"); os.makedirs(d, exist_ok=True) torch.save({"model": model.state_dict(), "opt": [o.state_dict() for o in opts], "step": step, "stream": stream.state(), "torch_rng": torch.get_rng_state(), "cuda_rng": torch.cuda.get_rng_state_all(), **extra}, os.path.join(d, "ckpt.pt")) self._rotate() if self.api: self.ex.submit(self._push, d, step) print(f"[ckpt] kaydedildi step {step} -> {d}", flush=True) def _push(self, d, step): try: with self.lock: self.api.upload_folder(folder_path=d, repo_id=self.repo, repo_type="model", path_in_repo=f"checkpoints/step_{step:06d}", commit_message=f"ckpt step {step}") print(f"[ckpt] HF push OK step {step}", flush=True) except Exception as e: print(f"[ckpt] HF push HATA step {step}: {repr(e)[:160]}", flush=True) def _rotate(self): ds = sorted(glob.glob(os.path.join(self.dir, "step_*"))) for old in ds[:-self.keep]: for f in glob.glob(os.path.join(old, "*")): os.remove(f) os.rmdir(old) def latest_local(self): ds = sorted(glob.glob(os.path.join(self.dir, "step_*", "ckpt.pt"))) return ds[-1] if ds else None def latest_hf(self): if not self.api: return None from huggingface_hub import hf_hub_download files = [f for f in self.api.list_repo_files(self.repo, repo_type="model") if f.startswith("checkpoints/step_") and f.endswith("ckpt.pt")] if not files: return None latest = max(files) # step_NNNNNN sıralı return hf_hub_download(self.repo, latest, repo_type="model") # ───────────────────────── train ───────────────────────── # ───────────────────────── Muon optimizer (v1.5: 2D-Linear ağırlıkları için) ───────────────────────── def _ns5(G, steps=5): """Newton-Schulz orthogonalizasyon (Muon çekirdeği) — G'yi yarı-ortogonale yaklaştır.""" a, b, c = 3.4445, -4.7750, 2.0315 X = G.bfloat16() t = G.size(-2) > G.size(-1) if t: X = X.mT X = X / (X.norm() + 1e-7) for _ in range(steps): A = X @ X.mT B = b * A + c * (A @ A) X = a * X + B @ X if t: X = X.mT return X.to(G.dtype) class Muon(torch.optim.Optimizer): """Momentum + Newton-Schulz ortogonalize güncelleme (Keller Jordan). 2D matris ağırlıkları için.""" def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, weight_decay=0.0): super().__init__(params, dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, weight_decay=weight_decay)) @torch.no_grad() def step(self): for grp in self.param_groups: lr, mom, nest = grp["lr"], grp["momentum"], grp["nesterov"] ns, wd = grp["ns_steps"], grp["weight_decay"] for p in grp["params"]: if p.grad is None: continue st = self.state[p] if "buf" not in st: st["buf"] = torch.zeros_like(p.grad) buf = st["buf"]; buf.lerp_(p.grad, 1 - mom) u = p.grad.lerp_(buf, mom) if nest else buf u = _ns5(u, ns) * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5) if wd: p.mul_(1 - lr * wd) p.add_(u, alpha=-lr) def build_optimizers(model, args): """args.muon ise [Muon(2D-Linear), AdamW(embed+3D / 1D)]; değilse [AdamW] (orijinal). Dönüş (opts, base_lrs) — base_lr WSD çarpanıyla (0→1→0.1) ölçeklenir.""" if args.muon: emb = {id(m.weight) for m in model.modules() if isinstance(m, nn.Embedding)} muon_p, adam_wd, adam_nod = [], [], [] for p in model.parameters(): if p.ndim == 2 and id(p) not in emb: muon_p.append(p) # gizli Linear ağırlıkları → Muon elif p.ndim >= 2: adam_wd.append(p) # embedding (2D, tied lm_head) + 3D bias (B/C_bias) else: adam_nod.append(p) # 1D (norm, dt_bias, D, A_log, qn/kn…) o_m = Muon(muon_p, lr=args.muon_lr, momentum=0.95, ns_steps=5, weight_decay=0.0) o_a = torch.optim.AdamW([{"params": adam_wd, "weight_decay": 0.1}, {"params": adam_nod, "weight_decay": 0.0}], lr=args.peak_lr, betas=(0.9, 0.95), eps=1e-8, fused=True) mp = sum(p.numel() for p in muon_p); ap_ = sum(p.numel() for p in adam_wd + adam_nod) print(f"[opt] MUON {mp/1e6:.1f}M (2D-Linear) + AdamW {ap_/1e6:.1f}M (embed+norm) | " f"muon_lr {args.muon_lr} | peak_lr {args.peak_lr}", flush=True) return [o_m, o_a], [args.muon_lr, args.peak_lr] decay = [p for p in model.parameters() if p.ndim >= 2] nod = [p for p in model.parameters() if p.ndim < 2] o_a = torch.optim.AdamW([{"params": decay, "weight_decay": 0.1}, {"params": nod, "weight_decay": 0.0}], lr=args.peak_lr, betas=(0.9, 0.95), eps=1e-8, fused=True) print(f"[opt] AdamW (tek) | peak_lr {args.peak_lr}", flush=True) return [o_a], [args.peak_lr] def main(): ap = argparse.ArgumentParser() ap.add_argument("--data", default="kdirgul/smartcore-v1-data") ap.add_argument("--ckpt_repo", default="kdirgul/smartcore-v1") ap.add_argument("--ckpt_dir", default="/content/ckpt") ap.add_argument("--resume", default=None, help="latest_local | latest_hf | ") ap.add_argument("--preset", default="177m", choices=["177m", "350m"]) ap.add_argument("--n_layers", type=int, default=None) # None → preset; override için ver ap.add_argument("--d_model", type=int, default=None) ap.add_argument("--d_intermediate", type=int, default=None) ap.add_argument("--n_heads", type=int, default=None) ap.add_argument("--n_kv_heads", type=int, default=None) ap.add_argument("--attn_every", type=int, default=6) ap.add_argument("--seq_len", type=int, default=2048) ap.add_argument("--micro_batch", type=int, default=4) ap.add_argument("--grad_accum", type=int, default=64) # 4*64*2048 = 524288 tok/step ap.add_argument("--total_tokens", type=float, default=12e9) ap.add_argument("--peak_lr", type=float, default=5e-4) ap.add_argument("--muon", action="store_true", help="v1.5: 2D-Linear ağırlıkları Muon (embed/norm AdamW)") ap.add_argument("--muon_lr", type=float, default=0.02) ap.add_argument("--warmup", type=int, default=600) ap.add_argument("--save_every", type=int, default=500) # opt-step ap.add_argument("--log_every", type=int, default=10) ap.add_argument("--seed", type=int, default=42) args = ap.parse_args() dev = torch.device("cuda") torch.manual_seed(args.seed); random.seed(args.seed) torch.set_float32_matmul_precision("high") token = os.environ.get("HF_TOKEN") P = dict(PRESETS[args.preset]) for k in ("n_layers", "d_model", "d_intermediate", "n_heads", "n_kv_heads"): if getattr(args, k) is not None: P[k] = getattr(args, k) cfg = dict(vocab_size=48000, d_model=P["d_model"], n_layers=P["n_layers"], d_state=128, expand=2, head_dim=P["head_dim"], ngroups=1, d_intermediate=P["d_intermediate"], attn_every=args.attn_every, n_heads=P["n_heads"], n_kv_heads=P["n_kv_heads"], rope_fraction=0.5, is_mimo=False, mimo_rank=1, chunk_size=128, scaled_embed=False, z_loss=1e-4) print(f"[cfg] preset={args.preset} | d_model={cfg['d_model']} n_layers={cfg['n_layers']} " f"d_int={cfg['d_intermediate']} {cfg['n_heads']}/{cfg['n_kv_heads']} GQA attn_every={cfg['attn_every']}", flush=True) model = HybridLM(cfg, device=dev, dtype=torch.bfloat16) print(f"[model] {n_params(model)/1e6:.1f}M | {cfg['n_layers']-len(model.attn_idx)} Mamba + " f"{len(model.attn_idx)} GQA (attn@{model.attn_idx})", flush=True) opts, base_lrs = build_optimizers(model, args) batch_tok = args.micro_batch * args.grad_accum * args.seq_len total_steps = int(args.total_tokens / batch_tok) print(f"[plan] {args.total_tokens/1e9:.0f}B token | {batch_tok} tok/step | {total_steps} step | " f"warmup {args.warmup} | peak {args.peak_lr}", flush=True) root = ensure_local_data(args.data, token) mix = MIXES.get(args.preset, MIXES["177m"]) # preset'e göre karışım (350m: TR↑ TC-100B) print(f"[mix] {args.preset}: " + " ".join(f"{k}={v:.0%}" for k, v in mix.items()), flush=True) stream = ShardStream(root, args.seq_len, mix, seed=args.seed) ckpt = Ckpt(args.ckpt_dir, args.ckpt_repo, token) start_step = 0 if args.resume: path = (ckpt.latest_local() if args.resume == "latest_local" else ckpt.latest_hf() if args.resume == "latest_hf" else args.resume) if path and os.path.exists(path): st = torch.load(path, map_location="cpu") model.load_state_dict(st["model"]) osd = st["opt"] if isinstance(st["opt"], list) else [st["opt"]] for o, s in zip(opts, osd): o.load_state_dict(s) stream.load_state(st["stream"]); start_step = st["step"] + 1 torch.set_rng_state(st["torch_rng"]); torch.cuda.set_rng_state_all(st["cuda_rng"]) print(f"[resume] {path} -> step {start_step}", flush=True) else: print(f"[resume] checkpoint bulunamadı ({args.resume}) — sıfırdan başlıyor", flush=True) # SIGTERM/SIGINT -> acil kayıt (Colab disconnect güvenlik ağı) cur = {"step": start_step} def emergency(signum, frame): ckpt.save(cur["step"], model, opts, stream, {"cfg": cfg}) sys.exit(0) signal.signal(signal.SIGTERM, emergency); signal.signal(signal.SIGINT, emergency) model.train() t0 = time.perf_counter(); seen = 0 for step in range(start_step, total_steps): cur["step"] = step frac = wsd_lr(step, total_steps, 1.0, 0.1, args.warmup) # 0→1→0.1 çarpan (her opt kendi base_lr'iyle) for o, b in zip(opts, base_lrs): for g in o.param_groups: g["lr"] = b * frac for o in opts: o.zero_grad(set_to_none=True) loss_acc = 0.0 for _ in range(args.grad_accum): batch = stream.batch(args.micro_batch, dev) with torch.autocast(device_type="cuda", dtype=torch.bfloat16): _, loss = model(batch, labels=batch) (loss / args.grad_accum).backward() loss_acc += loss.item() / args.grad_accum seen += batch.numel() gn = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) for o in opts: o.step() if step % args.log_every == 0: tok_s = seen / (time.perf_counter() - t0) print(f"step {step:6d}/{total_steps} | loss {loss_acc:.4f} | gnorm {gn:5.2f} | " f"lr {base_lrs[-1]*frac:.2e} | {tok_s/1e3:.1f}k tok/s | {seen/1e9:.3f}B tok", flush=True) if step > start_step and step % args.save_every == 0: ckpt.save(step, model, opts, stream, {"cfg": cfg}) ckpt.save(total_steps - 1, model, opts, stream, {"cfg": cfg, "final": True}) print("[bitti] pretraining tamamlandı.", flush=True) if __name__ == "__main__": main()