| |
| """ |
| KL Distillation Training - TOML-driven, accelerate multi-GPU. |
| |
| Run with: |
| accelerate launch --config_file configs/accelerate.yaml distill.py --config configs/base.toml |
| |
| The TOML config is the single source of truth - no hardcoded defaults in this file. |
| The only command line argument is --config <path-to-toml>. |
| """ |
|
|
| import os |
| |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
|
| import argparse |
| import gc |
| import json |
| import logging |
| import shutil |
| import time |
| import tomllib |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint as checkpoint_utils |
| from torch.optim import AdamW |
|
|
| from accelerate import Accelerator |
| from accelerate.utils import set_seed |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(message)s", |
| datefmt="%H:%M:%S", |
| ) |
| log = logging.getLogger("distill") |
|
|
|
|
| |
| |
| |
|
|
| REQUIRED_SECTIONS = ("model", "data", "train", "eval", "log", "init") |
| REQUIRED_KEYS = { |
| "model": ("teacher", "student", "tokenizer"), |
| "data": ( |
| "dataset", |
| "text_field", |
| "min_chars", |
| "max_seq_len", |
| "kl_start_pos", |
| "seed", |
| "shuffle_buffer", |
| ), |
| "train": ( |
| "seed", |
| "lr", |
| "schedule", |
| "warmup_steps", |
| "weight_decay", |
| "grad_clip", |
| "betas", |
| "eps", |
| "samples_per_step", |
| "max_steps", |
| "grad_checkpointing", |
| "attn_implementation", |
| "student_dtype", |
| "teacher_dtype", |
| "mixed_precision", |
| "kl_chunk_size", |
| "micro_batch_size", |
| "new_layer_lr_mul", |
| ), |
| "eval": ("every_steps", "samples", "seed"), |
| "log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"), |
| "init": ("zero_layers", "target_num_layers"), |
| } |
|
|
| DTYPE_MAP = { |
| "float32": torch.float32, |
| "bfloat16": torch.bfloat16, |
| } |
|
|
|
|
| def parse_dtype(s): |
| if s not in DTYPE_MAP: |
| raise ValueError(f"unknown dtype {s!r}; must be one of {list(DTYPE_MAP)}") |
| return DTYPE_MAP[s] |
|
|
|
|
| def load_config(path): |
| with open(path, "rb") as f: |
| cfg = tomllib.load(f) |
| for sec in REQUIRED_SECTIONS: |
| if sec not in cfg: |
| raise KeyError(f"config missing required section [{sec}]") |
| for key in REQUIRED_KEYS[sec]: |
| if key not in cfg[sec]: |
| raise KeyError(f"config missing required key [{sec}].{key}") |
| return cfg |
|
|
|
|
| |
| |
| |
|
|
| def get_inner_with_layers(model): |
| """Walk wrappers (model, language_model, transformer, ...) to find an |
| object that has `.layers`. Used by zero_layers.""" |
| seen = set() |
| stack = [model] |
| while stack: |
| m = stack.pop() |
| if id(m) in seen: |
| continue |
| seen.add(id(m)) |
| if hasattr(m, "layers"): |
| return m |
| for attr in ("model", "language_model", "transformer", "base_model"): |
| child = getattr(m, attr, None) |
| if child is not None: |
| stack.append(child) |
| raise RuntimeError(f"Could not locate `.layers` inside {type(model).__name__}") |
|
|
|
|
| def zero_layers(model, layer_indices): |
| inner = get_inner_with_layers(model) |
| layers = inner.layers |
| n = len(layers) |
| for idx in layer_indices: |
| if idx < 0 or idx >= n: |
| raise IndexError(f"layer {idx} out of range (0..{n - 1})") |
| with torch.no_grad(): |
| for p in layers[idx].parameters(): |
| p.zero_() |
| return n |
|
|
|
|
| def _zero_output_projections(layer): |
| """Zero out attention and MLP output projections so the layer is identity |
| at init while still allowing gradients to flow into o_proj/down_proj first |
| (and from there back into the rest of the layer's params after one step). |
| |
| Knows about Qwen3.5 names: self_attn.o_proj (full attention), |
| linear_attn.out_proj (linear attention), mlp.down_proj. |
| """ |
| zeroed = [] |
| with torch.no_grad(): |
| if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "o_proj"): |
| layer.self_attn.o_proj.weight.zero_() |
| zeroed.append("self_attn.o_proj") |
| if hasattr(layer, "linear_attn") and hasattr(layer.linear_attn, "out_proj"): |
| layer.linear_attn.out_proj.weight.zero_() |
| zeroed.append("linear_attn.out_proj") |
| if hasattr(layer, "mlp") and hasattr(layer.mlp, "down_proj"): |
| layer.mlp.down_proj.weight.zero_() |
| zeroed.append("mlp.down_proj") |
| return zeroed |
|
|
|
|
| def grow_layers(model, target_n): |
| """Grow the student to `target_n` decoder layers by appending new ones at the end. |
| |
| New layers are constructed via the existing decoder layer class with the model's |
| own _init_weights, then their output projections are zeroed so each new layer |
| starts as the identity but is still trainable. |
| """ |
| inner = get_inner_with_layers(model) |
| cur_n = len(inner.layers) |
| if target_n == cur_n: |
| return cur_n |
| if target_n < cur_n: |
| raise ValueError(f"target_num_layers={target_n} < current {cur_n}; cannot shrink") |
|
|
| |
| |
| |
| cfg = model.config |
| text_cfg = getattr(cfg, "text_config", cfg) |
|
|
| |
| if not hasattr(text_cfg, "layer_types") or not text_cfg.layer_types: |
| raise RuntimeError("text config has no layer_types; cannot extend pattern") |
| period = getattr(text_cfg, "full_attention_interval", 4) |
| new_types = list(text_cfg.layer_types) |
| while len(new_types) < target_n: |
| new_types.append(new_types[len(new_types) % period]) |
| text_cfg.layer_types = new_types |
| text_cfg.num_hidden_layers = target_n |
| if hasattr(cfg, "num_hidden_layers") and cfg is not text_cfg: |
| cfg.num_hidden_layers = target_n |
|
|
| |
| layer_cls = type(inner.layers[0]) |
| device = next(inner.parameters()).device |
| dtype = next(inner.parameters()).dtype |
|
|
| new_layer_zeroed = [] |
| for i in range(cur_n, target_n): |
| new_layer = layer_cls(text_cfg, layer_idx=i) |
| |
| new_layer.apply(model._init_weights) |
| new_layer.to(device=device, dtype=dtype) |
| |
| zeroed = _zero_output_projections(new_layer) |
| new_layer_zeroed.append((i, zeroed)) |
| inner.layers.append(new_layer) |
|
|
| return target_n, new_layer_zeroed |
|
|
|
|
| def load_student(model_id, dtype, grad_ckpt, attn_impl): |
| from transformers import AutoModelForCausalLM |
| log.info(f"Loading student: {model_id} (dtype={dtype})") |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| dtype=dtype, |
| low_cpu_mem_usage=True, |
| attn_implementation=attn_impl, |
| ) |
| model.config.use_cache = False |
| if grad_ckpt: |
| model.gradient_checkpointing_enable( |
| gradient_checkpointing_kwargs={"use_reentrant": False} |
| ) |
| return model |
|
|
|
|
| def load_teacher(model_id, dtype, attn_impl): |
| """Load teacher model. Handles both pure CausalLM and multimodal |
| (ConditionalGeneration) wrappers.""" |
| from transformers import AutoConfig |
| cfg = AutoConfig.from_pretrained(model_id) |
| archs = list(getattr(cfg, "architectures", []) or []) |
| arch = archs[0] if archs else "" |
| is_multimodal = "ConditionalGeneration" in arch or "ImageText" in arch |
| log.info(f"Loading teacher: {model_id} (arch={arch}, multimodal={is_multimodal}, dtype={dtype})") |
|
|
| if is_multimodal: |
| from transformers import AutoModelForImageTextToText |
| model = AutoModelForImageTextToText.from_pretrained( |
| model_id, |
| dtype=dtype, |
| low_cpu_mem_usage=True, |
| attn_implementation=attn_impl, |
| ) |
| else: |
| from transformers import AutoModelForCausalLM |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| dtype=dtype, |
| low_cpu_mem_usage=True, |
| attn_implementation=attn_impl, |
| ) |
| model.config.use_cache = False |
| model.eval() |
| for p in model.parameters(): |
| p.requires_grad_(False) |
| return model |
|
|
|
|
| def teacher_forward(teacher, input_ids, attention_mask): |
| """Get teacher logits whether the model is unimodal or multimodal.""" |
| out = teacher(input_ids=input_ids, attention_mask=attention_mask) |
| logits = getattr(out, "logits", None) |
| if logits is None: |
| raise RuntimeError("teacher forward did not return .logits") |
| return logits |
|
|
|
|
| |
| |
| |
|
|
| class StreamingTextLoader: |
| """Per-rank shard of a HF streaming dataset, yielding tokenized samples.""" |
|
|
| def __init__( |
| self, |
| name, |
| text_field, |
| min_chars, |
| max_seq_len, |
| kl_start_pos, |
| tokenizer, |
| rank, |
| world_size, |
| seed, |
| shuffle_buffer, |
| ): |
| from datasets import load_dataset |
| from datasets.distributed import split_dataset_by_node |
|
|
| |
| last_err = None |
| for attempt in range(8): |
| try: |
| ds = load_dataset(name, split="train", streaming=True) |
| break |
| except Exception as e: |
| last_err = e |
| wait = min(2 ** attempt, 30) |
| log.warning( |
| f"load_dataset({name!r}) failed (attempt {attempt + 1}/8): " |
| f"{type(e).__name__}: {e}; sleeping {wait}s" |
| ) |
| time.sleep(wait) |
| else: |
| raise RuntimeError(f"load_dataset failed after 8 retries") from last_err |
| ds = ds.shuffle(seed=seed, buffer_size=shuffle_buffer) |
| ds = split_dataset_by_node(ds, rank=rank, world_size=world_size) |
| self._ds = iter(ds) |
| self._text_field = text_field |
| self._min_chars = min_chars |
| self._max_seq_len = max_seq_len |
| self._min_tokens = kl_start_pos + 16 |
| self._tokenizer = tokenizer |
|
|
| def next_batch(self, n): |
| out = [] |
| scanned = 0 |
| while len(out) < n and scanned < n * 50: |
| try: |
| item = next(self._ds) |
| except StopIteration: |
| break |
| scanned += 1 |
| text = item.get(self._text_field, "") or "" |
| if len(text) < self._min_chars: |
| continue |
| ids = self._tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=self._max_seq_len, |
| ).input_ids.squeeze(0) |
| if ids.shape[0] < self._min_tokens: |
| continue |
| out.append(ids) |
| return out |
|
|
|
|
| def collate_pad(token_lists, pad_id): |
| """Right-pad a list of [L_i] tensors into [B, max_L] + attention_mask.""" |
| max_len = max(t.shape[0] for t in token_lists) |
| B = len(token_lists) |
| input_ids = torch.full((B, max_len), pad_id, dtype=torch.long) |
| attention_mask = torch.zeros((B, max_len), dtype=torch.long) |
| for i, t in enumerate(token_lists): |
| L = t.shape[0] |
| input_ids[i, :L] = t |
| attention_mask[i, :L] = 1 |
| return input_ids, attention_mask |
|
|
|
|
| |
| |
| |
|
|
| def _kl_chunk_sum(s_chunk, t_chunk, m_chunk): |
| """Compute (sum of masked KL) over a slice. Used as a checkpointed unit so the |
| fp32 softmax intermediates only live for one chunk's worth of memory at a time.""" |
| s = s_chunk.float() |
| t = t_chunk.float() |
| t_log_p = F.log_softmax(t, dim=-1) |
| s_log_p = F.log_softmax(s, dim=-1) |
| t_p = t_log_p.exp() |
| per_token = (t_p * (t_log_p - s_log_p)).sum(-1) |
| return (per_token * m_chunk).sum() |
|
|
|
|
| def kl_loss_masked(student_logits, teacher_logits, attention_mask, start_pos, chunk_size): |
| """Forward KL(teacher || student), masked for padding & start_pos, in fp32. |
| |
| If chunk_size > 0, processes the [start_pos:] sequence in chunks of that many |
| positions, with gradient checkpointing on each chunk so peak memory is bounded |
| by one chunk's intermediates rather than the full sequence's. |
| """ |
| s_full = student_logits[:, start_pos:, :] |
| t_full = teacher_logits[:, start_pos:, :].detach() |
| m_full = attention_mask[:, start_pos:].float() |
|
|
| T = s_full.shape[1] |
| if chunk_size <= 0 or chunk_size >= T: |
| return _kl_chunk_sum(s_full, t_full, m_full) / m_full.sum().clamp_min(1.0) |
|
|
| total_kl = torch.zeros((), device=s_full.device, dtype=torch.float32) |
| for i in range(0, T, chunk_size): |
| end = min(i + chunk_size, T) |
| s_c = s_full[:, i:end, :] |
| t_c = t_full[:, i:end, :] |
| m_c = m_full[:, i:end] |
| chunk_kl = checkpoint_utils.checkpoint( |
| _kl_chunk_sum, s_c, t_c, m_c, use_reentrant=False |
| ) |
| total_kl = total_kl + chunk_kl |
| return total_kl / m_full.sum().clamp_min(1.0) |
|
|
|
|
| |
| |
| |
|
|
| def make_optimizer(model, train_cfg, new_layer_indices=None): |
| """Create AdamW. If `new_layer_lr_mul != 1.0` and we know which layers are |
| 'new' (returned from grow_layers), put their params in a separate group with |
| a multiplied LR. Useful for the 'wake up new layers without disturbing the |
| old ones' regime.""" |
| base_lr = train_cfg["lr"] |
| mul = train_cfg["new_layer_lr_mul"] |
| common = dict( |
| weight_decay=train_cfg["weight_decay"], |
| betas=tuple(train_cfg["betas"]), |
| eps=train_cfg["eps"], |
| ) |
|
|
| if not new_layer_indices or mul == 1.0: |
| return AdamW( |
| [p for p in model.parameters() if p.requires_grad], |
| lr=base_lr, |
| **common, |
| ) |
|
|
| inner = get_inner_with_layers(model) |
| new_pids = set() |
| for idx in new_layer_indices: |
| for p in inner.layers[idx].parameters(): |
| if p.requires_grad: |
| new_pids.add(id(p)) |
|
|
| new_params = [] |
| rest_params = [] |
| for p in model.parameters(): |
| if not p.requires_grad: |
| continue |
| (new_params if id(p) in new_pids else rest_params).append(p) |
|
|
| return AdamW( |
| [ |
| {"params": rest_params, "lr": base_lr}, |
| {"params": new_params, "lr": base_lr * mul}, |
| ], |
| **common, |
| ) |
|
|
|
|
| def make_scheduler(optimizer, train_cfg): |
| schedule = train_cfg["schedule"] |
| warmup = train_cfg["warmup_steps"] |
| total = train_cfg["max_steps"] |
|
|
| if schedule == "constant": |
| from transformers import get_constant_schedule_with_warmup |
| return get_constant_schedule_with_warmup(optimizer, warmup) |
| if schedule == "cosine": |
| from transformers import get_cosine_schedule_with_warmup |
| return get_cosine_schedule_with_warmup(optimizer, warmup, total) |
| if schedule == "linear": |
| from transformers import get_linear_schedule_with_warmup |
| return get_linear_schedule_with_warmup(optimizer, warmup, total) |
| raise ValueError(f"unknown schedule: {schedule!r}") |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def evaluate(accelerator, student, teacher, eval_batches, pad_id, kl_start_pos, kl_chunk_size): |
| student.eval() |
| sdev = accelerator.device |
| total = 0.0 |
| n = 0 |
| for sample in eval_batches: |
| ids, mask = collate_pad([sample], pad_id) |
| ids = ids.to(sdev) |
| mask = mask.to(sdev) |
| t_logits = teacher_forward(teacher, ids, mask) |
| s_logits = student(input_ids=ids, attention_mask=mask).logits |
| loss = kl_loss_masked( |
| s_logits, t_logits, mask, |
| start_pos=kl_start_pos, chunk_size=kl_chunk_size, |
| ) |
| total += loss.item() |
| n += 1 |
| del t_logits, s_logits, loss |
| student.train() |
| if n == 0: |
| local = torch.tensor(float("inf"), device=sdev) |
| else: |
| local = torch.tensor(total / n, device=sdev) |
| gathered = accelerator.gather(local.unsqueeze(0)) |
| return gathered.mean().item() |
|
|
|
|
| def save_best(accelerator, student, tokenizer, output_dir, step, eval_kl): |
| accelerator.wait_for_everyone() |
| if accelerator.is_main_process: |
| out_dir = Path(output_dir) / "best" |
| if out_dir.exists(): |
| shutil.rmtree(out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| unwrapped = accelerator.unwrap_model(student) |
| unwrapped.save_pretrained(out_dir, safe_serialization=True) |
| tokenizer.save_pretrained(out_dir) |
| with open(out_dir / "best.json", "w") as f: |
| json.dump({"step": step, "eval_kl": eval_kl}, f, indent=2) |
| log.info(f" saved best @ step {step}: eval_kl={eval_kl:.6f} -> {out_dir}") |
| accelerator.wait_for_everyone() |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--config", required=True, help="Path to TOML config") |
| args = p.parse_args() |
|
|
| cfg = load_config(args.config) |
|
|
| accelerator = Accelerator(mixed_precision=cfg["train"]["mixed_precision"]) |
| set_seed(cfg["train"]["seed"]) |
|
|
| if accelerator.is_main_process: |
| log.info(f"Loaded config from {args.config}") |
| log.info(f"World size: {accelerator.num_processes}") |
| log.info(f"Mixed precision: {cfg['train']['mixed_precision']}") |
|
|
| |
| from transformers import AutoTokenizer |
| tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["tokenizer"]) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| pad_id = tokenizer.pad_token_id |
|
|
| |
| student_dtype = parse_dtype(cfg["train"]["student_dtype"]) |
| teacher_dtype = parse_dtype(cfg["train"]["teacher_dtype"]) |
| student = load_student( |
| cfg["model"]["student"], |
| student_dtype, |
| grad_ckpt=cfg["train"]["grad_checkpointing"], |
| attn_impl=cfg["train"]["attn_implementation"], |
| ) |
| teacher = load_teacher( |
| cfg["model"]["teacher"], |
| teacher_dtype, |
| attn_impl=cfg["train"]["attn_implementation"], |
| ) |
|
|
| |
| target_n = cfg["init"]["target_num_layers"] |
| cur_n = len(get_inner_with_layers(student).layers) |
| new_layer_indices = [] |
| if target_n != cur_n: |
| new_n, new_zeroed = grow_layers(student, target_n) |
| new_layer_indices = [idx for idx, _ in new_zeroed] |
| if accelerator.is_main_process: |
| log.info(f"Grew student from {cur_n} -> {new_n} layers") |
| for idx, names in new_zeroed: |
| log.info(f" layer {idx}: zeroed {names}") |
|
|
| zero_idx = cfg["init"]["zero_layers"] |
| if zero_idx: |
| n = zero_layers(student, zero_idx) |
| if accelerator.is_main_process: |
| log.info(f"Zeroed student layers {zero_idx} (model has {n} layers)") |
|
|
| teacher = teacher.to(accelerator.device) |
|
|
| |
| optimizer = make_optimizer(student, cfg["train"], new_layer_indices=new_layer_indices) |
| scheduler = make_scheduler(optimizer, cfg["train"]) |
| if accelerator.is_main_process and len(optimizer.param_groups) > 1: |
| log.info( |
| f"Param groups: rest lr={optimizer.param_groups[0]['lr']:.2e}, " |
| f"new lr={optimizer.param_groups[1]['lr']:.2e} " |
| f"({len(new_layer_indices)} layers grown)" |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| student, optimizer = accelerator.prepare(student, optimizer) |
|
|
| |
| output_dir = Path(cfg["log"]["output_dir"]) |
| if accelerator.is_main_process: |
| output_dir.mkdir(parents=True, exist_ok=True) |
| shutil.copy2(args.config, output_dir / "config.snapshot.toml") |
|
|
| |
| use_wandb = cfg["log"]["wandb"] |
| if use_wandb and accelerator.is_main_process: |
| import wandb |
| wandb.init( |
| project=cfg["log"]["wandb_project"], |
| name=cfg["log"]["wandb_run"], |
| config=cfg, |
| ) |
|
|
| |
| train_loader = StreamingTextLoader( |
| name=cfg["data"]["dataset"], |
| text_field=cfg["data"]["text_field"], |
| min_chars=cfg["data"]["min_chars"], |
| max_seq_len=cfg["data"]["max_seq_len"], |
| kl_start_pos=cfg["data"]["kl_start_pos"], |
| tokenizer=tokenizer, |
| rank=accelerator.process_index, |
| world_size=accelerator.num_processes, |
| seed=cfg["data"]["seed"], |
| shuffle_buffer=cfg["data"]["shuffle_buffer"], |
| ) |
| eval_loader = StreamingTextLoader( |
| name=cfg["data"]["dataset"], |
| text_field=cfg["data"]["text_field"], |
| min_chars=cfg["data"]["min_chars"], |
| max_seq_len=cfg["data"]["max_seq_len"], |
| kl_start_pos=cfg["data"]["kl_start_pos"], |
| tokenizer=tokenizer, |
| rank=accelerator.process_index, |
| world_size=accelerator.num_processes, |
| seed=cfg["eval"]["seed"], |
| shuffle_buffer=cfg["data"]["shuffle_buffer"], |
| ) |
| eval_per_rank = max(1, cfg["eval"]["samples"] // accelerator.num_processes) |
| eval_batches = eval_loader.next_batch(eval_per_rank) |
| if accelerator.is_main_process: |
| log.info( |
| f"Eval set: {len(eval_batches)}/rank x {accelerator.num_processes} ranks " |
| f"= {len(eval_batches) * accelerator.num_processes} samples" |
| ) |
|
|
| |
| samples_per_step = cfg["train"]["samples_per_step"] |
| micro_batch_size = cfg["train"]["micro_batch_size"] |
| grad_clip = cfg["train"]["grad_clip"] |
| kl_start_pos = cfg["data"]["kl_start_pos"] |
| kl_chunk_size = cfg["train"]["kl_chunk_size"] |
| max_steps = cfg["train"]["max_steps"] |
| eval_every = cfg["eval"]["every_steps"] |
| log_every = cfg["log"]["log_every"] |
|
|
| if accelerator.is_main_process: |
| log.info( |
| f"=== Training: max_steps={max_steps}, samples_per_step={samples_per_step} " |
| f"(per rank, micro={micro_batch_size}), " |
| f"effective batch={samples_per_step * accelerator.num_processes}" |
| ) |
|
|
| student.train() |
| best_kl = float("inf") |
| global_step = 0 |
|
|
| while global_step < max_steps: |
| t0 = time.time() |
| batch = train_loader.next_batch(samples_per_step) |
| if not batch: |
| log.warning(f"rank {accelerator.process_index}: data exhausted") |
| break |
|
|
| optimizer.zero_grad() |
| batch_n = len(batch) |
| kl_sum = 0.0 |
| for mb_start in range(0, batch_n, micro_batch_size): |
| micro = batch[mb_start : mb_start + micro_batch_size] |
| mb_n = len(micro) |
| ids, mask = collate_pad(micro, pad_id) |
| ids = ids.to(accelerator.device) |
| mask = mask.to(accelerator.device) |
|
|
| with torch.no_grad(): |
| t_logits = teacher_forward(teacher, ids, mask) |
| s_logits = student(input_ids=ids, attention_mask=mask).logits |
| loss = kl_loss_masked( |
| s_logits, t_logits, mask, |
| start_pos=kl_start_pos, chunk_size=kl_chunk_size, |
| ) |
| |
| scaled = loss * (mb_n / batch_n) |
| accelerator.backward(scaled) |
| kl_sum += loss.item() * mb_n |
| del t_logits, s_logits, loss, scaled |
|
|
| if grad_clip > 0: |
| accelerator.clip_grad_norm_(student.parameters(), grad_clip) |
| optimizer.step() |
| scheduler.step() |
| global_step += 1 |
|
|
| elapsed = time.time() - t0 |
| kl_local = torch.tensor(kl_sum / batch_n, device=accelerator.device) |
| kl_avg = accelerator.gather(kl_local.unsqueeze(0)).mean().item() |
| del kl_local |
|
|
| if accelerator.is_main_process and global_step % log_every == 0: |
| lr_now = scheduler.get_last_lr()[0] |
| log.info( |
| f"step {global_step}/{max_steps} | kl {kl_avg:.4f} | " |
| f"lr {lr_now:.2e} | {elapsed:.2f}s" |
| ) |
| if use_wandb: |
| import wandb |
| wandb.log( |
| { |
| "train/kl": kl_avg, |
| "train/lr": lr_now, |
| "perf/step_time_s": elapsed, |
| }, |
| step=global_step, |
| ) |
|
|
| if global_step % eval_every == 0: |
| eval_kl = evaluate( |
| accelerator, student, teacher, eval_batches, |
| pad_id, kl_start_pos, kl_chunk_size, |
| ) |
| if accelerator.is_main_process: |
| log.info( |
| f" eval @ step {global_step}: kl={eval_kl:.6f} " |
| f"(best={best_kl:.6f})" |
| ) |
| if use_wandb: |
| import wandb |
| wandb.log({"eval/kl": eval_kl}, step=global_step) |
| if eval_kl < best_kl: |
| best_kl = eval_kl |
| save_best( |
| accelerator, student, tokenizer, output_dir, global_step, eval_kl |
| ) |
| student.train() |
|
|
| if global_step % 20 == 0: |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| |
| eval_kl = evaluate( |
| accelerator, student, teacher, eval_batches, |
| pad_id, kl_start_pos, kl_chunk_size, |
| ) |
| if accelerator.is_main_process: |
| log.info(f" final eval: kl={eval_kl:.6f} (best={best_kl:.6f})") |
| if use_wandb: |
| import wandb |
| wandb.log({"eval/kl": eval_kl}, step=global_step) |
| if eval_kl < best_kl: |
| best_kl = eval_kl |
| save_best(accelerator, student, tokenizer, output_dir, global_step, eval_kl) |
|
|
| if accelerator.is_main_process: |
| log.info(f"Done. Best eval KL = {best_kl:.6f}") |
| if use_wandb: |
| import wandb |
| wandb.finish() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|