sllm / finetune /sft_train.py
geeteshcodes's picture
Initial commit
7f974df verified
"""
finetune/sft_train.py
Full Supervised Fine-Tuning (SFT) of SLLM-150M → Chat Model.
Starts from the pretrained base checkpoint, resizes the token embedding
for 2 new ChatML special tokens, then trains with masked CrossEntropy
so only assistant response tokens contribute to the loss.
Usage (first run):
python finetune/sft_train.py \\
--base_ckpt runs/sllm_150m/ckpt_0011500.pt \\
--run_dir runs/sllm_150m_chat \\
--max_steps 2000 \\
--batch_size 4 --grad_accum 8 \\
--grad_checkpoint
Resume:
python finetune/sft_train.py \\
--resume --run_dir runs/sllm_150m_chat \\
--extra_steps 1000
"""
import os
import sys
import json
import math
import time
import signal
import argparse
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
from transformers import PreTrainedTokenizerFast
from tqdm import tqdm
# ------------------------------------------------------------------ #
# Resolve project root so model/ is importable
# ------------------------------------------------------------------ #
SCRIPT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = SCRIPT_DIR.parent
DATA_DIR = SCRIPT_DIR / "data"
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(SCRIPT_DIR)) # so we can import sft_dataset
from model.config import SLLM_150M
from model.model import SLLM
from sft_dataset import build_sft_dataloader
# ------------------------------------------------------------------ #
# ARG PARSING
# ------------------------------------------------------------------ #
def parse_args():
p = argparse.ArgumentParser(description="SLLM-150M SFT Training")
# Checkpoints
p.add_argument("--base_ckpt", type=str,
default=str(PROJECT_ROOT / "runs" / "sllm_150m" / "ckpt_0011500.pt"),
help="Path to pretrained base checkpoint (.pt)")
p.add_argument("--run_dir", type=str, default="runs/sllm_150m_chat",
help="Output directory for SFT checkpoints and logs")
p.add_argument("--resume", action="store_true",
help="Resume from latest SFT checkpoint in --run_dir")
p.add_argument("--max_steps", type=int, default=2000,
help="Absolute step target for this run")
p.add_argument("--extra_steps", type=int, default=None,
help="Run N more steps from current checkpoint (relative)")
# Data
p.add_argument("--data_dir", type=str, default=str(DATA_DIR),
help="Directory with train_sft.pt, val_sft.pt, and tokenizer files")
p.add_argument("--num_workers", type=int, default=0)
# Optimisation — note: much lower LR than pretraining
p.add_argument("--batch_size", type=int, default=4)
p.add_argument("--grad_accum", type=int, default=8)
p.add_argument("--max_lr", type=float, default=1e-5,
help="Peak LR (10x lower than pretraining)")
p.add_argument("--min_lr", type=float, default=1e-6)
p.add_argument("--warmup_steps", type=int, default=30)
p.add_argument("--weight_decay", type=float, default=0.1)
p.add_argument("--grad_clip", type=float, default=1.0)
p.add_argument("--dropout", type=float, default=0.1,
help="Dropout rate during SFT (0.0 in pretraining)")
# Memory
p.add_argument("--grad_checkpoint", action="store_true",
help="Enable gradient checkpointing (saves VRAM)")
p.add_argument("--dtype", type=str, default="bf16",
choices=["fp32", "fp16", "bf16"])
# Logging
p.add_argument("--log_every", type=int, default=10)
p.add_argument("--save_every", type=int, default=500)
p.add_argument("--val_every", type=int, default=250)
p.add_argument("--val_steps", type=int, default=20)
return p.parse_args()
# ------------------------------------------------------------------ #
# VOCAB RESIZE
# ------------------------------------------------------------------ #
def resize_token_embeddings(model: SLLM, new_vocab_size: int):
"""
Grows model.token_emb from old_vocab_size → new_vocab_size.
New rows are initialised to the mean of existing embeddings so
training starts from a stable point rather than random noise.
lm_head weight-tying is re-applied automatically.
"""
old_size = model.config.vocab_size
if new_vocab_size == old_size:
return
if new_vocab_size < old_size:
raise ValueError(f"Cannot shrink vocab ({old_size}{new_vocab_size})")
d_model = model.config.d_model
device = model.token_emb.weight.device
dtype = model.token_emb.weight.dtype
old_weight = model.token_emb.weight.data.clone() # (old_size, d)
mean_vec = old_weight.mean(dim=0) # (d,)
new_weight = torch.zeros(new_vocab_size, d_model, dtype=dtype, device=device)
new_weight[:old_size] = old_weight
# Broadcast mean_vec into new rows
new_weight[old_size:] = mean_vec.unsqueeze(0).expand(new_vocab_size - old_size, -1)
# Replace the embedding module in-place
new_emb = nn.Embedding(new_vocab_size, d_model).to(device=device, dtype=dtype)
new_emb.weight.data = new_weight
model.token_emb = new_emb
# Re-tie the LM head to the (now larger) embedding
model.lm_head.weight = model.token_emb.weight
# Keep config consistent
model.config.vocab_size = new_vocab_size
n_new = new_vocab_size - old_size
print(f" Vocab resized: {old_size:,}{new_vocab_size:,} (+{n_new} tokens, init=mean)")
# ------------------------------------------------------------------ #
# DROPOUT
# ------------------------------------------------------------------ #
def set_dropout(model: SLLM, rate: float):
"""Applies dropout rate to every nn.Dropout in the model."""
count = 0
for m in model.modules():
if isinstance(m, nn.Dropout):
m.p = rate
count += 1
if count:
print(f" Dropout set to {rate} on {count} layer(s)")
# ------------------------------------------------------------------ #
# LR SCHEDULE (cosine with linear warmup, same shape as train.py)
# ------------------------------------------------------------------ #
def get_lr(step: int, warmup_steps: int, total_steps: int,
max_lr: float, min_lr: float) -> float:
if step < warmup_steps:
return max_lr * (step + 1) / warmup_steps
decay_steps = total_steps if total_steps else 5_000
if step >= decay_steps:
return min_lr
progress = (step - warmup_steps) / max(1, decay_steps - warmup_steps)
coeff = 0.5 * (1.0 + math.cos(math.pi * progress))
return min_lr + coeff * (max_lr - min_lr)
# ------------------------------------------------------------------ #
# OPTIMIZER (mirrors train.py — AdamW selective decay)
# ------------------------------------------------------------------ #
def build_optimizer(model: SLLM, lr: float, weight_decay: float):
decay, no_decay = [], []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if param.dim() >= 2:
decay.append(param)
else:
no_decay.append(param)
groups = [
{"params": decay, "weight_decay": weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]
n_d = sum(p.numel() for p in decay)
n_nd = sum(p.numel() for p in no_decay)
print(f" Optimizer: {n_d/1e6:.1f}M decay | {n_nd/1e6:.1f}M no-decay | lr={lr:.2e}")
# Note: no fused=True here — new embedding rows need correct grad flow
return torch.optim.AdamW(groups, lr=lr, betas=(0.9, 0.95), eps=1e-8)
# ------------------------------------------------------------------ #
# CHECKPOINT SAVE / LOAD
# ------------------------------------------------------------------ #
def save_checkpoint(path: str, model: SLLM, optimizer, step: int,
loss: float, vocab_size: int):
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
torch.save({
"step": step,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
"vocab_size": vocab_size,
}, path)
print(f"\n [CKPT] Saved: {path} (step={step}, loss={loss:.4f})")
def load_sft_checkpoint(run_dir: str, model: SLLM, optimizer, device):
"""Loads the latest ckpt_sft_*.pt from run_dir. Returns (step, vocab_size)."""
ckpts = sorted([
f for f in os.listdir(run_dir)
if f.startswith("ckpt_sft_") and f.endswith(".pt")
])
if not ckpts:
raise FileNotFoundError(f"No SFT checkpoints found in {run_dir}")
path = os.path.join(run_dir, ckpts[-1])
ckpt = torch.load(path, map_location=device, weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
step = ckpt["step"]
vocab_size = ckpt.get("vocab_size", model.config.vocab_size)
loss = ckpt.get("loss", float("nan"))
print(f" [CKPT] Resumed from: {path} (step={step}, loss={loss:.4f})")
return step, vocab_size
# ------------------------------------------------------------------ #
# VALIDATION (uses ignore_index=-100 like training)
# ------------------------------------------------------------------ #
@torch.no_grad()
def estimate_val_loss(model: SLLM, val_loader, val_steps: int,
device, dtype_ctx) -> float:
model.eval()
losses = []
for i, (x, y) in enumerate(val_loader):
if i >= val_steps:
break
x, y = x.to(device), y.to(device)
with dtype_ctx:
logits, _ = model(x)
# Shift logits and labels by 1 to predict the next token
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = y[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
losses.append(loss.item())
model.train()
return sum(losses) / len(losses) if losses else float("nan")
# ------------------------------------------------------------------ #
# METRIC LOGGER
# ------------------------------------------------------------------ #
class MetricLogger:
def __init__(self, log_path: str):
self.log_path = log_path
os.makedirs(os.path.dirname(os.path.abspath(log_path)), exist_ok=True)
print(f" [LOG] Logging to: {log_path}")
def log(self, **kwargs):
with open(self.log_path, "a") as f:
f.write(json.dumps(kwargs) + "\n")
# ------------------------------------------------------------------ #
# MAIN TRAINING LOOP
# ------------------------------------------------------------------ #
def train():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n{'='*60}")
print(f" SLLM-150M → Chat Model (SFT)")
print(f"{'='*60}")
print(f"\nDevice : {device}")
if device.type == "cuda":
print(f"GPU : {torch.cuda.get_device_name(0)}")
print(f"VRAM : {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
# ---- dtype ----------------------------------------------------- #
if args.dtype == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported():
dtype_torch, dtype_name = torch.bfloat16, "bf16"
elif args.dtype == "fp16" and device.type == "cuda":
dtype_torch, dtype_name = torch.float16, "fp16"
else:
dtype_torch, dtype_name = torch.float32, "fp32"
print(f"dtype : {dtype_name}")
use_amp = dtype_torch in (torch.float16, torch.bfloat16)
dtype_ctx = (autocast(device_type=device.type, dtype=dtype_torch)
if use_amp else torch.no_grad().__class__())
scaler = GradScaler(enabled=(dtype_torch == torch.float16))
# ---- Tokenizer ------------------------------------------------- #
print("\n[1/5] Loading tokenizer...")
tok_path = args.data_dir
if os.path.exists(os.path.join(tok_path, "tokenizer.json")):
# Prefer the saved tokenizer from prepare_data.py (has special tokens)
tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path)
print(f" Loaded from data dir: {tok_path}")
else:
# Fallback: load base tokenizer and add special tokens manually
base_tok_dir = str(PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer")
tokenizer = PreTrainedTokenizerFast.from_pretrained(base_tok_dir)
tokenizer.add_special_tokens({"additional_special_tokens":
["<|im_start|>", "<|im_end|>"]})
print(f" Loaded base tokenizer + added special tokens")
new_vocab_size = len(tokenizer)
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None \
else tokenizer.eos_token_id
print(f" Vocab size : {new_vocab_size:,}")
print(f" Pad token : {pad_id}")
# ---- Model ----------------------------------------------------- #
print("\n[2/5] Loading model...")
cfg = SLLM_150M
model = SLLM(cfg).to(device)
if not args.resume:
# Load pretrained base weights (step 11,500)
print(f" Loading base checkpoint: {args.base_ckpt}")
base_ckpt = torch.load(args.base_ckpt, map_location=device, weights_only=False)
model.load_state_dict(base_ckpt["model_state_dict"])
base_step = base_ckpt.get("step", "?")
base_loss = base_ckpt.get("loss", float("nan"))
print(f" Base model step={base_step} loss={base_loss:.4f}")
del base_ckpt
# Grow embedding for the 2 new special tokens
resize_token_embeddings(model, new_vocab_size)
# Apply SFT dropout (was 0.0 in pretraining)
set_dropout(model, args.dropout)
if args.grad_checkpoint:
model.enable_gradient_checkpointing()
print(" Gradient checkpointing: ON")
print(f" Model params: {model.count_params()/1e6:.1f}M")
# ---- Optimizer ------------------------------------------------- #
print("\n[3/5] Building optimizer...")
optimizer = build_optimizer(model, lr=args.max_lr, weight_decay=args.weight_decay)
# ---- Resume from SFT checkpoint -------------------------------- #
start_step = 0
if args.resume:
try:
start_step, _ = load_sft_checkpoint(args.run_dir, model, optimizer, device)
except FileNotFoundError as e:
print(f" [WARN] {e} — starting SFT from base checkpoint.")
# Resolve --extra_steps → --max_steps
if args.extra_steps is not None:
args.max_steps = start_step + args.extra_steps
print(f" --extra_steps {args.extra_steps} → max_steps={args.max_steps}")
if args.max_steps is not None and start_step >= args.max_steps:
print(f"\n [WARN] Already at step {start_step} >= max_steps {args.max_steps}.")
print(f" Use --extra_steps N to run N more steps.")
return
# ---- Data ------------------------------------------------------ #
print("\n[4/5] Loading SFT dataset...")
train_path = os.path.join(args.data_dir, "train_sft.pt")
val_path = os.path.join(args.data_dir, "val_sft.pt")
train_loader = build_sft_dataloader(
data_path=train_path, batch_size=args.batch_size,
pad_token_id=pad_id, context_length=cfg.context_length,
num_workers=args.num_workers, shuffle=True,
)
val_loader = build_sft_dataloader(
data_path=val_path, batch_size=args.batch_size,
pad_token_id=pad_id, context_length=cfg.context_length,
num_workers=0, shuffle=False,
)
# ---- Run dir + logger ------------------------------------------ #
os.makedirs(args.run_dir, exist_ok=True)
log_path = os.path.join(args.run_dir, "sft_log.jsonl")
logger = MetricLogger(log_path)
# ---- Training info --------------------------------------------- #
eff_batch = args.batch_size * args.grad_accum
print(f"\n[5/5] Training config:")
print(f" batch_size : {args.batch_size} (grad_accum={args.grad_accum} → eff={eff_batch})")
print(f" max_steps : {args.max_steps}")
print(f" start_step : {start_step}")
print(f" steps to run : {(args.max_steps - start_step) if args.max_steps else '∞'}")
print(f" max_lr / min_lr: {args.max_lr:.2e} / {args.min_lr:.2e}")
print(f" warmup_steps : {args.warmup_steps}")
print(f" save_every : {args.save_every}")
print(f" val_every : {args.val_every}")
# ---- Ctrl+C handler -------------------------------------------- #
stop_flag = {"stop": False}
def _signal_handler(sig, frame):
print("\n [SIGNAL] Ctrl+C — will save and exit after this step.")
stop_flag["stop"] = True
signal.signal(signal.SIGINT, _signal_handler)
# ================================================================ #
# TRAINING LOOP
# ================================================================ #
model.train()
step = start_step
running_loss = 0.0
t_start = time.time()
t_step_start = time.time()
data_iter = iter(train_loader)
print(f"\n{'='*60}")
print(f" SFT STARTED (step {step}{args.max_steps})")
print(f"{'='*60}\n")
pbar = tqdm(
initial=step, total=args.max_steps,
desc="SFT", unit="step", dynamic_ncols=True,
)
while True:
# ---- Stop conditions --------------------------------------- #
if stop_flag["stop"]:
break
if args.max_steps is not None and step >= args.max_steps:
print(f"\n [DONE] Reached max_steps={args.max_steps}")
break
optimizer.zero_grad(set_to_none=True)
accum_loss = 0.0
# ---- Gradient accumulation micro-steps --------------------- #
for _ in range(args.grad_accum):
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(train_loader)
x, y = next(data_iter)
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp):
logits, _ = model(x) # (B, T, V) — don't use built-in loss
# Shift logits and labels by 1 to predict the next token
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = y[..., 1:].contiguous()
# Use ignore_index=-100 so only assistant tokens drive the loss
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
) / args.grad_accum # scale for accumulation
scaler.scale(loss).backward()
accum_loss += loss.item()
# ---- Grad clip --------------------------------------------- #
if args.grad_clip > 0:
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
else:
grad_norm = float("nan")
# ---- LR ---------------------------------------------------- #
lr = get_lr(step, args.warmup_steps, args.max_steps, args.max_lr, args.min_lr)
for pg in optimizer.param_groups:
pg["lr"] = lr
# ---- Optimizer step ---------------------------------------- #
scaler.step(optimizer)
scaler.update()
step += 1
running_loss = accum_loss
t_now = time.time()
elapsed_step = t_now - t_step_start
t_step_start = t_now
pbar.update(1)
pbar.set_postfix({"loss": f"{running_loss:.4f}", "lr": f"{lr:.1e}"})
# ---- Logging ----------------------------------------------- #
if step % args.log_every == 0:
entry = {
"step": step,
"loss": round(running_loss, 6),
"lr": lr,
"grad_norm": round(float(grad_norm), 4)
if not math.isnan(float(grad_norm)) else None,
"elapsed_s": round(t_now - t_start, 1),
}
if device.type == "cuda":
entry["vram_gb"] = round(torch.cuda.memory_allocated() / 1e9, 3)
logger.log(**entry)
# ---- Validation -------------------------------------------- #
if step % args.val_every == 0:
v_ctx = autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp)
val_loss = estimate_val_loss(model, val_loader, args.val_steps, device, v_ctx)
tqdm.write(
f" [STEP {step:5d}] train={running_loss:.4f} "
f"val={val_loss:.4f} lr={lr:.1e}"
)
logger.log(step=step, val_loss=round(val_loss, 6))
# ---- Checkpoint -------------------------------------------- #
if step % args.save_every == 0:
ckpt_path = os.path.join(args.run_dir, f"ckpt_sft_{step:07d}.pt")
save_checkpoint(ckpt_path, model, optimizer, step, running_loss, new_vocab_size)
# ================================================================ #
# FINAL SAVE
# ================================================================ #
pbar.close()
steps_done = step - start_step
if steps_done > 0:
ckpt_path = os.path.join(args.run_dir, f"ckpt_sft_{step:07d}.pt")
save_checkpoint(ckpt_path, model, optimizer, step, running_loss, new_vocab_size)
else:
print("\n [SKIP] No steps taken — skipping checkpoint save.")
total_time = time.time() - t_start
print(f"\n{'='*60}")
print(f" SFT COMPLETE")
print(f"{'='*60}")
print(f" Steps done : {steps_done}")
print(f" Final loss : {running_loss:.4f}")
print(f" Total time : {total_time/60:.1f} min")
print(f" Run dir : {args.run_dir}")
print(f"\nStart chatting:")
print(f" python finetune/chat.py --run_dir {args.run_dir}")
if __name__ == "__main__":
train()