""" MYTHOS-RDT Training — Recurrent-Depth Transformer Usage: python3 train.py [--quick] [--epochs N] """ import sys; from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent / "mythos-rdt")) from raiai.train import get_device, RaidDataset as DS import importlib.util # Load mythos-rdt module directly (hyphen in dir name) spec = importlib.util.spec_from_file_location("mythos_rdt_model", Path(__file__).parent / "model.py") mythos_mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mythos_mod) from shared.tokenizer import RaidTokenizer import torch, argparse, math, time M = mythos_mod.MythosRDTModel C = mythos_mod.MythosRDTConfig def train(args): dev = get_device(args.device) cfg = C() if args.quick: cfg.dim=384; cfg.prelude_layers=1; cfg.coda_layers=1; cfg.max_loops=4; cfg.min_loops=1 cfg.n_heads=8; cfg.n_kv_heads=2; cfg.max_seq_len=256; cfg.vocab_size=2048 cfg.ffn_dim=768; cfg.expert_dim=128; cfg.n_experts=4 base = Path(__file__).parent.parent ds_path = base.parent/"datasets"/"enhanced"/"raiai.json" if not ds_path.exists(): ds_path = base.parent/"datasets"/"raiai_0.1_orchestrator.json" tok = RaidTokenizer(); tok.load(str(base/"shared"/"tokenizer.json")) ds = DS(str(ds_path), tok, cfg.max_seq_len) bs=2 if args.quick else args.batch_size; ga=2 if args.quick else args.grad_accum dl = torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, drop_last=len(ds)>=bs, num_workers=0) model = M(cfg).to(dev); model.train() n = sum(p.numel() for p in model.parameters()) print(f"MYTHOS-RDT šŸŒ€ — {'QUICK' if args.quick else 'TRAINING'}") print(f" Device: {dev} | Params: {n/1e6:.1f}M | Loops: {cfg.min_loops}-{cfg.max_loops}") print(f" Dataset: {len(ds)} | Batch: {bs}Ɨ{ga} | MoE: {cfg.n_experts}e x{cfg.experts_per_tok}") epochs=1 if args.quick else args.epochs; lr=1e-3 if args.quick else args.lr opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9,0.95), weight_decay=0.1) sp = len(dl)//max(1,ga); tot = sp*epochs; wu = tot//10 glr = lambda s: (s+1)/max(1,wu) if s0 and step%max(1,args.save_every)==0: print(f" Step {step:4d}/{tot} | Loss: {el/(i+1):.4f} | {time.time()-t0:.0f}s") print(f" āœ… Epoch {ep+1}/{epochs} | Loss: {el/len(dl):.4f}") torch.save({"step":step,"model":model.state_dict(),"config":cfg.__dict__}, str(out/f"epoch_{ep+1}.pt")) print(f"\nāœ… Mythos-RDT training completo! {out}") except KeyboardInterrupt: torch.save({"step":step,"model":model.state_dict(),"config":cfg.__dict__}, str(out/"interrupted.pt")) print("\nāœ… Salvato") if __name__=="__main__": p=argparse.ArgumentParser() p.add_argument("--quick",action="store_true"); p.add_argument("--device",default="auto") p.add_argument("--batch_size",type=int,default=2); p.add_argument("--epochs",type=int,default=10) p.add_argument("--lr",type=float,default=3e-4); p.add_argument("--grad_accum",type=int,default=4) p.add_argument("--save_every",type=int,default=200) train(p.parse_args())