"""DeltaLens training script. Usage: python train.py --data_path /path/to/train.pt --val_path /path/to/val.pt Data format: torch tensor of shape (num_sequences, seq_len) with token IDs. """ import sys, os, math, time, glob, argparse, signal import torch import wandb _SHOULD_STOP = False def _sigterm_handler(signum, frame): global _SHOULD_STOP print(f"\n[SIGTERM] Saving checkpoint and exiting...") _SHOULD_STOP = True signal.signal(signal.SIGTERM, _sigterm_handler) def get_lr(step, total_steps, warmup_steps, lr_max, lr_min): if step < warmup_steps: return lr_max * step / warmup_steps progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress)) def save_checkpoint(model, optimizer, step, global_tokens, path): torch.save({ "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "step": step, "global_tokens": global_tokens, }, path) size_mb = os.path.getsize(path) / 1e6 print(f" Checkpoint: {path} ({size_mb:.0f}MB, step={step})") @torch.no_grad() def evaluate(model, val_data, max_docs=200): model.eval() total_loss = 0.0 total_tokens = 0 for i in range(min(len(val_data), max_docs)): input_ids = val_data[i:i+1].long().cuda() out = model(input_ids=input_ids, labels=input_ids) n = input_ids.numel() total_loss += out.loss.item() * n total_tokens += n model.train() return math.exp(total_loss / total_tokens), total_loss / total_tokens def main(): global _SHOULD_STOP parser = argparse.ArgumentParser() parser.add_argument("--exp_id", default="DeltaLens-1.3B") parser.add_argument("--data_path", required=True) parser.add_argument("--val_path", required=True) parser.add_argument("--ckpt_dir", default="./checkpoints") parser.add_argument("--total_tokens", type=int, default=1_000_000_000) parser.add_argument("--micro_bs", type=int, default=2) parser.add_argument("--grad_accum", type=int, default=256) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--lr_min", type=float, default=3e-5) parser.add_argument("--d_model", type=int, default=2048) parser.add_argument("--d_state", type=int, default=512) parser.add_argument("--n_layers", type=int, default=24) parser.add_argument("--n_heads", type=int, default=16) parser.add_argument("--vocab_size", type=int, default=32000) args = parser.parse_args() SEQ_LEN = 2048 EFFECTIVE_BS = args.micro_bs * args.grad_accum TOKENS_PER_STEP = EFFECTIVE_BS * SEQ_LEN TOTAL_STEPS = args.total_tokens // TOKENS_PER_STEP WARMUP_RATIO = 0.03 os.makedirs(args.ckpt_dir, exist_ok=True) print(f"=== {args.exp_id} ===") print(f" Total: {args.total_tokens/1e9:.0f}B tokens, {TOTAL_STEPS} steps") print(f" Effective BS: {EFFECTIVE_BS}") from deltalens_layer import DeltaLensModel model = DeltaLensModel( vocab_size=args.vocab_size, d_model=args.d_model, n_layers=args.n_layers, d_state=args.d_state, n_heads=args.n_heads, max_seq_len=SEQ_LEN, ).to(torch.bfloat16).cuda() total_params = sum(p.numel() for p in model.parameters()) print(f" Params: {total_params:,} ({total_params*2/1e9:.2f}GB)") print("\nLoading data...") train_data = torch.load(args.data_path, mmap=True) val_data = torch.load(args.val_path, mmap=True) print(f" Train: {len(train_data):,}, Val: {len(val_data):,}") optimizer = torch.optim.AdamW( model.parameters(), lr=args.lr, weight_decay=0.01, betas=(0.9, 0.95), eps=1e-8, ) warmup_steps = max(1, int(TOTAL_STEPS * WARMUP_RATIO)) # Resume start_step = 0 global_tokens = 0 ckpts = sorted(glob.glob(os.path.join(args.ckpt_dir, "ckpt_*.pt"))) if ckpts: print(f"Resuming from {ckpts[-1]}") ckpt = torch.load(ckpts[-1], map_location="cpu") model.load_state_dict(ckpt["model_state"]) optimizer.load_state_dict(ckpt["optimizer_state"]) start_step = ckpt["step"] + 1 global_tokens = ckpt["global_tokens"] del ckpt wandb.init(project="deltalens", name=args.exp_id, config=vars(args), resume="allow") model.train() EVAL_EVERY = args.total_tokens // 10 // TOKENS_PER_STEP step_time_start = time.time() for step in range(start_step, TOTAL_STEPS): optimizer.zero_grad(set_to_none=True) step_loss = 0.0 for micro in range(args.grad_accum): seq_idx = step * EFFECTIVE_BS + micro * args.micro_bs input_ids = train_data[seq_idx : seq_idx + args.micro_bs].long().cuda() out = model(input_ids=input_ids, labels=input_ids) loss = out.loss / args.grad_accum loss.backward() step_loss += loss.item() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item() optimizer.step() lr = get_lr(step, TOTAL_STEPS, warmup_steps, args.lr, args.lr_min) for pg in optimizer.param_groups: pg["lr"] = lr global_tokens += TOKENS_PER_STEP if step % 10 == 0: elapsed = time.time() - step_time_start tps = (10 * TOKENS_PER_STEP) / max(elapsed, 1) if step > start_step else 0 wandb.log({"train/loss": step_loss, "train/lr": lr, "train/tokens": global_tokens, "train/grad_norm": grad_norm, "train/tokens_per_sec": tps, "step": step}) print(f" step {step}/{TOTAL_STEPS} | loss {step_loss:.4f} | " f"lr {lr:.2e} | gnorm {grad_norm:.3f} | {tps:.0f} tok/s", flush=True) step_time_start = time.time() if step > 0 and step % EVAL_EVERY == 0: ppl, eval_loss = evaluate(model, val_data) wandb.log({"eval/val_ppl": ppl, "eval/val_loss": eval_loss, "step": step}) print(f" [EVAL] step {step} | val_ppl {ppl:.2f}", flush=True) if step > 0 and step % 100 == 0: ckpt_path = os.path.join(args.ckpt_dir, f"ckpt_s{step:06d}.pt") save_checkpoint(model, optimizer, step, global_tokens, ckpt_path) ckpts = sorted(glob.glob(os.path.join(args.ckpt_dir, "ckpt_*.pt"))) for old in ckpts[:-2]: os.remove(old) if _SHOULD_STOP: ckpt_path = os.path.join(args.ckpt_dir, f"ckpt_s{step:06d}.pt") save_checkpoint(model, optimizer, step, global_tokens, ckpt_path) wandb.finish() return # Final save print("\n=== Training complete! ===") ckpt_path = os.path.join(args.ckpt_dir, f"ckpt_s{TOTAL_STEPS:06d}_final.pt") save_checkpoint(model, optimizer, TOTAL_STEPS, global_tokens, ckpt_path) ppl, eval_loss = evaluate(model, val_data) wandb.log({"eval/val_ppl": ppl, "step": TOTAL_STEPS}) print(f"[FINAL] val_ppl {ppl:.2f}") wandb.finish() if __name__ == "__main__": main()