| """ |
| finetune/sft_train.py |
| |
| Full Supervised Fine-Tuning (SFT) of SLLM-150M → Chat Model. |
| |
| Starts from the pretrained base checkpoint, resizes the token embedding |
| for 2 new ChatML special tokens, then trains with masked CrossEntropy |
| so only assistant response tokens contribute to the loss. |
| |
| Usage (first run): |
| python finetune/sft_train.py \\ |
| --base_ckpt runs/sllm_150m/ckpt_0011500.pt \\ |
| --run_dir runs/sllm_150m_chat \\ |
| --max_steps 2000 \\ |
| --batch_size 4 --grad_accum 8 \\ |
| --grad_checkpoint |
| |
| Resume: |
| python finetune/sft_train.py \\ |
| --resume --run_dir runs/sllm_150m_chat \\ |
| --extra_steps 1000 |
| """ |
|
|
| import os |
| import sys |
| import json |
| import math |
| import time |
| import signal |
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.amp import autocast, GradScaler |
| from transformers import PreTrainedTokenizerFast |
| from tqdm import tqdm |
|
|
| |
| |
| |
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent |
| PROJECT_ROOT = SCRIPT_DIR.parent |
| DATA_DIR = SCRIPT_DIR / "data" |
|
|
| sys.path.insert(0, str(PROJECT_ROOT)) |
| sys.path.insert(0, str(SCRIPT_DIR)) |
|
|
| from model.config import SLLM_150M |
| from model.model import SLLM |
| from sft_dataset import build_sft_dataloader |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="SLLM-150M SFT Training") |
|
|
| |
| p.add_argument("--base_ckpt", type=str, |
| default=str(PROJECT_ROOT / "runs" / "sllm_150m" / "ckpt_0011500.pt"), |
| help="Path to pretrained base checkpoint (.pt)") |
| p.add_argument("--run_dir", type=str, default="runs/sllm_150m_chat", |
| help="Output directory for SFT checkpoints and logs") |
| p.add_argument("--resume", action="store_true", |
| help="Resume from latest SFT checkpoint in --run_dir") |
| p.add_argument("--max_steps", type=int, default=2000, |
| help="Absolute step target for this run") |
| p.add_argument("--extra_steps", type=int, default=None, |
| help="Run N more steps from current checkpoint (relative)") |
|
|
| |
| p.add_argument("--data_dir", type=str, default=str(DATA_DIR), |
| help="Directory with train_sft.pt, val_sft.pt, and tokenizer files") |
| p.add_argument("--num_workers", type=int, default=0) |
|
|
| |
| p.add_argument("--batch_size", type=int, default=4) |
| p.add_argument("--grad_accum", type=int, default=8) |
| p.add_argument("--max_lr", type=float, default=1e-5, |
| help="Peak LR (10x lower than pretraining)") |
| p.add_argument("--min_lr", type=float, default=1e-6) |
| p.add_argument("--warmup_steps", type=int, default=30) |
| p.add_argument("--weight_decay", type=float, default=0.1) |
| p.add_argument("--grad_clip", type=float, default=1.0) |
| p.add_argument("--dropout", type=float, default=0.1, |
| help="Dropout rate during SFT (0.0 in pretraining)") |
|
|
| |
| p.add_argument("--grad_checkpoint", action="store_true", |
| help="Enable gradient checkpointing (saves VRAM)") |
| p.add_argument("--dtype", type=str, default="bf16", |
| choices=["fp32", "fp16", "bf16"]) |
|
|
| |
| p.add_argument("--log_every", type=int, default=10) |
| p.add_argument("--save_every", type=int, default=500) |
| p.add_argument("--val_every", type=int, default=250) |
| p.add_argument("--val_steps", type=int, default=20) |
|
|
| return p.parse_args() |
|
|
|
|
| |
| |
| |
|
|
| def resize_token_embeddings(model: SLLM, new_vocab_size: int): |
| """ |
| Grows model.token_emb from old_vocab_size → new_vocab_size. |
| |
| New rows are initialised to the mean of existing embeddings so |
| training starts from a stable point rather than random noise. |
| lm_head weight-tying is re-applied automatically. |
| """ |
| old_size = model.config.vocab_size |
| if new_vocab_size == old_size: |
| return |
| if new_vocab_size < old_size: |
| raise ValueError(f"Cannot shrink vocab ({old_size} → {new_vocab_size})") |
|
|
| d_model = model.config.d_model |
| device = model.token_emb.weight.device |
| dtype = model.token_emb.weight.dtype |
| old_weight = model.token_emb.weight.data.clone() |
| mean_vec = old_weight.mean(dim=0) |
|
|
| new_weight = torch.zeros(new_vocab_size, d_model, dtype=dtype, device=device) |
| new_weight[:old_size] = old_weight |
| |
| new_weight[old_size:] = mean_vec.unsqueeze(0).expand(new_vocab_size - old_size, -1) |
|
|
| |
| new_emb = nn.Embedding(new_vocab_size, d_model).to(device=device, dtype=dtype) |
| new_emb.weight.data = new_weight |
| model.token_emb = new_emb |
|
|
| |
| model.lm_head.weight = model.token_emb.weight |
|
|
| |
| model.config.vocab_size = new_vocab_size |
|
|
| n_new = new_vocab_size - old_size |
| print(f" Vocab resized: {old_size:,} → {new_vocab_size:,} (+{n_new} tokens, init=mean)") |
|
|
|
|
| |
| |
| |
|
|
| def set_dropout(model: SLLM, rate: float): |
| """Applies dropout rate to every nn.Dropout in the model.""" |
| count = 0 |
| for m in model.modules(): |
| if isinstance(m, nn.Dropout): |
| m.p = rate |
| count += 1 |
| if count: |
| print(f" Dropout set to {rate} on {count} layer(s)") |
|
|
|
|
| |
| |
| |
|
|
| def get_lr(step: int, warmup_steps: int, total_steps: int, |
| max_lr: float, min_lr: float) -> float: |
| if step < warmup_steps: |
| return max_lr * (step + 1) / warmup_steps |
| decay_steps = total_steps if total_steps else 5_000 |
| if step >= decay_steps: |
| return min_lr |
| progress = (step - warmup_steps) / max(1, decay_steps - warmup_steps) |
| coeff = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| return min_lr + coeff * (max_lr - min_lr) |
|
|
|
|
| |
| |
| |
|
|
| def build_optimizer(model: SLLM, lr: float, weight_decay: float): |
| decay, no_decay = [], [] |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if param.dim() >= 2: |
| decay.append(param) |
| else: |
| no_decay.append(param) |
|
|
| groups = [ |
| {"params": decay, "weight_decay": weight_decay}, |
| {"params": no_decay, "weight_decay": 0.0}, |
| ] |
| n_d = sum(p.numel() for p in decay) |
| n_nd = sum(p.numel() for p in no_decay) |
| print(f" Optimizer: {n_d/1e6:.1f}M decay | {n_nd/1e6:.1f}M no-decay | lr={lr:.2e}") |
|
|
| |
| return torch.optim.AdamW(groups, lr=lr, betas=(0.9, 0.95), eps=1e-8) |
|
|
|
|
| |
| |
| |
|
|
| def save_checkpoint(path: str, model: SLLM, optimizer, step: int, |
| loss: float, vocab_size: int): |
| os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) |
| torch.save({ |
| "step": step, |
| "model_state_dict": model.state_dict(), |
| "optimizer_state_dict": optimizer.state_dict(), |
| "loss": loss, |
| "vocab_size": vocab_size, |
| }, path) |
| print(f"\n [CKPT] Saved: {path} (step={step}, loss={loss:.4f})") |
|
|
|
|
| def load_sft_checkpoint(run_dir: str, model: SLLM, optimizer, device): |
| """Loads the latest ckpt_sft_*.pt from run_dir. Returns (step, vocab_size).""" |
| ckpts = sorted([ |
| f for f in os.listdir(run_dir) |
| if f.startswith("ckpt_sft_") and f.endswith(".pt") |
| ]) |
| if not ckpts: |
| raise FileNotFoundError(f"No SFT checkpoints found in {run_dir}") |
|
|
| path = os.path.join(run_dir, ckpts[-1]) |
| ckpt = torch.load(path, map_location=device, weights_only=False) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| optimizer.load_state_dict(ckpt["optimizer_state_dict"]) |
| step = ckpt["step"] |
| vocab_size = ckpt.get("vocab_size", model.config.vocab_size) |
| loss = ckpt.get("loss", float("nan")) |
| print(f" [CKPT] Resumed from: {path} (step={step}, loss={loss:.4f})") |
| return step, vocab_size |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def estimate_val_loss(model: SLLM, val_loader, val_steps: int, |
| device, dtype_ctx) -> float: |
| model.eval() |
| losses = [] |
| for i, (x, y) in enumerate(val_loader): |
| if i >= val_steps: |
| break |
| x, y = x.to(device), y.to(device) |
| with dtype_ctx: |
| logits, _ = model(x) |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = y[..., 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
| losses.append(loss.item()) |
| model.train() |
| return sum(losses) / len(losses) if losses else float("nan") |
|
|
|
|
| |
| |
| |
|
|
| class MetricLogger: |
| def __init__(self, log_path: str): |
| self.log_path = log_path |
| os.makedirs(os.path.dirname(os.path.abspath(log_path)), exist_ok=True) |
| print(f" [LOG] Logging to: {log_path}") |
|
|
| def log(self, **kwargs): |
| with open(self.log_path, "a") as f: |
| f.write(json.dumps(kwargs) + "\n") |
|
|
|
|
| |
| |
| |
|
|
| def train(): |
| args = parse_args() |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| print(f"\n{'='*60}") |
| print(f" SLLM-150M → Chat Model (SFT)") |
| print(f"{'='*60}") |
| print(f"\nDevice : {device}") |
| if device.type == "cuda": |
| print(f"GPU : {torch.cuda.get_device_name(0)}") |
| print(f"VRAM : {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB") |
|
|
| |
| if args.dtype == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported(): |
| dtype_torch, dtype_name = torch.bfloat16, "bf16" |
| elif args.dtype == "fp16" and device.type == "cuda": |
| dtype_torch, dtype_name = torch.float16, "fp16" |
| else: |
| dtype_torch, dtype_name = torch.float32, "fp32" |
|
|
| print(f"dtype : {dtype_name}") |
| use_amp = dtype_torch in (torch.float16, torch.bfloat16) |
| dtype_ctx = (autocast(device_type=device.type, dtype=dtype_torch) |
| if use_amp else torch.no_grad().__class__()) |
| scaler = GradScaler(enabled=(dtype_torch == torch.float16)) |
|
|
| |
| print("\n[1/5] Loading tokenizer...") |
| tok_path = args.data_dir |
| if os.path.exists(os.path.join(tok_path, "tokenizer.json")): |
| |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path) |
| print(f" Loaded from data dir: {tok_path}") |
| else: |
| |
| base_tok_dir = str(PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer") |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(base_tok_dir) |
| tokenizer.add_special_tokens({"additional_special_tokens": |
| ["<|im_start|>", "<|im_end|>"]}) |
| print(f" Loaded base tokenizer + added special tokens") |
|
|
| new_vocab_size = len(tokenizer) |
| pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None \ |
| else tokenizer.eos_token_id |
| print(f" Vocab size : {new_vocab_size:,}") |
| print(f" Pad token : {pad_id}") |
|
|
| |
| print("\n[2/5] Loading model...") |
| cfg = SLLM_150M |
| model = SLLM(cfg).to(device) |
|
|
| if not args.resume: |
| |
| print(f" Loading base checkpoint: {args.base_ckpt}") |
| base_ckpt = torch.load(args.base_ckpt, map_location=device, weights_only=False) |
| model.load_state_dict(base_ckpt["model_state_dict"]) |
| base_step = base_ckpt.get("step", "?") |
| base_loss = base_ckpt.get("loss", float("nan")) |
| print(f" Base model step={base_step} loss={base_loss:.4f}") |
| del base_ckpt |
|
|
| |
| resize_token_embeddings(model, new_vocab_size) |
|
|
| |
| set_dropout(model, args.dropout) |
|
|
| if args.grad_checkpoint: |
| model.enable_gradient_checkpointing() |
| print(" Gradient checkpointing: ON") |
|
|
| print(f" Model params: {model.count_params()/1e6:.1f}M") |
|
|
| |
| print("\n[3/5] Building optimizer...") |
| optimizer = build_optimizer(model, lr=args.max_lr, weight_decay=args.weight_decay) |
|
|
| |
| start_step = 0 |
| if args.resume: |
| try: |
| start_step, _ = load_sft_checkpoint(args.run_dir, model, optimizer, device) |
| except FileNotFoundError as e: |
| print(f" [WARN] {e} — starting SFT from base checkpoint.") |
|
|
| |
| if args.extra_steps is not None: |
| args.max_steps = start_step + args.extra_steps |
| print(f" --extra_steps {args.extra_steps} → max_steps={args.max_steps}") |
|
|
| if args.max_steps is not None and start_step >= args.max_steps: |
| print(f"\n [WARN] Already at step {start_step} >= max_steps {args.max_steps}.") |
| print(f" Use --extra_steps N to run N more steps.") |
| return |
|
|
| |
| print("\n[4/5] Loading SFT dataset...") |
| train_path = os.path.join(args.data_dir, "train_sft.pt") |
| val_path = os.path.join(args.data_dir, "val_sft.pt") |
|
|
| train_loader = build_sft_dataloader( |
| data_path=train_path, batch_size=args.batch_size, |
| pad_token_id=pad_id, context_length=cfg.context_length, |
| num_workers=args.num_workers, shuffle=True, |
| ) |
| val_loader = build_sft_dataloader( |
| data_path=val_path, batch_size=args.batch_size, |
| pad_token_id=pad_id, context_length=cfg.context_length, |
| num_workers=0, shuffle=False, |
| ) |
|
|
| |
| os.makedirs(args.run_dir, exist_ok=True) |
| log_path = os.path.join(args.run_dir, "sft_log.jsonl") |
| logger = MetricLogger(log_path) |
|
|
| |
| eff_batch = args.batch_size * args.grad_accum |
| print(f"\n[5/5] Training config:") |
| print(f" batch_size : {args.batch_size} (grad_accum={args.grad_accum} → eff={eff_batch})") |
| print(f" max_steps : {args.max_steps}") |
| print(f" start_step : {start_step}") |
| print(f" steps to run : {(args.max_steps - start_step) if args.max_steps else '∞'}") |
| print(f" max_lr / min_lr: {args.max_lr:.2e} / {args.min_lr:.2e}") |
| print(f" warmup_steps : {args.warmup_steps}") |
| print(f" save_every : {args.save_every}") |
| print(f" val_every : {args.val_every}") |
|
|
| |
| stop_flag = {"stop": False} |
| def _signal_handler(sig, frame): |
| print("\n [SIGNAL] Ctrl+C — will save and exit after this step.") |
| stop_flag["stop"] = True |
| signal.signal(signal.SIGINT, _signal_handler) |
|
|
| |
| |
| |
| model.train() |
| step = start_step |
| running_loss = 0.0 |
| t_start = time.time() |
| t_step_start = time.time() |
| data_iter = iter(train_loader) |
|
|
| print(f"\n{'='*60}") |
| print(f" SFT STARTED (step {step} → {args.max_steps})") |
| print(f"{'='*60}\n") |
|
|
| pbar = tqdm( |
| initial=step, total=args.max_steps, |
| desc="SFT", unit="step", dynamic_ncols=True, |
| ) |
|
|
| while True: |
| |
| if stop_flag["stop"]: |
| break |
| if args.max_steps is not None and step >= args.max_steps: |
| print(f"\n [DONE] Reached max_steps={args.max_steps}") |
| break |
|
|
| optimizer.zero_grad(set_to_none=True) |
| accum_loss = 0.0 |
|
|
| |
| for _ in range(args.grad_accum): |
| try: |
| x, y = next(data_iter) |
| except StopIteration: |
| data_iter = iter(train_loader) |
| x, y = next(data_iter) |
|
|
| x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) |
|
|
| with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp): |
| logits, _ = model(x) |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = y[..., 1:].contiguous() |
| |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) / args.grad_accum |
|
|
| scaler.scale(loss).backward() |
| accum_loss += loss.item() |
|
|
| |
| if args.grad_clip > 0: |
| scaler.unscale_(optimizer) |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
| else: |
| grad_norm = float("nan") |
|
|
| |
| lr = get_lr(step, args.warmup_steps, args.max_steps, args.max_lr, args.min_lr) |
| for pg in optimizer.param_groups: |
| pg["lr"] = lr |
|
|
| |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| step += 1 |
| running_loss = accum_loss |
|
|
| t_now = time.time() |
| elapsed_step = t_now - t_step_start |
| t_step_start = t_now |
|
|
| pbar.update(1) |
| pbar.set_postfix({"loss": f"{running_loss:.4f}", "lr": f"{lr:.1e}"}) |
|
|
| |
| if step % args.log_every == 0: |
| entry = { |
| "step": step, |
| "loss": round(running_loss, 6), |
| "lr": lr, |
| "grad_norm": round(float(grad_norm), 4) |
| if not math.isnan(float(grad_norm)) else None, |
| "elapsed_s": round(t_now - t_start, 1), |
| } |
| if device.type == "cuda": |
| entry["vram_gb"] = round(torch.cuda.memory_allocated() / 1e9, 3) |
| logger.log(**entry) |
|
|
| |
| if step % args.val_every == 0: |
| v_ctx = autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp) |
| val_loss = estimate_val_loss(model, val_loader, args.val_steps, device, v_ctx) |
| tqdm.write( |
| f" [STEP {step:5d}] train={running_loss:.4f} " |
| f"val={val_loss:.4f} lr={lr:.1e}" |
| ) |
| logger.log(step=step, val_loss=round(val_loss, 6)) |
|
|
| |
| if step % args.save_every == 0: |
| ckpt_path = os.path.join(args.run_dir, f"ckpt_sft_{step:07d}.pt") |
| save_checkpoint(ckpt_path, model, optimizer, step, running_loss, new_vocab_size) |
|
|
| |
| |
| |
| pbar.close() |
| steps_done = step - start_step |
| if steps_done > 0: |
| ckpt_path = os.path.join(args.run_dir, f"ckpt_sft_{step:07d}.pt") |
| save_checkpoint(ckpt_path, model, optimizer, step, running_loss, new_vocab_size) |
| else: |
| print("\n [SKIP] No steps taken — skipping checkpoint save.") |
|
|
| total_time = time.time() - t_start |
| print(f"\n{'='*60}") |
| print(f" SFT COMPLETE") |
| print(f"{'='*60}") |
| print(f" Steps done : {steps_done}") |
| print(f" Final loss : {running_loss:.4f}") |
| print(f" Total time : {total_time/60:.1f} min") |
| print(f" Run dir : {args.run_dir}") |
| print(f"\nStart chatting:") |
| print(f" python finetune/chat.py --run_dir {args.run_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|