mythos-rdt / train.py
Raidone's picture
MYTHOS-RDT β€” Recurrent-Depth Transformer. Ψ¨Ψ³Ω… Ψ§Ω„Ω„Ω‡ Ψ§Ω„Ψ±Ψ­Ω…Ω† Ψ§Ω„Ψ±Ψ­ΩŠΩ…
4cf6c82 verified
Raw
History Blame Contribute Delete
3.99 kB
"""
MYTHOS-RDT Training β€” Recurrent-Depth Transformer
Usage: python3 train.py [--quick] [--epochs N]
"""
import sys; from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
sys.path.insert(0, str(Path(__file__).parent.parent / "mythos-rdt"))
from raiai.train import get_device, RaidDataset as DS
import importlib.util
# Load mythos-rdt module directly (hyphen in dir name)
spec = importlib.util.spec_from_file_location("mythos_rdt_model", Path(__file__).parent / "model.py")
mythos_mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mythos_mod)
from shared.tokenizer import RaidTokenizer
import torch, argparse, math, time
M = mythos_mod.MythosRDTModel
C = mythos_mod.MythosRDTConfig
def train(args):
dev = get_device(args.device)
cfg = C()
if args.quick:
cfg.dim=384; cfg.prelude_layers=1; cfg.coda_layers=1; cfg.max_loops=4; cfg.min_loops=1
cfg.n_heads=8; cfg.n_kv_heads=2; cfg.max_seq_len=256; cfg.vocab_size=2048
cfg.ffn_dim=768; cfg.expert_dim=128; cfg.n_experts=4
base = Path(__file__).parent.parent
ds_path = base.parent/"datasets"/"enhanced"/"raiai.json"
if not ds_path.exists(): ds_path = base.parent/"datasets"/"raiai_0.1_orchestrator.json"
tok = RaidTokenizer(); tok.load(str(base/"shared"/"tokenizer.json"))
ds = DS(str(ds_path), tok, cfg.max_seq_len)
bs=2 if args.quick else args.batch_size; ga=2 if args.quick else args.grad_accum
dl = torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, drop_last=len(ds)>=bs, num_workers=0)
model = M(cfg).to(dev); model.train()
n = sum(p.numel() for p in model.parameters())
print(f"MYTHOS-RDT πŸŒ€ β€” {'QUICK' if args.quick else 'TRAINING'}")
print(f" Device: {dev} | Params: {n/1e6:.1f}M | Loops: {cfg.min_loops}-{cfg.max_loops}")
print(f" Dataset: {len(ds)} | Batch: {bs}Γ—{ga} | MoE: {cfg.n_experts}e x{cfg.experts_per_tok}")
epochs=1 if args.quick else args.epochs; lr=1e-3 if args.quick else args.lr
opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9,0.95), weight_decay=0.1)
sp = len(dl)//max(1,ga); tot = sp*epochs; wu = tot//10
glr = lambda s: (s+1)/max(1,wu) if s<wu else max(0.1,0.5*(1+math.cos(math.pi*(s-wu)/max(1,tot-wu))))
out = base/"mythos-rdt"/"checkpoints"; out.mkdir(exist_ok=True)
step=0; best=float('inf'); t0=time.time()
try:
for ep in range(epochs):
el=0
for i,(x,y) in enumerate(dl):
x,y=x.to(dev),y.to(dev)
loss = torch.nn.functional.cross_entropy(
model(x)["logits"].view(-1,cfg.vocab_size), y.view(-1), ignore_index=0)/ga
loss.backward()
if (i+1)%ga==0:
torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
for pg in opt.param_groups: pg['lr']=lr*glr(step)
opt.step(); opt.zero_grad(); step+=1
el+=loss.item()*ga
if step>0 and step%max(1,args.save_every)==0:
print(f" Step {step:4d}/{tot} | Loss: {el/(i+1):.4f} | {time.time()-t0:.0f}s")
print(f" βœ… Epoch {ep+1}/{epochs} | Loss: {el/len(dl):.4f}")
torch.save({"step":step,"model":model.state_dict(),"config":cfg.__dict__}, str(out/f"epoch_{ep+1}.pt"))
print(f"\nβœ… Mythos-RDT training completo! {out}")
except KeyboardInterrupt:
torch.save({"step":step,"model":model.state_dict(),"config":cfg.__dict__}, str(out/"interrupted.pt"))
print("\nβœ… Salvato")
if __name__=="__main__":
p=argparse.ArgumentParser()
p.add_argument("--quick",action="store_true"); p.add_argument("--device",default="auto")
p.add_argument("--batch_size",type=int,default=2); p.add_argument("--epochs",type=int,default=10)
p.add_argument("--lr",type=float,default=3e-4); p.add_argument("--grad_accum",type=int,default=4)
p.add_argument("--save_every",type=int,default=200)
train(p.parse_args())