| """DeltaLens training script. |
| |
| Usage: |
| python train.py --data_path /path/to/train.pt --val_path /path/to/val.pt |
| |
| Data format: torch tensor of shape (num_sequences, seq_len) with token IDs. |
| """ |
| import sys, os, math, time, glob, argparse, signal |
| import torch |
| import wandb |
|
|
| _SHOULD_STOP = False |
| def _sigterm_handler(signum, frame): |
| global _SHOULD_STOP |
| print(f"\n[SIGTERM] Saving checkpoint and exiting...") |
| _SHOULD_STOP = True |
| signal.signal(signal.SIGTERM, _sigterm_handler) |
|
|
|
|
| def get_lr(step, total_steps, warmup_steps, lr_max, lr_min): |
| if step < warmup_steps: |
| return lr_max * step / warmup_steps |
| progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) |
| return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress)) |
|
|
|
|
| def save_checkpoint(model, optimizer, step, global_tokens, path): |
| torch.save({ |
| "model_state": model.state_dict(), |
| "optimizer_state": optimizer.state_dict(), |
| "step": step, |
| "global_tokens": global_tokens, |
| }, path) |
| size_mb = os.path.getsize(path) / 1e6 |
| print(f" Checkpoint: {path} ({size_mb:.0f}MB, step={step})") |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, val_data, max_docs=200): |
| model.eval() |
| total_loss = 0.0 |
| total_tokens = 0 |
| for i in range(min(len(val_data), max_docs)): |
| input_ids = val_data[i:i+1].long().cuda() |
| out = model(input_ids=input_ids, labels=input_ids) |
| n = input_ids.numel() |
| total_loss += out.loss.item() * n |
| total_tokens += n |
| model.train() |
| return math.exp(total_loss / total_tokens), total_loss / total_tokens |
|
|
|
|
| def main(): |
| global _SHOULD_STOP |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--exp_id", default="DeltaLens-1.3B") |
| parser.add_argument("--data_path", required=True) |
| parser.add_argument("--val_path", required=True) |
| parser.add_argument("--ckpt_dir", default="./checkpoints") |
| parser.add_argument("--total_tokens", type=int, default=1_000_000_000) |
| parser.add_argument("--micro_bs", type=int, default=2) |
| parser.add_argument("--grad_accum", type=int, default=256) |
| parser.add_argument("--lr", type=float, default=3e-4) |
| parser.add_argument("--lr_min", type=float, default=3e-5) |
| parser.add_argument("--d_model", type=int, default=2048) |
| parser.add_argument("--d_state", type=int, default=512) |
| parser.add_argument("--n_layers", type=int, default=24) |
| parser.add_argument("--n_heads", type=int, default=16) |
| parser.add_argument("--vocab_size", type=int, default=32000) |
| args = parser.parse_args() |
|
|
| SEQ_LEN = 2048 |
| EFFECTIVE_BS = args.micro_bs * args.grad_accum |
| TOKENS_PER_STEP = EFFECTIVE_BS * SEQ_LEN |
| TOTAL_STEPS = args.total_tokens // TOKENS_PER_STEP |
| WARMUP_RATIO = 0.03 |
|
|
| os.makedirs(args.ckpt_dir, exist_ok=True) |
|
|
| print(f"=== {args.exp_id} ===") |
| print(f" Total: {args.total_tokens/1e9:.0f}B tokens, {TOTAL_STEPS} steps") |
| print(f" Effective BS: {EFFECTIVE_BS}") |
|
|
| from deltalens_layer import DeltaLensModel |
|
|
| model = DeltaLensModel( |
| vocab_size=args.vocab_size, |
| d_model=args.d_model, |
| n_layers=args.n_layers, |
| d_state=args.d_state, |
| n_heads=args.n_heads, |
| max_seq_len=SEQ_LEN, |
| ).to(torch.bfloat16).cuda() |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f" Params: {total_params:,} ({total_params*2/1e9:.2f}GB)") |
|
|
| print("\nLoading data...") |
| train_data = torch.load(args.data_path, mmap=True) |
| val_data = torch.load(args.val_path, mmap=True) |
| print(f" Train: {len(train_data):,}, Val: {len(val_data):,}") |
|
|
| optimizer = torch.optim.AdamW( |
| model.parameters(), lr=args.lr, weight_decay=0.01, |
| betas=(0.9, 0.95), eps=1e-8, |
| ) |
| warmup_steps = max(1, int(TOTAL_STEPS * WARMUP_RATIO)) |
|
|
| |
| start_step = 0 |
| global_tokens = 0 |
| ckpts = sorted(glob.glob(os.path.join(args.ckpt_dir, "ckpt_*.pt"))) |
| if ckpts: |
| print(f"Resuming from {ckpts[-1]}") |
| ckpt = torch.load(ckpts[-1], map_location="cpu") |
| model.load_state_dict(ckpt["model_state"]) |
| optimizer.load_state_dict(ckpt["optimizer_state"]) |
| start_step = ckpt["step"] + 1 |
| global_tokens = ckpt["global_tokens"] |
| del ckpt |
|
|
| wandb.init(project="deltalens", name=args.exp_id, |
| config=vars(args), resume="allow") |
|
|
| model.train() |
| EVAL_EVERY = args.total_tokens // 10 // TOKENS_PER_STEP |
| step_time_start = time.time() |
|
|
| for step in range(start_step, TOTAL_STEPS): |
| optimizer.zero_grad(set_to_none=True) |
| step_loss = 0.0 |
|
|
| for micro in range(args.grad_accum): |
| seq_idx = step * EFFECTIVE_BS + micro * args.micro_bs |
| input_ids = train_data[seq_idx : seq_idx + args.micro_bs].long().cuda() |
| out = model(input_ids=input_ids, labels=input_ids) |
| loss = out.loss / args.grad_accum |
| loss.backward() |
| step_loss += loss.item() |
|
|
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item() |
| optimizer.step() |
|
|
| lr = get_lr(step, TOTAL_STEPS, warmup_steps, args.lr, args.lr_min) |
| for pg in optimizer.param_groups: |
| pg["lr"] = lr |
|
|
| global_tokens += TOKENS_PER_STEP |
|
|
| if step % 10 == 0: |
| elapsed = time.time() - step_time_start |
| tps = (10 * TOKENS_PER_STEP) / max(elapsed, 1) if step > start_step else 0 |
| wandb.log({"train/loss": step_loss, "train/lr": lr, |
| "train/tokens": global_tokens, "train/grad_norm": grad_norm, |
| "train/tokens_per_sec": tps, "step": step}) |
| print(f" step {step}/{TOTAL_STEPS} | loss {step_loss:.4f} | " |
| f"lr {lr:.2e} | gnorm {grad_norm:.3f} | {tps:.0f} tok/s", flush=True) |
| step_time_start = time.time() |
|
|
| if step > 0 and step % EVAL_EVERY == 0: |
| ppl, eval_loss = evaluate(model, val_data) |
| wandb.log({"eval/val_ppl": ppl, "eval/val_loss": eval_loss, "step": step}) |
| print(f" [EVAL] step {step} | val_ppl {ppl:.2f}", flush=True) |
|
|
| if step > 0 and step % 100 == 0: |
| ckpt_path = os.path.join(args.ckpt_dir, f"ckpt_s{step:06d}.pt") |
| save_checkpoint(model, optimizer, step, global_tokens, ckpt_path) |
| ckpts = sorted(glob.glob(os.path.join(args.ckpt_dir, "ckpt_*.pt"))) |
| for old in ckpts[:-2]: |
| os.remove(old) |
|
|
| if _SHOULD_STOP: |
| ckpt_path = os.path.join(args.ckpt_dir, f"ckpt_s{step:06d}.pt") |
| save_checkpoint(model, optimizer, step, global_tokens, ckpt_path) |
| wandb.finish() |
| return |
|
|
| |
| print("\n=== Training complete! ===") |
| ckpt_path = os.path.join(args.ckpt_dir, f"ckpt_s{TOTAL_STEPS:06d}_final.pt") |
| save_checkpoint(model, optimizer, TOTAL_STEPS, global_tokens, ckpt_path) |
| ppl, eval_loss = evaluate(model, val_data) |
| wandb.log({"eval/val_ppl": ppl, "step": TOTAL_STEPS}) |
| print(f"[FINAL] val_ppl {ppl:.2f}") |
| wandb.finish() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|