""" finetune/sft_train.py Full Supervised Fine-Tuning (SFT) of SLLM-150M → Chat Model. Starts from the pretrained base checkpoint, resizes the token embedding for 2 new ChatML special tokens, then trains with masked CrossEntropy so only assistant response tokens contribute to the loss. Usage (first run): python finetune/sft_train.py \\ --base_ckpt runs/sllm_150m/ckpt_0011500.pt \\ --run_dir runs/sllm_150m_chat \\ --max_steps 2000 \\ --batch_size 4 --grad_accum 8 \\ --grad_checkpoint Resume: python finetune/sft_train.py \\ --resume --run_dir runs/sllm_150m_chat \\ --extra_steps 1000 """ import os import sys import json import math import time import signal import argparse from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from torch.amp import autocast, GradScaler from transformers import PreTrainedTokenizerFast from tqdm import tqdm # ------------------------------------------------------------------ # # Resolve project root so model/ is importable # ------------------------------------------------------------------ # SCRIPT_DIR = Path(__file__).resolve().parent PROJECT_ROOT = SCRIPT_DIR.parent DATA_DIR = SCRIPT_DIR / "data" sys.path.insert(0, str(PROJECT_ROOT)) sys.path.insert(0, str(SCRIPT_DIR)) # so we can import sft_dataset from model.config import SLLM_150M from model.model import SLLM from sft_dataset import build_sft_dataloader # ------------------------------------------------------------------ # # ARG PARSING # ------------------------------------------------------------------ # def parse_args(): p = argparse.ArgumentParser(description="SLLM-150M SFT Training") # Checkpoints p.add_argument("--base_ckpt", type=str, default=str(PROJECT_ROOT / "runs" / "sllm_150m" / "ckpt_0011500.pt"), help="Path to pretrained base checkpoint (.pt)") p.add_argument("--run_dir", type=str, default="runs/sllm_150m_chat", help="Output directory for SFT checkpoints and logs") p.add_argument("--resume", action="store_true", help="Resume from latest SFT checkpoint in --run_dir") p.add_argument("--max_steps", type=int, default=2000, help="Absolute step target for this run") p.add_argument("--extra_steps", type=int, default=None, help="Run N more steps from current checkpoint (relative)") # Data p.add_argument("--data_dir", type=str, default=str(DATA_DIR), help="Directory with train_sft.pt, val_sft.pt, and tokenizer files") p.add_argument("--num_workers", type=int, default=0) # Optimisation — note: much lower LR than pretraining p.add_argument("--batch_size", type=int, default=4) p.add_argument("--grad_accum", type=int, default=8) p.add_argument("--max_lr", type=float, default=1e-5, help="Peak LR (10x lower than pretraining)") p.add_argument("--min_lr", type=float, default=1e-6) p.add_argument("--warmup_steps", type=int, default=30) p.add_argument("--weight_decay", type=float, default=0.1) p.add_argument("--grad_clip", type=float, default=1.0) p.add_argument("--dropout", type=float, default=0.1, help="Dropout rate during SFT (0.0 in pretraining)") # Memory p.add_argument("--grad_checkpoint", action="store_true", help="Enable gradient checkpointing (saves VRAM)") p.add_argument("--dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) # Logging p.add_argument("--log_every", type=int, default=10) p.add_argument("--save_every", type=int, default=500) p.add_argument("--val_every", type=int, default=250) p.add_argument("--val_steps", type=int, default=20) return p.parse_args() # ------------------------------------------------------------------ # # VOCAB RESIZE # ------------------------------------------------------------------ # def resize_token_embeddings(model: SLLM, new_vocab_size: int): """ Grows model.token_emb from old_vocab_size → new_vocab_size. New rows are initialised to the mean of existing embeddings so training starts from a stable point rather than random noise. lm_head weight-tying is re-applied automatically. """ old_size = model.config.vocab_size if new_vocab_size == old_size: return if new_vocab_size < old_size: raise ValueError(f"Cannot shrink vocab ({old_size} → {new_vocab_size})") d_model = model.config.d_model device = model.token_emb.weight.device dtype = model.token_emb.weight.dtype old_weight = model.token_emb.weight.data.clone() # (old_size, d) mean_vec = old_weight.mean(dim=0) # (d,) new_weight = torch.zeros(new_vocab_size, d_model, dtype=dtype, device=device) new_weight[:old_size] = old_weight # Broadcast mean_vec into new rows new_weight[old_size:] = mean_vec.unsqueeze(0).expand(new_vocab_size - old_size, -1) # Replace the embedding module in-place new_emb = nn.Embedding(new_vocab_size, d_model).to(device=device, dtype=dtype) new_emb.weight.data = new_weight model.token_emb = new_emb # Re-tie the LM head to the (now larger) embedding model.lm_head.weight = model.token_emb.weight # Keep config consistent model.config.vocab_size = new_vocab_size n_new = new_vocab_size - old_size print(f" Vocab resized: {old_size:,} → {new_vocab_size:,} (+{n_new} tokens, init=mean)") # ------------------------------------------------------------------ # # DROPOUT # ------------------------------------------------------------------ # def set_dropout(model: SLLM, rate: float): """Applies dropout rate to every nn.Dropout in the model.""" count = 0 for m in model.modules(): if isinstance(m, nn.Dropout): m.p = rate count += 1 if count: print(f" Dropout set to {rate} on {count} layer(s)") # ------------------------------------------------------------------ # # LR SCHEDULE (cosine with linear warmup, same shape as train.py) # ------------------------------------------------------------------ # def get_lr(step: int, warmup_steps: int, total_steps: int, max_lr: float, min_lr: float) -> float: if step < warmup_steps: return max_lr * (step + 1) / warmup_steps decay_steps = total_steps if total_steps else 5_000 if step >= decay_steps: return min_lr progress = (step - warmup_steps) / max(1, decay_steps - warmup_steps) coeff = 0.5 * (1.0 + math.cos(math.pi * progress)) return min_lr + coeff * (max_lr - min_lr) # ------------------------------------------------------------------ # # OPTIMIZER (mirrors train.py — AdamW selective decay) # ------------------------------------------------------------------ # def build_optimizer(model: SLLM, lr: float, weight_decay: float): decay, no_decay = [], [] for name, param in model.named_parameters(): if not param.requires_grad: continue if param.dim() >= 2: decay.append(param) else: no_decay.append(param) groups = [ {"params": decay, "weight_decay": weight_decay}, {"params": no_decay, "weight_decay": 0.0}, ] n_d = sum(p.numel() for p in decay) n_nd = sum(p.numel() for p in no_decay) print(f" Optimizer: {n_d/1e6:.1f}M decay | {n_nd/1e6:.1f}M no-decay | lr={lr:.2e}") # Note: no fused=True here — new embedding rows need correct grad flow return torch.optim.AdamW(groups, lr=lr, betas=(0.9, 0.95), eps=1e-8) # ------------------------------------------------------------------ # # CHECKPOINT SAVE / LOAD # ------------------------------------------------------------------ # def save_checkpoint(path: str, model: SLLM, optimizer, step: int, loss: float, vocab_size: int): os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) torch.save({ "step": step, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, "vocab_size": vocab_size, }, path) print(f"\n [CKPT] Saved: {path} (step={step}, loss={loss:.4f})") def load_sft_checkpoint(run_dir: str, model: SLLM, optimizer, device): """Loads the latest ckpt_sft_*.pt from run_dir. Returns (step, vocab_size).""" ckpts = sorted([ f for f in os.listdir(run_dir) if f.startswith("ckpt_sft_") and f.endswith(".pt") ]) if not ckpts: raise FileNotFoundError(f"No SFT checkpoints found in {run_dir}") path = os.path.join(run_dir, ckpts[-1]) ckpt = torch.load(path, map_location=device, weights_only=False) model.load_state_dict(ckpt["model_state_dict"]) optimizer.load_state_dict(ckpt["optimizer_state_dict"]) step = ckpt["step"] vocab_size = ckpt.get("vocab_size", model.config.vocab_size) loss = ckpt.get("loss", float("nan")) print(f" [CKPT] Resumed from: {path} (step={step}, loss={loss:.4f})") return step, vocab_size # ------------------------------------------------------------------ # # VALIDATION (uses ignore_index=-100 like training) # ------------------------------------------------------------------ # @torch.no_grad() def estimate_val_loss(model: SLLM, val_loader, val_steps: int, device, dtype_ctx) -> float: model.eval() losses = [] for i, (x, y) in enumerate(val_loader): if i >= val_steps: break x, y = x.to(device), y.to(device) with dtype_ctx: logits, _ = model(x) # Shift logits and labels by 1 to predict the next token shift_logits = logits[..., :-1, :].contiguous() shift_labels = y[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) losses.append(loss.item()) model.train() return sum(losses) / len(losses) if losses else float("nan") # ------------------------------------------------------------------ # # METRIC LOGGER # ------------------------------------------------------------------ # class MetricLogger: def __init__(self, log_path: str): self.log_path = log_path os.makedirs(os.path.dirname(os.path.abspath(log_path)), exist_ok=True) print(f" [LOG] Logging to: {log_path}") def log(self, **kwargs): with open(self.log_path, "a") as f: f.write(json.dumps(kwargs) + "\n") # ------------------------------------------------------------------ # # MAIN TRAINING LOOP # ------------------------------------------------------------------ # def train(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\n{'='*60}") print(f" SLLM-150M → Chat Model (SFT)") print(f"{'='*60}") print(f"\nDevice : {device}") if device.type == "cuda": print(f"GPU : {torch.cuda.get_device_name(0)}") print(f"VRAM : {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB") # ---- dtype ----------------------------------------------------- # if args.dtype == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported(): dtype_torch, dtype_name = torch.bfloat16, "bf16" elif args.dtype == "fp16" and device.type == "cuda": dtype_torch, dtype_name = torch.float16, "fp16" else: dtype_torch, dtype_name = torch.float32, "fp32" print(f"dtype : {dtype_name}") use_amp = dtype_torch in (torch.float16, torch.bfloat16) dtype_ctx = (autocast(device_type=device.type, dtype=dtype_torch) if use_amp else torch.no_grad().__class__()) scaler = GradScaler(enabled=(dtype_torch == torch.float16)) # ---- Tokenizer ------------------------------------------------- # print("\n[1/5] Loading tokenizer...") tok_path = args.data_dir if os.path.exists(os.path.join(tok_path, "tokenizer.json")): # Prefer the saved tokenizer from prepare_data.py (has special tokens) tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path) print(f" Loaded from data dir: {tok_path}") else: # Fallback: load base tokenizer and add special tokens manually base_tok_dir = str(PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer") tokenizer = PreTrainedTokenizerFast.from_pretrained(base_tok_dir) tokenizer.add_special_tokens({"additional_special_tokens": ["<|im_start|>", "<|im_end|>"]}) print(f" Loaded base tokenizer + added special tokens") new_vocab_size = len(tokenizer) pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None \ else tokenizer.eos_token_id print(f" Vocab size : {new_vocab_size:,}") print(f" Pad token : {pad_id}") # ---- Model ----------------------------------------------------- # print("\n[2/5] Loading model...") cfg = SLLM_150M model = SLLM(cfg).to(device) if not args.resume: # Load pretrained base weights (step 11,500) print(f" Loading base checkpoint: {args.base_ckpt}") base_ckpt = torch.load(args.base_ckpt, map_location=device, weights_only=False) model.load_state_dict(base_ckpt["model_state_dict"]) base_step = base_ckpt.get("step", "?") base_loss = base_ckpt.get("loss", float("nan")) print(f" Base model step={base_step} loss={base_loss:.4f}") del base_ckpt # Grow embedding for the 2 new special tokens resize_token_embeddings(model, new_vocab_size) # Apply SFT dropout (was 0.0 in pretraining) set_dropout(model, args.dropout) if args.grad_checkpoint: model.enable_gradient_checkpointing() print(" Gradient checkpointing: ON") print(f" Model params: {model.count_params()/1e6:.1f}M") # ---- Optimizer ------------------------------------------------- # print("\n[3/5] Building optimizer...") optimizer = build_optimizer(model, lr=args.max_lr, weight_decay=args.weight_decay) # ---- Resume from SFT checkpoint -------------------------------- # start_step = 0 if args.resume: try: start_step, _ = load_sft_checkpoint(args.run_dir, model, optimizer, device) except FileNotFoundError as e: print(f" [WARN] {e} — starting SFT from base checkpoint.") # Resolve --extra_steps → --max_steps if args.extra_steps is not None: args.max_steps = start_step + args.extra_steps print(f" --extra_steps {args.extra_steps} → max_steps={args.max_steps}") if args.max_steps is not None and start_step >= args.max_steps: print(f"\n [WARN] Already at step {start_step} >= max_steps {args.max_steps}.") print(f" Use --extra_steps N to run N more steps.") return # ---- Data ------------------------------------------------------ # print("\n[4/5] Loading SFT dataset...") train_path = os.path.join(args.data_dir, "train_sft.pt") val_path = os.path.join(args.data_dir, "val_sft.pt") train_loader = build_sft_dataloader( data_path=train_path, batch_size=args.batch_size, pad_token_id=pad_id, context_length=cfg.context_length, num_workers=args.num_workers, shuffle=True, ) val_loader = build_sft_dataloader( data_path=val_path, batch_size=args.batch_size, pad_token_id=pad_id, context_length=cfg.context_length, num_workers=0, shuffle=False, ) # ---- Run dir + logger ------------------------------------------ # os.makedirs(args.run_dir, exist_ok=True) log_path = os.path.join(args.run_dir, "sft_log.jsonl") logger = MetricLogger(log_path) # ---- Training info --------------------------------------------- # eff_batch = args.batch_size * args.grad_accum print(f"\n[5/5] Training config:") print(f" batch_size : {args.batch_size} (grad_accum={args.grad_accum} → eff={eff_batch})") print(f" max_steps : {args.max_steps}") print(f" start_step : {start_step}") print(f" steps to run : {(args.max_steps - start_step) if args.max_steps else '∞'}") print(f" max_lr / min_lr: {args.max_lr:.2e} / {args.min_lr:.2e}") print(f" warmup_steps : {args.warmup_steps}") print(f" save_every : {args.save_every}") print(f" val_every : {args.val_every}") # ---- Ctrl+C handler -------------------------------------------- # stop_flag = {"stop": False} def _signal_handler(sig, frame): print("\n [SIGNAL] Ctrl+C — will save and exit after this step.") stop_flag["stop"] = True signal.signal(signal.SIGINT, _signal_handler) # ================================================================ # # TRAINING LOOP # ================================================================ # model.train() step = start_step running_loss = 0.0 t_start = time.time() t_step_start = time.time() data_iter = iter(train_loader) print(f"\n{'='*60}") print(f" SFT STARTED (step {step} → {args.max_steps})") print(f"{'='*60}\n") pbar = tqdm( initial=step, total=args.max_steps, desc="SFT", unit="step", dynamic_ncols=True, ) while True: # ---- Stop conditions --------------------------------------- # if stop_flag["stop"]: break if args.max_steps is not None and step >= args.max_steps: print(f"\n [DONE] Reached max_steps={args.max_steps}") break optimizer.zero_grad(set_to_none=True) accum_loss = 0.0 # ---- Gradient accumulation micro-steps --------------------- # for _ in range(args.grad_accum): try: x, y = next(data_iter) except StopIteration: data_iter = iter(train_loader) x, y = next(data_iter) x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp): logits, _ = model(x) # (B, T, V) — don't use built-in loss # Shift logits and labels by 1 to predict the next token shift_logits = logits[..., :-1, :].contiguous() shift_labels = y[..., 1:].contiguous() # Use ignore_index=-100 so only assistant tokens drive the loss loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) / args.grad_accum # scale for accumulation scaler.scale(loss).backward() accum_loss += loss.item() # ---- Grad clip --------------------------------------------- # if args.grad_clip > 0: scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) else: grad_norm = float("nan") # ---- LR ---------------------------------------------------- # lr = get_lr(step, args.warmup_steps, args.max_steps, args.max_lr, args.min_lr) for pg in optimizer.param_groups: pg["lr"] = lr # ---- Optimizer step ---------------------------------------- # scaler.step(optimizer) scaler.update() step += 1 running_loss = accum_loss t_now = time.time() elapsed_step = t_now - t_step_start t_step_start = t_now pbar.update(1) pbar.set_postfix({"loss": f"{running_loss:.4f}", "lr": f"{lr:.1e}"}) # ---- Logging ----------------------------------------------- # if step % args.log_every == 0: entry = { "step": step, "loss": round(running_loss, 6), "lr": lr, "grad_norm": round(float(grad_norm), 4) if not math.isnan(float(grad_norm)) else None, "elapsed_s": round(t_now - t_start, 1), } if device.type == "cuda": entry["vram_gb"] = round(torch.cuda.memory_allocated() / 1e9, 3) logger.log(**entry) # ---- Validation -------------------------------------------- # if step % args.val_every == 0: v_ctx = autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp) val_loss = estimate_val_loss(model, val_loader, args.val_steps, device, v_ctx) tqdm.write( f" [STEP {step:5d}] train={running_loss:.4f} " f"val={val_loss:.4f} lr={lr:.1e}" ) logger.log(step=step, val_loss=round(val_loss, 6)) # ---- Checkpoint -------------------------------------------- # if step % args.save_every == 0: ckpt_path = os.path.join(args.run_dir, f"ckpt_sft_{step:07d}.pt") save_checkpoint(ckpt_path, model, optimizer, step, running_loss, new_vocab_size) # ================================================================ # # FINAL SAVE # ================================================================ # pbar.close() steps_done = step - start_step if steps_done > 0: ckpt_path = os.path.join(args.run_dir, f"ckpt_sft_{step:07d}.pt") save_checkpoint(ckpt_path, model, optimizer, step, running_loss, new_vocab_size) else: print("\n [SKIP] No steps taken — skipping checkpoint save.") total_time = time.time() - t_start print(f"\n{'='*60}") print(f" SFT COMPLETE") print(f"{'='*60}") print(f" Steps done : {steps_done}") print(f" Final loss : {running_loss:.4f}") print(f" Total time : {total_time/60:.1f} min") print(f" Run dir : {args.run_dir}") print(f"\nStart chatting:") print(f" python finetune/chat.py --run_dir {args.run_dir}") if __name__ == "__main__": train()