smartcore-v1 / code /kod /faz3_train.py
kdirgul's picture
v1.5b: MIXES preset-aware (350m EN47/TR30 tc100b/code13/math10) + ShardStream mix param
de5424b verified
Raw
History Blame Contribute Delete
23.3 kB
"""
Faz 3 — SmartCore V1 tam pretraining (fork/Triton hibrit, Colab A100).
Model: Mamba-3 SISO (mamba-og Triton kernel) + her 6. katman GQA (torch SDPA, flash_attn'sız).
Veri: kdirgul/smartcore-v1-data parquet shard'ları (önce yerele indirilir → resumable, hızlı).
Eğitim: WSD (warmup→stable→decay), AdamW (2D-only wd), bf16 autocast, grad-accum ~0.5M token,
grad-clip 1.0, z-loss; checkpoint + async HF push + cross-session --resume.
ÖNKOŞUL (Colab): mamba-og fork kurulu (Faz 3a). HF_TOKEN env (private repo).
Kullanım (ilk):
HF_TOKEN=hf_xxx python faz3_train.py
Devam (yeni Colab session):
HF_TOKEN=hf_xxx python faz3_train.py --resume latest_hf
"""
import os, sys, time, math, glob, argparse, random, signal, threading
import torch, torch.nn as nn, torch.nn.functional as F
from functools import partial
from concurrent.futures import ThreadPoolExecutor
# ───────────────────────── model (fork hibrit) ─────────────────────────
# Fork importları: Colab'da başarılı; yerelde (fork yok) None → ShardStream/wsd_lr yine de test edilebilir.
try:
from mamba_ssm.modules.block import Block
from mamba_ssm.modules.mamba3 import Mamba3
from mamba_ssm.modules.mlp import GatedMLP
from mamba_ssm.ops.triton.layer_norm import RMSNorm
except ImportError:
Block = Mamba3 = GatedMLP = RMSNorm = None # yerel: model kurulamaz, veri/LR test edilebilir
# Ölçek presetleri (v1.5b: 350m). head_dim=64, attn_every, d_state vb. sabit/aşağıda.
PRESETS = {
"177m": dict(d_model=768, n_layers=20, d_intermediate=1500, head_dim=64, n_heads=12, n_kv_heads=3),
"350m": dict(d_model=1024, n_layers=24, d_intermediate=2048, head_dim=64, n_heads=16, n_kv_heads=4),
}
def _rms(x, w, eps=1e-5):
return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)) * w
def _rot_half(x):
a, b = x.chunk(2, -1)
return torch.cat((-b, a), -1)
class GQAMixer(nn.Module):
"""GQA attention (QK-norm + RoPE, causal) — torch SDPA, flash_attn YOK. Block'a uyumlu (x->tensor).
Çıkış projeksiyonu 'out_proj' adıyla → _init_weights residual-rescale'i yakalar."""
def __init__(self, dim, n_heads=12, n_kv=3, base=10000.0, layer_idx=None, device=None, dtype=None):
super().__init__()
self.nh, self.nkv, self.hd = n_heads, n_kv, dim // n_heads
self.rep = n_heads // n_kv
fk = {"device": device, "dtype": dtype}
self.q_proj = nn.Linear(dim, n_heads * self.hd, bias=False, **fk)
self.k_proj = nn.Linear(dim, n_kv * self.hd, bias=False, **fk)
self.v_proj = nn.Linear(dim, n_kv * self.hd, bias=False, **fk)
self.out_proj = nn.Linear(n_heads * self.hd, dim, bias=False, **fk)
self.qn = nn.Parameter(torch.ones(self.hd, **fk))
self.kn = nn.Parameter(torch.ones(self.hd, **fk))
for lin in (self.q_proj, self.k_proj, self.v_proj):
nn.init.normal_(lin.weight, std=0.02)
self.register_buffer(
"inv", 1.0 / (base ** (torch.arange(0, self.hd, 2, device=device).float() / self.hd)),
persistent=False)
def _rope(self, x, T):
f = torch.outer(torch.arange(T, device=x.device, dtype=torch.float32), self.inv)
e = torch.cat((f, f), -1)
return (x * e.cos()[None, None] + _rot_half(x) * e.sin()[None, None]).to(x.dtype)
def forward(self, x, **kw): # kw: Block'tan gelen inference_params vb. yok sayılır (eğitim)
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.nh, self.hd).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
q = _rms(q.float(), self.qn.float()).to(x.dtype)
k = _rms(k.float(), self.kn.float()).to(x.dtype)
q, k = self._rope(q, T), self._rope(k, T)
k = k.repeat_interleave(self.rep, 1)
v = v.repeat_interleave(self.rep, 1)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.out_proj(y.transpose(1, 2).contiguous().view(B, T, -1))
def _init_weights(m, n_layer):
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, std=0.02)
for name, p in m.named_parameters():
if name in ("out_proj.weight", "fc2.weight"): # residual rescale (GPT-2/Mamba kuralı)
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(2 * n_layer)
class HybridLM(nn.Module):
def __init__(self, cfg, device=None, dtype=None):
super().__init__()
self.cfg = cfg
self.vocab = cfg["vocab_size"]
self.scaled_embed = cfg.get("scaled_embed", False)
self.z_loss = cfg.get("z_loss", 1e-4)
d = cfg["d_model"]
self.embedding = nn.Embedding(self.vocab, d, device=device, dtype=dtype)
self.layers = nn.ModuleList()
self.attn_idx = []
for i in range(cfg["n_layers"]):
is_attn = ((i + 1) % cfg["attn_every"] == 0) and i != 0 and i != cfg["n_layers"] - 1
fk = {"device": device, "dtype": dtype}
if is_attn:
mixer_cls = partial(GQAMixer, n_heads=cfg["n_heads"], n_kv=cfg["n_kv_heads"],
layer_idx=i, **fk)
self.attn_idx.append(i)
else:
ssm = dict(d_state=cfg["d_state"], expand=cfg["expand"], headdim=cfg["head_dim"],
ngroups=cfg["ngroups"], rope_fraction=cfg["rope_fraction"],
is_outproj_norm=False, is_mimo=cfg["is_mimo"], mimo_rank=cfg["mimo_rank"],
chunk_size=cfg["chunk_size"])
mixer_cls = partial(Mamba3, layer_idx=i, **ssm, **fk)
blk = Block(d, mixer_cls,
partial(GatedMLP, hidden_features=cfg["d_intermediate"], out_features=d, **fk),
norm_cls=partial(RMSNorm, eps=1e-5, **fk),
fused_add_norm=True, residual_in_fp32=True)
blk.layer_idx = i
self.layers.append(blk)
self.norm_f = RMSNorm(d, eps=1e-5, device=device, dtype=dtype)
self.lm_head = nn.Linear(d, self.vocab, bias=False, device=device, dtype=dtype)
self.apply(partial(_init_weights, n_layer=cfg["n_layers"]))
self.lm_head.weight = self.embedding.weight # tied (init sonrası)
def forward(self, ids, labels=None):
h = self.embedding(ids)
if self.scaled_embed:
h = h * (self.cfg["d_model"] ** 0.5)
res = None
for l in self.layers:
h, res = l(h, res)
h = self.norm_f((h + res) if res is not None else h)
logits = self.lm_head(h.to(self.lm_head.weight.dtype))
loss = None
if labels is not None:
sl = logits[:, :-1].reshape(-1, self.vocab).float()
tl = labels[:, 1:].reshape(-1)
loss = F.cross_entropy(sl, tl, ignore_index=-100)
if self.z_loss > 0:
z = torch.logsumexp(sl, dim=-1)
loss = loss + self.z_loss * (z ** 2).mean()
return logits, loss
def n_params(m):
seen, t = set(), 0
for p in m.parameters():
if id(p) in seen:
continue
seen.add(id(p)); t += p.numel()
return t
# ───────────────────────── veri (resumable shard) ─────────────────────────
import pyarrow.parquet as pq
MIXES = {
"177m": {"en_fineweb_edu": 0.55, "tr_fineweb2_hq": 0.22, "code_codeparrot": 0.13, "math_openwebmath": 0.10},
"350m": {"en_fineweb_edu": 0.47, "tr_tc100b": 0.30, "code_codeparrot": 0.13, "math_openwebmath": 0.10}, # v1.5b: TR↑ TC-100B
}
def ensure_local_data(data, token):
"""HF repo ise yerele indir (resumable + hızlı), zaten yerel dizinse aynen döndür."""
if os.path.isdir(data):
return data
from huggingface_hub import snapshot_download
print(f"[veri] {data} indiriliyor (snapshot)...", flush=True)
p = snapshot_download(data, repo_type="dataset", token=token, allow_patterns=["*/shard_*.parquet"])
print(f"[veri] indirildi: {p}", flush=True)
return p
class ShardStream:
"""Yerel parquet'ten oranlı, DETERMİNİSTİK ve RESUMABLE okuma (cursor + RNG kaydedilir)."""
def __init__(self, root, seq_len, mix, seed=42):
self.names = list(mix); self.w = [mix[n] for n in self.names]
self.seq_len = seq_len
self.files = {n: sorted(glob.glob(os.path.join(root, n, "shard_*.parquet"))) for n in self.names}
for n in self.names:
assert self.files[n], f"shard yok: {root}/{n}"
self.cursor = {n: [0, 0] for n in self.names} # [shard_idx, row_idx]
self.cache = {}
self.rng = random.Random(seed)
def _rows(self, n):
si = self.cursor[n][0] % len(self.files[n])
if self.cache.get(n, (None,))[0] != si:
# arrow kolonu olduğu gibi tut (to_pylist YOK → ~1GB bellek/şard spike'ı önlenir)
col = pq.read_table(self.files[n][si], columns=["input_ids"]).column("input_ids")
self.cache[n] = (si, col)
return self.cache[n][1]
def _next(self, n):
rows = self._rows(n)
if self.cursor[n][1] >= len(rows):
self.cursor[n][0] += 1; self.cursor[n][1] = 0; rows = self._rows(n)
ri = self.cursor[n][1]; self.cursor[n][1] = ri + 1
return rows[ri].as_py()[:self.seq_len] # tek satırı listeye çevir
def batch(self, bsz, device):
rows = [self._next(self.rng.choices(self.names, weights=self.w, k=1)[0]) for _ in range(bsz)]
return torch.tensor(rows, dtype=torch.long, device=device)
def state(self):
# derin kopya: cursor mutable; aksi halde sonraki batch() kaydı bozar
return {"cursor": {k: list(v) for k, v in self.cursor.items()}, "rng": self.rng.getstate()}
def load_state(self, s):
self.cursor = {k: list(v) for k, v in s["cursor"].items()}
self.rng.setstate(s["rng"]); self.cache = {}
# ───────────────────────── WSD LR ─────────────────────────
def wsd_lr(step, total, peak, floor, warmup, decay_frac=0.25):
if step < warmup:
return peak * (step + 1) / warmup
dec = int(total * (1 - decay_frac))
if step < dec:
return peak
return peak - (peak - floor) * (step - dec) / max(1, total - dec)
# ───────────────────────── checkpoint + async HF push ─────────────────────────
class Ckpt:
def __init__(self, local_dir, repo_id, token, keep=3):
self.dir = local_dir; self.repo = repo_id; self.keep = keep
os.makedirs(local_dir, exist_ok=True)
self.api = None
if repo_id and token:
from huggingface_hub import HfApi
self.api = HfApi(token=token)
self.ex = ThreadPoolExecutor(max_workers=1); self.lock = threading.Lock()
def save(self, step, model, opts, stream, extra):
d = os.path.join(self.dir, f"step_{step:06d}"); os.makedirs(d, exist_ok=True)
torch.save({"model": model.state_dict(), "opt": [o.state_dict() for o in opts], "step": step,
"stream": stream.state(), "torch_rng": torch.get_rng_state(),
"cuda_rng": torch.cuda.get_rng_state_all(), **extra},
os.path.join(d, "ckpt.pt"))
self._rotate()
if self.api:
self.ex.submit(self._push, d, step)
print(f"[ckpt] kaydedildi step {step} -> {d}", flush=True)
def _push(self, d, step):
try:
with self.lock:
self.api.upload_folder(folder_path=d, repo_id=self.repo, repo_type="model",
path_in_repo=f"checkpoints/step_{step:06d}",
commit_message=f"ckpt step {step}")
print(f"[ckpt] HF push OK step {step}", flush=True)
except Exception as e:
print(f"[ckpt] HF push HATA step {step}: {repr(e)[:160]}", flush=True)
def _rotate(self):
ds = sorted(glob.glob(os.path.join(self.dir, "step_*")))
for old in ds[:-self.keep]:
for f in glob.glob(os.path.join(old, "*")):
os.remove(f)
os.rmdir(old)
def latest_local(self):
ds = sorted(glob.glob(os.path.join(self.dir, "step_*", "ckpt.pt")))
return ds[-1] if ds else None
def latest_hf(self):
if not self.api:
return None
from huggingface_hub import hf_hub_download
files = [f for f in self.api.list_repo_files(self.repo, repo_type="model")
if f.startswith("checkpoints/step_") and f.endswith("ckpt.pt")]
if not files:
return None
latest = max(files) # step_NNNNNN sıralı
return hf_hub_download(self.repo, latest, repo_type="model")
# ───────────────────────── train ─────────────────────────
# ───────────────────────── Muon optimizer (v1.5: 2D-Linear ağırlıkları için) ─────────────────────────
def _ns5(G, steps=5):
"""Newton-Schulz orthogonalizasyon (Muon çekirdeği) — G'yi yarı-ortogonale yaklaştır."""
a, b, c = 3.4445, -4.7750, 2.0315
X = G.bfloat16()
t = G.size(-2) > G.size(-1)
if t:
X = X.mT
X = X / (X.norm() + 1e-7)
for _ in range(steps):
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
if t:
X = X.mT
return X.to(G.dtype)
class Muon(torch.optim.Optimizer):
"""Momentum + Newton-Schulz ortogonalize güncelleme (Keller Jordan). 2D matris ağırlıkları için."""
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, weight_decay=0.0):
super().__init__(params, dict(lr=lr, momentum=momentum, nesterov=nesterov,
ns_steps=ns_steps, weight_decay=weight_decay))
@torch.no_grad()
def step(self):
for grp in self.param_groups:
lr, mom, nest = grp["lr"], grp["momentum"], grp["nesterov"]
ns, wd = grp["ns_steps"], grp["weight_decay"]
for p in grp["params"]:
if p.grad is None:
continue
st = self.state[p]
if "buf" not in st:
st["buf"] = torch.zeros_like(p.grad)
buf = st["buf"]; buf.lerp_(p.grad, 1 - mom)
u = p.grad.lerp_(buf, mom) if nest else buf
u = _ns5(u, ns) * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
if wd:
p.mul_(1 - lr * wd)
p.add_(u, alpha=-lr)
def build_optimizers(model, args):
"""args.muon ise [Muon(2D-Linear), AdamW(embed+3D / 1D)]; değilse [AdamW] (orijinal).
Dönüş (opts, base_lrs) — base_lr WSD çarpanıyla (0→1→0.1) ölçeklenir."""
if args.muon:
emb = {id(m.weight) for m in model.modules() if isinstance(m, nn.Embedding)}
muon_p, adam_wd, adam_nod = [], [], []
for p in model.parameters():
if p.ndim == 2 and id(p) not in emb:
muon_p.append(p) # gizli Linear ağırlıkları → Muon
elif p.ndim >= 2:
adam_wd.append(p) # embedding (2D, tied lm_head) + 3D bias (B/C_bias)
else:
adam_nod.append(p) # 1D (norm, dt_bias, D, A_log, qn/kn…)
o_m = Muon(muon_p, lr=args.muon_lr, momentum=0.95, ns_steps=5, weight_decay=0.0)
o_a = torch.optim.AdamW([{"params": adam_wd, "weight_decay": 0.1},
{"params": adam_nod, "weight_decay": 0.0}],
lr=args.peak_lr, betas=(0.9, 0.95), eps=1e-8, fused=True)
mp = sum(p.numel() for p in muon_p); ap_ = sum(p.numel() for p in adam_wd + adam_nod)
print(f"[opt] MUON {mp/1e6:.1f}M (2D-Linear) + AdamW {ap_/1e6:.1f}M (embed+norm) | "
f"muon_lr {args.muon_lr} | peak_lr {args.peak_lr}", flush=True)
return [o_m, o_a], [args.muon_lr, args.peak_lr]
decay = [p for p in model.parameters() if p.ndim >= 2]
nod = [p for p in model.parameters() if p.ndim < 2]
o_a = torch.optim.AdamW([{"params": decay, "weight_decay": 0.1},
{"params": nod, "weight_decay": 0.0}],
lr=args.peak_lr, betas=(0.9, 0.95), eps=1e-8, fused=True)
print(f"[opt] AdamW (tek) | peak_lr {args.peak_lr}", flush=True)
return [o_a], [args.peak_lr]
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--data", default="kdirgul/smartcore-v1-data")
ap.add_argument("--ckpt_repo", default="kdirgul/smartcore-v1")
ap.add_argument("--ckpt_dir", default="/content/ckpt")
ap.add_argument("--resume", default=None, help="latest_local | latest_hf | <path>")
ap.add_argument("--preset", default="177m", choices=["177m", "350m"])
ap.add_argument("--n_layers", type=int, default=None) # None → preset; override için ver
ap.add_argument("--d_model", type=int, default=None)
ap.add_argument("--d_intermediate", type=int, default=None)
ap.add_argument("--n_heads", type=int, default=None)
ap.add_argument("--n_kv_heads", type=int, default=None)
ap.add_argument("--attn_every", type=int, default=6)
ap.add_argument("--seq_len", type=int, default=2048)
ap.add_argument("--micro_batch", type=int, default=4)
ap.add_argument("--grad_accum", type=int, default=64) # 4*64*2048 = 524288 tok/step
ap.add_argument("--total_tokens", type=float, default=12e9)
ap.add_argument("--peak_lr", type=float, default=5e-4)
ap.add_argument("--muon", action="store_true", help="v1.5: 2D-Linear ağırlıkları Muon (embed/norm AdamW)")
ap.add_argument("--muon_lr", type=float, default=0.02)
ap.add_argument("--warmup", type=int, default=600)
ap.add_argument("--save_every", type=int, default=500) # opt-step
ap.add_argument("--log_every", type=int, default=10)
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
dev = torch.device("cuda")
torch.manual_seed(args.seed); random.seed(args.seed)
torch.set_float32_matmul_precision("high")
token = os.environ.get("HF_TOKEN")
P = dict(PRESETS[args.preset])
for k in ("n_layers", "d_model", "d_intermediate", "n_heads", "n_kv_heads"):
if getattr(args, k) is not None:
P[k] = getattr(args, k)
cfg = dict(vocab_size=48000, d_model=P["d_model"], n_layers=P["n_layers"], d_state=128, expand=2,
head_dim=P["head_dim"], ngroups=1, d_intermediate=P["d_intermediate"], attn_every=args.attn_every,
n_heads=P["n_heads"], n_kv_heads=P["n_kv_heads"], rope_fraction=0.5, is_mimo=False, mimo_rank=1,
chunk_size=128, scaled_embed=False, z_loss=1e-4)
print(f"[cfg] preset={args.preset} | d_model={cfg['d_model']} n_layers={cfg['n_layers']} "
f"d_int={cfg['d_intermediate']} {cfg['n_heads']}/{cfg['n_kv_heads']} GQA attn_every={cfg['attn_every']}", flush=True)
model = HybridLM(cfg, device=dev, dtype=torch.bfloat16)
print(f"[model] {n_params(model)/1e6:.1f}M | {cfg['n_layers']-len(model.attn_idx)} Mamba + "
f"{len(model.attn_idx)} GQA (attn@{model.attn_idx})", flush=True)
opts, base_lrs = build_optimizers(model, args)
batch_tok = args.micro_batch * args.grad_accum * args.seq_len
total_steps = int(args.total_tokens / batch_tok)
print(f"[plan] {args.total_tokens/1e9:.0f}B token | {batch_tok} tok/step | {total_steps} step | "
f"warmup {args.warmup} | peak {args.peak_lr}", flush=True)
root = ensure_local_data(args.data, token)
mix = MIXES.get(args.preset, MIXES["177m"]) # preset'e göre karışım (350m: TR↑ TC-100B)
print(f"[mix] {args.preset}: " + " ".join(f"{k}={v:.0%}" for k, v in mix.items()), flush=True)
stream = ShardStream(root, args.seq_len, mix, seed=args.seed)
ckpt = Ckpt(args.ckpt_dir, args.ckpt_repo, token)
start_step = 0
if args.resume:
path = (ckpt.latest_local() if args.resume == "latest_local" else
ckpt.latest_hf() if args.resume == "latest_hf" else args.resume)
if path and os.path.exists(path):
st = torch.load(path, map_location="cpu")
model.load_state_dict(st["model"])
osd = st["opt"] if isinstance(st["opt"], list) else [st["opt"]]
for o, s in zip(opts, osd):
o.load_state_dict(s)
stream.load_state(st["stream"]); start_step = st["step"] + 1
torch.set_rng_state(st["torch_rng"]); torch.cuda.set_rng_state_all(st["cuda_rng"])
print(f"[resume] {path} -> step {start_step}", flush=True)
else:
print(f"[resume] checkpoint bulunamadı ({args.resume}) — sıfırdan başlıyor", flush=True)
# SIGTERM/SIGINT -> acil kayıt (Colab disconnect güvenlik ağı)
cur = {"step": start_step}
def emergency(signum, frame):
ckpt.save(cur["step"], model, opts, stream, {"cfg": cfg})
sys.exit(0)
signal.signal(signal.SIGTERM, emergency); signal.signal(signal.SIGINT, emergency)
model.train()
t0 = time.perf_counter(); seen = 0
for step in range(start_step, total_steps):
cur["step"] = step
frac = wsd_lr(step, total_steps, 1.0, 0.1, args.warmup) # 0→1→0.1 çarpan (her opt kendi base_lr'iyle)
for o, b in zip(opts, base_lrs):
for g in o.param_groups:
g["lr"] = b * frac
for o in opts:
o.zero_grad(set_to_none=True)
loss_acc = 0.0
for _ in range(args.grad_accum):
batch = stream.batch(args.micro_batch, dev)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
_, loss = model(batch, labels=batch)
(loss / args.grad_accum).backward()
loss_acc += loss.item() / args.grad_accum
seen += batch.numel()
gn = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
for o in opts:
o.step()
if step % args.log_every == 0:
tok_s = seen / (time.perf_counter() - t0)
print(f"step {step:6d}/{total_steps} | loss {loss_acc:.4f} | gnorm {gn:5.2f} | "
f"lr {base_lrs[-1]*frac:.2e} | {tok_s/1e3:.1f}k tok/s | {seen/1e9:.3f}B tok", flush=True)
if step > start_step and step % args.save_every == 0:
ckpt.save(step, model, opts, stream, {"cfg": cfg})
ckpt.save(total_steps - 1, model, opts, stream, {"cfg": cfg, "final": True})
print("[bitti] pretraining tamamlandı.", flush=True)
if __name__ == "__main__":
main()