Delta-Vector's picture
add 9-config hparam sweep + new_layer_lr_mul param-groups support
3af7f4c verified
#!/usr/bin/env python3
"""
KL Distillation Training - TOML-driven, accelerate multi-GPU.
Run with:
accelerate launch --config_file configs/accelerate.yaml distill.py --config configs/base.toml
The TOML config is the single source of truth - no hardcoded defaults in this file.
The only command line argument is --config <path-to-toml>.
"""
import os
# Reduce fragmentation; large vocab + long seq creates many short-lived big tensors.
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import argparse
import gc
import json
import logging
import shutil
import time
import tomllib
from pathlib import Path
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint_utils
from torch.optim import AdamW
from accelerate import Accelerator
from accelerate.utils import set_seed
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%H:%M:%S",
)
log = logging.getLogger("distill")
# ----------------------------------------------------------------------------
# Config
# ----------------------------------------------------------------------------
REQUIRED_SECTIONS = ("model", "data", "train", "eval", "log", "init")
REQUIRED_KEYS = {
"model": ("teacher", "student", "tokenizer"),
"data": (
"dataset",
"text_field",
"min_chars",
"max_seq_len",
"kl_start_pos",
"seed",
"shuffle_buffer",
),
"train": (
"seed",
"lr",
"schedule",
"warmup_steps",
"weight_decay",
"grad_clip",
"betas",
"eps",
"samples_per_step",
"max_steps",
"grad_checkpointing",
"attn_implementation",
"student_dtype",
"teacher_dtype",
"mixed_precision",
"kl_chunk_size",
"micro_batch_size",
"new_layer_lr_mul",
),
"eval": ("every_steps", "samples", "seed"),
"log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"),
"init": ("zero_layers", "target_num_layers"),
}
DTYPE_MAP = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
def parse_dtype(s):
if s not in DTYPE_MAP:
raise ValueError(f"unknown dtype {s!r}; must be one of {list(DTYPE_MAP)}")
return DTYPE_MAP[s]
def load_config(path):
with open(path, "rb") as f:
cfg = tomllib.load(f)
for sec in REQUIRED_SECTIONS:
if sec not in cfg:
raise KeyError(f"config missing required section [{sec}]")
for key in REQUIRED_KEYS[sec]:
if key not in cfg[sec]:
raise KeyError(f"config missing required key [{sec}].{key}")
return cfg
# ----------------------------------------------------------------------------
# Model loading
# ----------------------------------------------------------------------------
def get_inner_with_layers(model):
"""Walk wrappers (model, language_model, transformer, ...) to find an
object that has `.layers`. Used by zero_layers."""
seen = set()
stack = [model]
while stack:
m = stack.pop()
if id(m) in seen:
continue
seen.add(id(m))
if hasattr(m, "layers"):
return m
for attr in ("model", "language_model", "transformer", "base_model"):
child = getattr(m, attr, None)
if child is not None:
stack.append(child)
raise RuntimeError(f"Could not locate `.layers` inside {type(model).__name__}")
def zero_layers(model, layer_indices):
inner = get_inner_with_layers(model)
layers = inner.layers
n = len(layers)
for idx in layer_indices:
if idx < 0 or idx >= n:
raise IndexError(f"layer {idx} out of range (0..{n - 1})")
with torch.no_grad():
for p in layers[idx].parameters():
p.zero_()
return n
def _zero_output_projections(layer):
"""Zero out attention and MLP output projections so the layer is identity
at init while still allowing gradients to flow into o_proj/down_proj first
(and from there back into the rest of the layer's params after one step).
Knows about Qwen3.5 names: self_attn.o_proj (full attention),
linear_attn.out_proj (linear attention), mlp.down_proj.
"""
zeroed = []
with torch.no_grad():
if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "o_proj"):
layer.self_attn.o_proj.weight.zero_()
zeroed.append("self_attn.o_proj")
if hasattr(layer, "linear_attn") and hasattr(layer.linear_attn, "out_proj"):
layer.linear_attn.out_proj.weight.zero_()
zeroed.append("linear_attn.out_proj")
if hasattr(layer, "mlp") and hasattr(layer.mlp, "down_proj"):
layer.mlp.down_proj.weight.zero_()
zeroed.append("mlp.down_proj")
return zeroed
def grow_layers(model, target_n):
"""Grow the student to `target_n` decoder layers by appending new ones at the end.
New layers are constructed via the existing decoder layer class with the model's
own _init_weights, then their output projections are zeroed so each new layer
starts as the identity but is still trainable.
"""
inner = get_inner_with_layers(model)
cur_n = len(inner.layers)
if target_n == cur_n:
return cur_n
if target_n < cur_n:
raise ValueError(f"target_num_layers={target_n} < current {cur_n}; cannot shrink")
# Locate the (text) config that the layers are built from. For multimodal
# wrappers this lives at .text_config; for the dense student it's the same
# object as model.config.
cfg = model.config
text_cfg = getattr(cfg, "text_config", cfg)
# Extend layer_types by repeating the existing periodic pattern
if not hasattr(text_cfg, "layer_types") or not text_cfg.layer_types:
raise RuntimeError("text config has no layer_types; cannot extend pattern")
period = getattr(text_cfg, "full_attention_interval", 4)
new_types = list(text_cfg.layer_types)
while len(new_types) < target_n:
new_types.append(new_types[len(new_types) % period])
text_cfg.layer_types = new_types
text_cfg.num_hidden_layers = target_n
if hasattr(cfg, "num_hidden_layers") and cfg is not text_cfg:
cfg.num_hidden_layers = target_n
# Construct new layers using the same class as the existing ones
layer_cls = type(inner.layers[0])
device = next(inner.parameters()).device
dtype = next(inner.parameters()).dtype
new_layer_zeroed = []
for i in range(cur_n, target_n):
new_layer = layer_cls(text_cfg, layer_idx=i)
# Apply the parent model's init scheme (std=initializer_range etc.)
new_layer.apply(model._init_weights)
new_layer.to(device=device, dtype=dtype)
# Zero output projections -> identity at init, gradients still flow
zeroed = _zero_output_projections(new_layer)
new_layer_zeroed.append((i, zeroed))
inner.layers.append(new_layer)
return target_n, new_layer_zeroed
def load_student(model_id, dtype, grad_ckpt, attn_impl):
from transformers import AutoModelForCausalLM
log.info(f"Loading student: {model_id} (dtype={dtype})")
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=dtype,
low_cpu_mem_usage=True,
attn_implementation=attn_impl,
)
model.config.use_cache = False
if grad_ckpt:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
return model
def load_teacher(model_id, dtype, attn_impl):
"""Load teacher model. Handles both pure CausalLM and multimodal
(ConditionalGeneration) wrappers."""
from transformers import AutoConfig
cfg = AutoConfig.from_pretrained(model_id)
archs = list(getattr(cfg, "architectures", []) or [])
arch = archs[0] if archs else ""
is_multimodal = "ConditionalGeneration" in arch or "ImageText" in arch
log.info(f"Loading teacher: {model_id} (arch={arch}, multimodal={is_multimodal}, dtype={dtype})")
if is_multimodal:
from transformers import AutoModelForImageTextToText
model = AutoModelForImageTextToText.from_pretrained(
model_id,
dtype=dtype,
low_cpu_mem_usage=True,
attn_implementation=attn_impl,
)
else:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=dtype,
low_cpu_mem_usage=True,
attn_implementation=attn_impl,
)
model.config.use_cache = False
model.eval()
for p in model.parameters():
p.requires_grad_(False)
return model
def teacher_forward(teacher, input_ids, attention_mask):
"""Get teacher logits whether the model is unimodal or multimodal."""
out = teacher(input_ids=input_ids, attention_mask=attention_mask)
logits = getattr(out, "logits", None)
if logits is None:
raise RuntimeError("teacher forward did not return .logits")
return logits
# ----------------------------------------------------------------------------
# Data
# ----------------------------------------------------------------------------
class StreamingTextLoader:
"""Per-rank shard of a HF streaming dataset, yielding tokenized samples."""
def __init__(
self,
name,
text_field,
min_chars,
max_seq_len,
kl_start_pos,
tokenizer,
rank,
world_size,
seed,
shuffle_buffer,
):
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
# HF Hub occasionally returns 5xx during dataset metadata crawl. Retry.
last_err = None
for attempt in range(8):
try:
ds = load_dataset(name, split="train", streaming=True)
break
except Exception as e:
last_err = e
wait = min(2 ** attempt, 30)
log.warning(
f"load_dataset({name!r}) failed (attempt {attempt + 1}/8): "
f"{type(e).__name__}: {e}; sleeping {wait}s"
)
time.sleep(wait)
else:
raise RuntimeError(f"load_dataset failed after 8 retries") from last_err
ds = ds.shuffle(seed=seed, buffer_size=shuffle_buffer)
ds = split_dataset_by_node(ds, rank=rank, world_size=world_size)
self._ds = iter(ds)
self._text_field = text_field
self._min_chars = min_chars
self._max_seq_len = max_seq_len
self._min_tokens = kl_start_pos + 16
self._tokenizer = tokenizer
def next_batch(self, n):
out = []
scanned = 0
while len(out) < n and scanned < n * 50:
try:
item = next(self._ds)
except StopIteration:
break
scanned += 1
text = item.get(self._text_field, "") or ""
if len(text) < self._min_chars:
continue
ids = self._tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=self._max_seq_len,
).input_ids.squeeze(0)
if ids.shape[0] < self._min_tokens:
continue
out.append(ids)
return out
def collate_pad(token_lists, pad_id):
"""Right-pad a list of [L_i] tensors into [B, max_L] + attention_mask."""
max_len = max(t.shape[0] for t in token_lists)
B = len(token_lists)
input_ids = torch.full((B, max_len), pad_id, dtype=torch.long)
attention_mask = torch.zeros((B, max_len), dtype=torch.long)
for i, t in enumerate(token_lists):
L = t.shape[0]
input_ids[i, :L] = t
attention_mask[i, :L] = 1
return input_ids, attention_mask
# ----------------------------------------------------------------------------
# Loss
# ----------------------------------------------------------------------------
def _kl_chunk_sum(s_chunk, t_chunk, m_chunk):
"""Compute (sum of masked KL) over a slice. Used as a checkpointed unit so the
fp32 softmax intermediates only live for one chunk's worth of memory at a time."""
s = s_chunk.float()
t = t_chunk.float()
t_log_p = F.log_softmax(t, dim=-1)
s_log_p = F.log_softmax(s, dim=-1)
t_p = t_log_p.exp()
per_token = (t_p * (t_log_p - s_log_p)).sum(-1)
return (per_token * m_chunk).sum()
def kl_loss_masked(student_logits, teacher_logits, attention_mask, start_pos, chunk_size):
"""Forward KL(teacher || student), masked for padding & start_pos, in fp32.
If chunk_size > 0, processes the [start_pos:] sequence in chunks of that many
positions, with gradient checkpointing on each chunk so peak memory is bounded
by one chunk's intermediates rather than the full sequence's.
"""
s_full = student_logits[:, start_pos:, :]
t_full = teacher_logits[:, start_pos:, :].detach()
m_full = attention_mask[:, start_pos:].float()
T = s_full.shape[1]
if chunk_size <= 0 or chunk_size >= T:
return _kl_chunk_sum(s_full, t_full, m_full) / m_full.sum().clamp_min(1.0)
total_kl = torch.zeros((), device=s_full.device, dtype=torch.float32)
for i in range(0, T, chunk_size):
end = min(i + chunk_size, T)
s_c = s_full[:, i:end, :]
t_c = t_full[:, i:end, :]
m_c = m_full[:, i:end]
chunk_kl = checkpoint_utils.checkpoint(
_kl_chunk_sum, s_c, t_c, m_c, use_reentrant=False
)
total_kl = total_kl + chunk_kl
return total_kl / m_full.sum().clamp_min(1.0)
# ----------------------------------------------------------------------------
# Optimizer / scheduler
# ----------------------------------------------------------------------------
def make_optimizer(model, train_cfg, new_layer_indices=None):
"""Create AdamW. If `new_layer_lr_mul != 1.0` and we know which layers are
'new' (returned from grow_layers), put their params in a separate group with
a multiplied LR. Useful for the 'wake up new layers without disturbing the
old ones' regime."""
base_lr = train_cfg["lr"]
mul = train_cfg["new_layer_lr_mul"]
common = dict(
weight_decay=train_cfg["weight_decay"],
betas=tuple(train_cfg["betas"]),
eps=train_cfg["eps"],
)
if not new_layer_indices or mul == 1.0:
return AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=base_lr,
**common,
)
inner = get_inner_with_layers(model)
new_pids = set()
for idx in new_layer_indices:
for p in inner.layers[idx].parameters():
if p.requires_grad:
new_pids.add(id(p))
new_params = []
rest_params = []
for p in model.parameters():
if not p.requires_grad:
continue
(new_params if id(p) in new_pids else rest_params).append(p)
return AdamW(
[
{"params": rest_params, "lr": base_lr},
{"params": new_params, "lr": base_lr * mul},
],
**common,
)
def make_scheduler(optimizer, train_cfg):
schedule = train_cfg["schedule"]
warmup = train_cfg["warmup_steps"]
total = train_cfg["max_steps"]
if schedule == "constant":
from transformers import get_constant_schedule_with_warmup
return get_constant_schedule_with_warmup(optimizer, warmup)
if schedule == "cosine":
from transformers import get_cosine_schedule_with_warmup
return get_cosine_schedule_with_warmup(optimizer, warmup, total)
if schedule == "linear":
from transformers import get_linear_schedule_with_warmup
return get_linear_schedule_with_warmup(optimizer, warmup, total)
raise ValueError(f"unknown schedule: {schedule!r}")
# ----------------------------------------------------------------------------
# Eval
# ----------------------------------------------------------------------------
@torch.no_grad()
def evaluate(accelerator, student, teacher, eval_batches, pad_id, kl_start_pos, kl_chunk_size):
student.eval()
sdev = accelerator.device
total = 0.0
n = 0
for sample in eval_batches:
ids, mask = collate_pad([sample], pad_id)
ids = ids.to(sdev)
mask = mask.to(sdev)
t_logits = teacher_forward(teacher, ids, mask)
s_logits = student(input_ids=ids, attention_mask=mask).logits
loss = kl_loss_masked(
s_logits, t_logits, mask,
start_pos=kl_start_pos, chunk_size=kl_chunk_size,
)
total += loss.item()
n += 1
del t_logits, s_logits, loss
student.train()
if n == 0:
local = torch.tensor(float("inf"), device=sdev)
else:
local = torch.tensor(total / n, device=sdev)
gathered = accelerator.gather(local.unsqueeze(0))
return gathered.mean().item()
def save_best(accelerator, student, tokenizer, output_dir, step, eval_kl):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
out_dir = Path(output_dir) / "best"
if out_dir.exists():
shutil.rmtree(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
unwrapped = accelerator.unwrap_model(student)
unwrapped.save_pretrained(out_dir, safe_serialization=True)
tokenizer.save_pretrained(out_dir)
with open(out_dir / "best.json", "w") as f:
json.dump({"step": step, "eval_kl": eval_kl}, f, indent=2)
log.info(f" saved best @ step {step}: eval_kl={eval_kl:.6f} -> {out_dir}")
accelerator.wait_for_everyone()
# ----------------------------------------------------------------------------
# Main
# ----------------------------------------------------------------------------
def main():
p = argparse.ArgumentParser()
p.add_argument("--config", required=True, help="Path to TOML config")
args = p.parse_args()
cfg = load_config(args.config)
accelerator = Accelerator(mixed_precision=cfg["train"]["mixed_precision"])
set_seed(cfg["train"]["seed"])
if accelerator.is_main_process:
log.info(f"Loaded config from {args.config}")
log.info(f"World size: {accelerator.num_processes}")
log.info(f"Mixed precision: {cfg['train']['mixed_precision']}")
# ---- Tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["tokenizer"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
pad_id = tokenizer.pad_token_id
# ---- Models (separate dtypes per config)
student_dtype = parse_dtype(cfg["train"]["student_dtype"])
teacher_dtype = parse_dtype(cfg["train"]["teacher_dtype"])
student = load_student(
cfg["model"]["student"],
student_dtype,
grad_ckpt=cfg["train"]["grad_checkpointing"],
attn_impl=cfg["train"]["attn_implementation"],
)
teacher = load_teacher(
cfg["model"]["teacher"],
teacher_dtype,
attn_impl=cfg["train"]["attn_implementation"],
)
# ---- Layer modifications: grow first, then zero (composable)
target_n = cfg["init"]["target_num_layers"]
cur_n = len(get_inner_with_layers(student).layers)
new_layer_indices = []
if target_n != cur_n:
new_n, new_zeroed = grow_layers(student, target_n)
new_layer_indices = [idx for idx, _ in new_zeroed]
if accelerator.is_main_process:
log.info(f"Grew student from {cur_n} -> {new_n} layers")
for idx, names in new_zeroed:
log.info(f" layer {idx}: zeroed {names}")
zero_idx = cfg["init"]["zero_layers"]
if zero_idx:
n = zero_layers(student, zero_idx)
if accelerator.is_main_process:
log.info(f"Zeroed student layers {zero_idx} (model has {n} layers)")
teacher = teacher.to(accelerator.device)
# ---- Optimizer / scheduler
optimizer = make_optimizer(student, cfg["train"], new_layer_indices=new_layer_indices)
scheduler = make_scheduler(optimizer, cfg["train"])
if accelerator.is_main_process and len(optimizer.param_groups) > 1:
log.info(
f"Param groups: rest lr={optimizer.param_groups[0]['lr']:.2e}, "
f"new lr={optimizer.param_groups[1]['lr']:.2e} "
f"({len(new_layer_indices)} layers grown)"
)
# NB: do NOT pass `scheduler` to accelerator.prepare. When prepared, accelerate
# advances the scheduler by `num_processes` steps per call (to match the
# "single-GPU equivalent" timeline). Combined with our explicit max_steps
# accounting, that causes the cosine to cycle multiple times mid-run. By
# leaving the scheduler unprepared, scheduler.step() advances exactly once
# per training step, matching how max_steps is interpreted in this script.
student, optimizer = accelerator.prepare(student, optimizer)
# ---- Output dir + config snapshot
output_dir = Path(cfg["log"]["output_dir"])
if accelerator.is_main_process:
output_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(args.config, output_dir / "config.snapshot.toml")
# ---- Wandb
use_wandb = cfg["log"]["wandb"]
if use_wandb and accelerator.is_main_process:
import wandb
wandb.init(
project=cfg["log"]["wandb_project"],
name=cfg["log"]["wandb_run"],
config=cfg,
)
# ---- Data loaders
train_loader = StreamingTextLoader(
name=cfg["data"]["dataset"],
text_field=cfg["data"]["text_field"],
min_chars=cfg["data"]["min_chars"],
max_seq_len=cfg["data"]["max_seq_len"],
kl_start_pos=cfg["data"]["kl_start_pos"],
tokenizer=tokenizer,
rank=accelerator.process_index,
world_size=accelerator.num_processes,
seed=cfg["data"]["seed"],
shuffle_buffer=cfg["data"]["shuffle_buffer"],
)
eval_loader = StreamingTextLoader(
name=cfg["data"]["dataset"],
text_field=cfg["data"]["text_field"],
min_chars=cfg["data"]["min_chars"],
max_seq_len=cfg["data"]["max_seq_len"],
kl_start_pos=cfg["data"]["kl_start_pos"],
tokenizer=tokenizer,
rank=accelerator.process_index,
world_size=accelerator.num_processes,
seed=cfg["eval"]["seed"],
shuffle_buffer=cfg["data"]["shuffle_buffer"],
)
eval_per_rank = max(1, cfg["eval"]["samples"] // accelerator.num_processes)
eval_batches = eval_loader.next_batch(eval_per_rank)
if accelerator.is_main_process:
log.info(
f"Eval set: {len(eval_batches)}/rank x {accelerator.num_processes} ranks "
f"= {len(eval_batches) * accelerator.num_processes} samples"
)
# ---- Train loop
samples_per_step = cfg["train"]["samples_per_step"]
micro_batch_size = cfg["train"]["micro_batch_size"]
grad_clip = cfg["train"]["grad_clip"]
kl_start_pos = cfg["data"]["kl_start_pos"]
kl_chunk_size = cfg["train"]["kl_chunk_size"]
max_steps = cfg["train"]["max_steps"]
eval_every = cfg["eval"]["every_steps"]
log_every = cfg["log"]["log_every"]
if accelerator.is_main_process:
log.info(
f"=== Training: max_steps={max_steps}, samples_per_step={samples_per_step} "
f"(per rank, micro={micro_batch_size}), "
f"effective batch={samples_per_step * accelerator.num_processes}"
)
student.train()
best_kl = float("inf")
global_step = 0
while global_step < max_steps:
t0 = time.time()
batch = train_loader.next_batch(samples_per_step)
if not batch:
log.warning(f"rank {accelerator.process_index}: data exhausted")
break
optimizer.zero_grad()
batch_n = len(batch)
kl_sum = 0.0
for mb_start in range(0, batch_n, micro_batch_size):
micro = batch[mb_start : mb_start + micro_batch_size]
mb_n = len(micro)
ids, mask = collate_pad(micro, pad_id)
ids = ids.to(accelerator.device)
mask = mask.to(accelerator.device)
with torch.no_grad():
t_logits = teacher_forward(teacher, ids, mask)
s_logits = student(input_ids=ids, attention_mask=mask).logits
loss = kl_loss_masked(
s_logits, t_logits, mask,
start_pos=kl_start_pos, chunk_size=kl_chunk_size,
)
# Weight by micro size so summing micros gives the batch mean
scaled = loss * (mb_n / batch_n)
accelerator.backward(scaled)
kl_sum += loss.item() * mb_n
del t_logits, s_logits, loss, scaled
if grad_clip > 0:
accelerator.clip_grad_norm_(student.parameters(), grad_clip)
optimizer.step()
scheduler.step()
global_step += 1
elapsed = time.time() - t0
kl_local = torch.tensor(kl_sum / batch_n, device=accelerator.device)
kl_avg = accelerator.gather(kl_local.unsqueeze(0)).mean().item()
del kl_local
if accelerator.is_main_process and global_step % log_every == 0:
lr_now = scheduler.get_last_lr()[0]
log.info(
f"step {global_step}/{max_steps} | kl {kl_avg:.4f} | "
f"lr {lr_now:.2e} | {elapsed:.2f}s"
)
if use_wandb:
import wandb
wandb.log(
{
"train/kl": kl_avg,
"train/lr": lr_now,
"perf/step_time_s": elapsed,
},
step=global_step,
)
if global_step % eval_every == 0:
eval_kl = evaluate(
accelerator, student, teacher, eval_batches,
pad_id, kl_start_pos, kl_chunk_size,
)
if accelerator.is_main_process:
log.info(
f" eval @ step {global_step}: kl={eval_kl:.6f} "
f"(best={best_kl:.6f})"
)
if use_wandb:
import wandb
wandb.log({"eval/kl": eval_kl}, step=global_step)
if eval_kl < best_kl:
best_kl = eval_kl
save_best(
accelerator, student, tokenizer, output_dir, global_step, eval_kl
)
student.train()
if global_step % 20 == 0:
gc.collect()
torch.cuda.empty_cache()
# Final eval
eval_kl = evaluate(
accelerator, student, teacher, eval_batches,
pad_id, kl_start_pos, kl_chunk_size,
)
if accelerator.is_main_process:
log.info(f" final eval: kl={eval_kl:.6f} (best={best_kl:.6f})")
if use_wandb:
import wandb
wandb.log({"eval/kl": eval_kl}, step=global_step)
if eval_kl < best_kl:
best_kl = eval_kl
save_best(accelerator, student, tokenizer, output_dir, global_step, eval_kl)
if accelerator.is_main_process:
log.info(f"Done. Best eval KL = {best_kl:.6f}")
if use_wandb:
import wandb
wandb.finish()
if __name__ == "__main__":
main()