|
|
""" |
|
|
Supervised fine-tuning (SFT) the model. |
|
|
Run as: |
|
|
|
|
|
python -m scripts.chat_sft |
|
|
|
|
|
Or torchrun for training: |
|
|
|
|
|
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16 |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" |
|
|
import time |
|
|
import wandb |
|
|
import torch |
|
|
from contextlib import nullcontext |
|
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type |
|
|
from nanochat.tokenizer import get_token_bytes |
|
|
from nanochat.checkpoint_manager import save_checkpoint |
|
|
from nanochat.loss_eval import evaluate_bpb |
|
|
from nanochat.checkpoint_manager import load_model |
|
|
import torch.distributed as dist |
|
|
|
|
|
from tasks.common import TaskMixture |
|
|
from tasks.gsm8k import GSM8K |
|
|
from tasks.mmlu import MMLU |
|
|
from tasks.smoltalk import SmolTalk |
|
|
from tasks.customjson import CustomJSON |
|
|
from tasks.spellingbee import SimpleSpelling, SpellingBee |
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the model") |
|
|
|
|
|
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") |
|
|
|
|
|
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") |
|
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") |
|
|
|
|
|
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") |
|
|
parser.add_argument("--model-step", type=int, default=None, help="model step to load from") |
|
|
|
|
|
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)") |
|
|
|
|
|
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") |
|
|
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") |
|
|
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") |
|
|
|
|
|
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") |
|
|
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)") |
|
|
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") |
|
|
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)") |
|
|
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR") |
|
|
|
|
|
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)") |
|
|
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on") |
|
|
|
|
|
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report") |
|
|
args = parser.parse_args() |
|
|
user_config = vars(args).copy() |
|
|
|
|
|
|
|
|
|
|
|
device_type = autodetect_device_type() if args.device_type == "" else args.device_type |
|
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) |
|
|
master_process = ddp_rank == 0 |
|
|
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 |
|
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() |
|
|
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None |
|
|
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 |
|
|
|
|
|
|
|
|
use_dummy_wandb = args.run == "dummy" or not master_process |
|
|
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config) |
|
|
|
|
|
|
|
|
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) |
|
|
pretrain_batch_size = meta.get("device_batch_size", None) |
|
|
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size: |
|
|
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?") |
|
|
orig_model = model |
|
|
|
|
|
|
|
|
depth = model.config.n_layer |
|
|
num_flops_per_token = model.estimate_flops() |
|
|
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len |
|
|
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size |
|
|
assert args.total_batch_size % world_tokens_per_fwdbwd == 0 |
|
|
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd |
|
|
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") |
|
|
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") |
|
|
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") |
|
|
token_bytes = get_token_bytes(device=device) |
|
|
|
|
|
|
|
|
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay) |
|
|
|
|
|
for group in optimizer.param_groups: |
|
|
group["lr"] = group["lr"] * args.init_lr_frac |
|
|
group["initial_lr"] = group["lr"] |
|
|
|
|
|
|
|
|
base_dir = get_base_dir() |
|
|
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") |
|
|
|
|
|
train_dataset = TaskMixture([ |
|
|
SmolTalk(split="train"), |
|
|
|
|
|
|
|
|
|
|
|
CustomJSON(filepath=identity_conversations_filepath), |
|
|
CustomJSON(filepath=identity_conversations_filepath), |
|
|
|
|
|
|
|
|
]) |
|
|
val_dataset = TaskMixture([ |
|
|
SmolTalk(split="test"), |
|
|
|
|
|
|
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
last_step = False |
|
|
approx_progress = 0.0 |
|
|
current_epoch = 1 |
|
|
def sft_data_generator_bos_bestfit(split, buffer_size=100): |
|
|
""" |
|
|
BOS-aligned dataloader for SFT with bestfit-pad packing. |
|
|
|
|
|
Each row in the batch starts with BOS (beginning of a conversation). |
|
|
Conversations are packed using best-fit algorithm. When no conversation fits, |
|
|
the row is padded (instead of cropping) to ensure no tokens are ever discarded. |
|
|
Padding positions have targets masked with -1 (ignore_index for cross-entropy). |
|
|
""" |
|
|
global last_step, approx_progress, current_epoch |
|
|
assert split in {"train", "val"}, "split must be 'train' or 'val'" |
|
|
dataset = train_dataset if split == "train" else val_dataset |
|
|
dataset_size = len(dataset) |
|
|
assert dataset_size > 0 |
|
|
row_capacity = args.max_seq_len + 1 |
|
|
bos_token = tokenizer.get_bos_token_id() |
|
|
|
|
|
|
|
|
conv_buffer = [] |
|
|
cursor = ddp_rank |
|
|
consumed = ddp_rank |
|
|
epoch = 1 |
|
|
it = 0 |
|
|
|
|
|
def refill_buffer(): |
|
|
nonlocal cursor, epoch |
|
|
while len(conv_buffer) < buffer_size: |
|
|
conversation = dataset[cursor] |
|
|
ids, _ = tokenizer.render_conversation(conversation) |
|
|
conv_buffer.append(ids) |
|
|
cursor += ddp_world_size |
|
|
if cursor >= dataset_size: |
|
|
cursor = cursor % dataset_size |
|
|
epoch += 1 |
|
|
|
|
|
|
|
|
while True: |
|
|
rows = [] |
|
|
row_lengths = [] |
|
|
for _ in range(args.device_batch_size): |
|
|
row = [] |
|
|
padded = False |
|
|
while len(row) < row_capacity: |
|
|
|
|
|
while len(conv_buffer) < buffer_size: |
|
|
refill_buffer() |
|
|
|
|
|
remaining = row_capacity - len(row) |
|
|
|
|
|
|
|
|
best_idx = -1 |
|
|
best_len = 0 |
|
|
for i, conv in enumerate(conv_buffer): |
|
|
conv_len = len(conv) |
|
|
if conv_len <= remaining and conv_len > best_len: |
|
|
best_idx = i |
|
|
best_len = conv_len |
|
|
|
|
|
if best_idx >= 0: |
|
|
|
|
|
conv = conv_buffer.pop(best_idx) |
|
|
row.extend(conv) |
|
|
consumed += ddp_world_size |
|
|
else: |
|
|
|
|
|
|
|
|
content_len = len(row) |
|
|
row.extend([bos_token] * remaining) |
|
|
padded = True |
|
|
break |
|
|
|
|
|
|
|
|
if padded: |
|
|
row_lengths.append(content_len) |
|
|
else: |
|
|
row_lengths.append(row_capacity) |
|
|
rows.append(row[:row_capacity]) |
|
|
|
|
|
|
|
|
it += 1 |
|
|
if 0 < args.num_iterations <= it and split == "train": |
|
|
last_step = True |
|
|
|
|
|
|
|
|
if split == "train": |
|
|
current_epoch = epoch |
|
|
if args.num_iterations > 0: |
|
|
approx_progress = it / args.num_iterations |
|
|
else: |
|
|
approx_progress = consumed / dataset_size |
|
|
|
|
|
if consumed >= dataset_size: |
|
|
last_step = True |
|
|
|
|
|
|
|
|
use_cuda = device_type == "cuda" |
|
|
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda) |
|
|
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda) |
|
|
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda) |
|
|
|
|
|
|
|
|
|
|
|
for i, content_len in enumerate(row_lengths): |
|
|
if content_len < row_capacity: |
|
|
targets[i, content_len-1:] = -1 |
|
|
|
|
|
yield inputs, targets |
|
|
|
|
|
train_loader = sft_data_generator_bos_bestfit("train") |
|
|
build_val_loader = lambda: sft_data_generator_bos_bestfit("val") |
|
|
progress = 0 |
|
|
|
|
|
|
|
|
def get_lr_multiplier(progress): |
|
|
|
|
|
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2 |
|
|
|
|
|
|
|
|
def get_muon_momentum(it): |
|
|
frac = min(it / 300, 1) |
|
|
momentum = (1 - frac) * 0.85 + frac * 0.95 |
|
|
return momentum |
|
|
|
|
|
|
|
|
|
|
|
x, y = next(train_loader) |
|
|
min_val_bpb = float("inf") |
|
|
smooth_train_loss = 0 |
|
|
ema_beta = 0.9 |
|
|
total_training_time = 0 |
|
|
step = 0 |
|
|
while True: |
|
|
flops_so_far = num_flops_per_token * args.total_batch_size * step |
|
|
|
|
|
|
|
|
if ddp: |
|
|
last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device) |
|
|
dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX) |
|
|
last_step = bool(last_step_tensor.item()) |
|
|
|
|
|
|
|
|
|
|
|
if last_step or (args.eval_every > 0 and step % args.eval_every == 0): |
|
|
model.eval() |
|
|
val_loader = build_val_loader() |
|
|
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) |
|
|
with autocast_ctx: |
|
|
val_bpb = evaluate_bpb(orig_model, val_loader, eval_steps, token_bytes) |
|
|
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") |
|
|
if val_bpb < min_val_bpb: |
|
|
min_val_bpb = val_bpb |
|
|
wandb_run.log({ |
|
|
"step": step, |
|
|
"total_training_flops": flops_so_far, |
|
|
"total_training_time": total_training_time, |
|
|
"val/bpb": val_bpb, |
|
|
}) |
|
|
model.train() |
|
|
|
|
|
|
|
|
if master_process and last_step and not args.dry_run: |
|
|
output_dirname = args.model_tag if args.model_tag else f"d{depth}" |
|
|
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname) |
|
|
save_checkpoint( |
|
|
checkpoint_dir, |
|
|
step, |
|
|
orig_model.state_dict(), |
|
|
optimizer.state_dict(), |
|
|
{ |
|
|
"step": step, |
|
|
"val_bpb": val_bpb, |
|
|
"model_config": { |
|
|
"sequence_len": args.max_seq_len, |
|
|
"vocab_size": tokenizer.get_vocab_size(), |
|
|
"n_layer": depth, |
|
|
"n_head": model.config.n_head, |
|
|
"n_kv_head": model.config.n_kv_head, |
|
|
"n_embd": model.config.n_embd, |
|
|
"window_pattern": model.config.window_pattern, |
|
|
}, |
|
|
"user_config": user_config, |
|
|
} |
|
|
) |
|
|
|
|
|
if last_step: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
synchronize() |
|
|
t0 = time.time() |
|
|
for micro_step in range(grad_accum_steps): |
|
|
with autocast_ctx: |
|
|
loss = model(x, y) |
|
|
train_loss = loss.detach() |
|
|
loss = loss / grad_accum_steps |
|
|
loss.backward() |
|
|
x, y = next(train_loader) |
|
|
progress = max(progress, approx_progress) |
|
|
|
|
|
lrm = get_lr_multiplier(progress) |
|
|
muon_momentum = get_muon_momentum(step) |
|
|
for group in optimizer.param_groups: |
|
|
group["lr"] = group["initial_lr"] * lrm |
|
|
if group['kind'] == 'muon': |
|
|
group["momentum"] = muon_momentum |
|
|
optimizer.step() |
|
|
model.zero_grad(set_to_none=True) |
|
|
synchronize() |
|
|
t1 = time.time() |
|
|
dt = t1 - t0 |
|
|
|
|
|
|
|
|
|
|
|
step += 1 |
|
|
|
|
|
|
|
|
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() |
|
|
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) |
|
|
pct_done = 100 * progress |
|
|
tok_per_sec = int(args.total_batch_size / dt) |
|
|
flops_per_sec = num_flops_per_token * args.total_batch_size / dt |
|
|
promised_flops_per_sec_h100 = 989e12 * ddp_world_size |
|
|
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 |
|
|
if step > 10: |
|
|
total_training_time += dt |
|
|
|
|
|
eta_str = "" |
|
|
if step > 10 and progress > 0.01 and total_training_time > 0: |
|
|
eta_min = total_training_time / 60 * (1 - progress) / progress |
|
|
eta_str = f" | ETA: {eta_min:.1f}m" |
|
|
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m{eta_str}") |
|
|
if step % 10 == 0: |
|
|
wandb_run.log({ |
|
|
"step": step, |
|
|
"total_training_flops": flops_so_far, |
|
|
"total_training_time": total_training_time, |
|
|
"train/loss": debiased_smooth_loss, |
|
|
"train/lrm": lrm, |
|
|
"train/dt": dt, |
|
|
"train/tok_per_sec": tok_per_sec, |
|
|
"train/mfu": mfu, |
|
|
"train/epoch": current_epoch, |
|
|
}) |
|
|
|
|
|
|
|
|
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") |
|
|
print0(f"Total training time: {total_training_time/60:.2f}m") |
|
|
print0(f"Minimum validation bpb: {min_val_bpb:.4f}") |
|
|
|
|
|
|
|
|
if not args.dry_run: |
|
|
from nanochat.report import get_report |
|
|
get_report().log(section="SFT", data=[ |
|
|
user_config, |
|
|
{ |
|
|
"Number of iterations": step, |
|
|
"DDP world size": ddp_world_size, |
|
|
}, |
|
|
{ |
|
|
"Minimum validation bpb": min_val_bpb, |
|
|
} |
|
|
]) |
|
|
|
|
|
|
|
|
wandb_run.finish() |
|
|
compute_cleanup() |
|
|
|