dntf-architecture / src /train.py
2264K's picture
Upload folder using huggingface_hub
24b1807 verified
"""DeltaLens training script.
Usage:
python train.py --data_path /path/to/train.pt --val_path /path/to/val.pt
Data format: torch tensor of shape (num_sequences, seq_len) with token IDs.
"""
import sys, os, math, time, glob, argparse, signal
import torch
import wandb
_SHOULD_STOP = False
def _sigterm_handler(signum, frame):
global _SHOULD_STOP
print(f"\n[SIGTERM] Saving checkpoint and exiting...")
_SHOULD_STOP = True
signal.signal(signal.SIGTERM, _sigterm_handler)
def get_lr(step, total_steps, warmup_steps, lr_max, lr_min):
if step < warmup_steps:
return lr_max * step / warmup_steps
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress))
def save_checkpoint(model, optimizer, step, global_tokens, path):
torch.save({
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"step": step,
"global_tokens": global_tokens,
}, path)
size_mb = os.path.getsize(path) / 1e6
print(f" Checkpoint: {path} ({size_mb:.0f}MB, step={step})")
@torch.no_grad()
def evaluate(model, val_data, max_docs=200):
model.eval()
total_loss = 0.0
total_tokens = 0
for i in range(min(len(val_data), max_docs)):
input_ids = val_data[i:i+1].long().cuda()
out = model(input_ids=input_ids, labels=input_ids)
n = input_ids.numel()
total_loss += out.loss.item() * n
total_tokens += n
model.train()
return math.exp(total_loss / total_tokens), total_loss / total_tokens
def main():
global _SHOULD_STOP
parser = argparse.ArgumentParser()
parser.add_argument("--exp_id", default="DeltaLens-1.3B")
parser.add_argument("--data_path", required=True)
parser.add_argument("--val_path", required=True)
parser.add_argument("--ckpt_dir", default="./checkpoints")
parser.add_argument("--total_tokens", type=int, default=1_000_000_000)
parser.add_argument("--micro_bs", type=int, default=2)
parser.add_argument("--grad_accum", type=int, default=256)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--lr_min", type=float, default=3e-5)
parser.add_argument("--d_model", type=int, default=2048)
parser.add_argument("--d_state", type=int, default=512)
parser.add_argument("--n_layers", type=int, default=24)
parser.add_argument("--n_heads", type=int, default=16)
parser.add_argument("--vocab_size", type=int, default=32000)
args = parser.parse_args()
SEQ_LEN = 2048
EFFECTIVE_BS = args.micro_bs * args.grad_accum
TOKENS_PER_STEP = EFFECTIVE_BS * SEQ_LEN
TOTAL_STEPS = args.total_tokens // TOKENS_PER_STEP
WARMUP_RATIO = 0.03
os.makedirs(args.ckpt_dir, exist_ok=True)
print(f"=== {args.exp_id} ===")
print(f" Total: {args.total_tokens/1e9:.0f}B tokens, {TOTAL_STEPS} steps")
print(f" Effective BS: {EFFECTIVE_BS}")
from deltalens_layer import DeltaLensModel
model = DeltaLensModel(
vocab_size=args.vocab_size,
d_model=args.d_model,
n_layers=args.n_layers,
d_state=args.d_state,
n_heads=args.n_heads,
max_seq_len=SEQ_LEN,
).to(torch.bfloat16).cuda()
total_params = sum(p.numel() for p in model.parameters())
print(f" Params: {total_params:,} ({total_params*2/1e9:.2f}GB)")
print("\nLoading data...")
train_data = torch.load(args.data_path, mmap=True)
val_data = torch.load(args.val_path, mmap=True)
print(f" Train: {len(train_data):,}, Val: {len(val_data):,}")
optimizer = torch.optim.AdamW(
model.parameters(), lr=args.lr, weight_decay=0.01,
betas=(0.9, 0.95), eps=1e-8,
)
warmup_steps = max(1, int(TOTAL_STEPS * WARMUP_RATIO))
# Resume
start_step = 0
global_tokens = 0
ckpts = sorted(glob.glob(os.path.join(args.ckpt_dir, "ckpt_*.pt")))
if ckpts:
print(f"Resuming from {ckpts[-1]}")
ckpt = torch.load(ckpts[-1], map_location="cpu")
model.load_state_dict(ckpt["model_state"])
optimizer.load_state_dict(ckpt["optimizer_state"])
start_step = ckpt["step"] + 1
global_tokens = ckpt["global_tokens"]
del ckpt
wandb.init(project="deltalens", name=args.exp_id,
config=vars(args), resume="allow")
model.train()
EVAL_EVERY = args.total_tokens // 10 // TOKENS_PER_STEP
step_time_start = time.time()
for step in range(start_step, TOTAL_STEPS):
optimizer.zero_grad(set_to_none=True)
step_loss = 0.0
for micro in range(args.grad_accum):
seq_idx = step * EFFECTIVE_BS + micro * args.micro_bs
input_ids = train_data[seq_idx : seq_idx + args.micro_bs].long().cuda()
out = model(input_ids=input_ids, labels=input_ids)
loss = out.loss / args.grad_accum
loss.backward()
step_loss += loss.item()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item()
optimizer.step()
lr = get_lr(step, TOTAL_STEPS, warmup_steps, args.lr, args.lr_min)
for pg in optimizer.param_groups:
pg["lr"] = lr
global_tokens += TOKENS_PER_STEP
if step % 10 == 0:
elapsed = time.time() - step_time_start
tps = (10 * TOKENS_PER_STEP) / max(elapsed, 1) if step > start_step else 0
wandb.log({"train/loss": step_loss, "train/lr": lr,
"train/tokens": global_tokens, "train/grad_norm": grad_norm,
"train/tokens_per_sec": tps, "step": step})
print(f" step {step}/{TOTAL_STEPS} | loss {step_loss:.4f} | "
f"lr {lr:.2e} | gnorm {grad_norm:.3f} | {tps:.0f} tok/s", flush=True)
step_time_start = time.time()
if step > 0 and step % EVAL_EVERY == 0:
ppl, eval_loss = evaluate(model, val_data)
wandb.log({"eval/val_ppl": ppl, "eval/val_loss": eval_loss, "step": step})
print(f" [EVAL] step {step} | val_ppl {ppl:.2f}", flush=True)
if step > 0 and step % 100 == 0:
ckpt_path = os.path.join(args.ckpt_dir, f"ckpt_s{step:06d}.pt")
save_checkpoint(model, optimizer, step, global_tokens, ckpt_path)
ckpts = sorted(glob.glob(os.path.join(args.ckpt_dir, "ckpt_*.pt")))
for old in ckpts[:-2]:
os.remove(old)
if _SHOULD_STOP:
ckpt_path = os.path.join(args.ckpt_dir, f"ckpt_s{step:06d}.pt")
save_checkpoint(model, optimizer, step, global_tokens, ckpt_path)
wandb.finish()
return
# Final save
print("\n=== Training complete! ===")
ckpt_path = os.path.join(args.ckpt_dir, f"ckpt_s{TOTAL_STEPS:06d}_final.pt")
save_checkpoint(model, optimizer, TOTAL_STEPS, global_tokens, ckpt_path)
ppl, eval_loss = evaluate(model, val_data)
wandb.log({"eval/val_ppl": ppl, "step": TOTAL_STEPS})
print(f"[FINAL] val_ppl {ppl:.2f}")
wandb.finish()
if __name__ == "__main__":
main()