#!/usr/bin/env python3 """ KL Distillation Training - TOML-driven, accelerate multi-GPU. Run with: accelerate launch --config_file configs/accelerate.yaml distill.py --config configs/base.toml The TOML config is the single source of truth - no hardcoded defaults in this file. The only command line argument is --config . """ import os # Reduce fragmentation; large vocab + long seq creates many short-lived big tensors. os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") import argparse import gc import json import logging import shutil import time import tomllib from pathlib import Path import torch import torch.nn.functional as F import torch.utils.checkpoint as checkpoint_utils from torch.optim import AdamW from accelerate import Accelerator from accelerate.utils import set_seed logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S", ) log = logging.getLogger("distill") # ---------------------------------------------------------------------------- # Config # ---------------------------------------------------------------------------- REQUIRED_SECTIONS = ("model", "data", "train", "eval", "log", "init") REQUIRED_KEYS = { "model": ("teacher", "student", "tokenizer"), "data": ( "dataset", "text_field", "min_chars", "max_seq_len", "kl_start_pos", "seed", "shuffle_buffer", ), "train": ( "seed", "lr", "schedule", "warmup_steps", "weight_decay", "grad_clip", "betas", "eps", "samples_per_step", "max_steps", "grad_checkpointing", "attn_implementation", "student_dtype", "teacher_dtype", "mixed_precision", "kl_chunk_size", "micro_batch_size", "new_layer_lr_mul", ), "eval": ("every_steps", "samples", "seed"), "log": ("wandb", "wandb_project", "wandb_run", "log_every", "output_dir"), "init": ("zero_layers", "target_num_layers"), } DTYPE_MAP = { "float32": torch.float32, "bfloat16": torch.bfloat16, } def parse_dtype(s): if s not in DTYPE_MAP: raise ValueError(f"unknown dtype {s!r}; must be one of {list(DTYPE_MAP)}") return DTYPE_MAP[s] def load_config(path): with open(path, "rb") as f: cfg = tomllib.load(f) for sec in REQUIRED_SECTIONS: if sec not in cfg: raise KeyError(f"config missing required section [{sec}]") for key in REQUIRED_KEYS[sec]: if key not in cfg[sec]: raise KeyError(f"config missing required key [{sec}].{key}") return cfg # ---------------------------------------------------------------------------- # Model loading # ---------------------------------------------------------------------------- def get_inner_with_layers(model): """Walk wrappers (model, language_model, transformer, ...) to find an object that has `.layers`. Used by zero_layers.""" seen = set() stack = [model] while stack: m = stack.pop() if id(m) in seen: continue seen.add(id(m)) if hasattr(m, "layers"): return m for attr in ("model", "language_model", "transformer", "base_model"): child = getattr(m, attr, None) if child is not None: stack.append(child) raise RuntimeError(f"Could not locate `.layers` inside {type(model).__name__}") def zero_layers(model, layer_indices): inner = get_inner_with_layers(model) layers = inner.layers n = len(layers) for idx in layer_indices: if idx < 0 or idx >= n: raise IndexError(f"layer {idx} out of range (0..{n - 1})") with torch.no_grad(): for p in layers[idx].parameters(): p.zero_() return n def _zero_output_projections(layer): """Zero out attention and MLP output projections so the layer is identity at init while still allowing gradients to flow into o_proj/down_proj first (and from there back into the rest of the layer's params after one step). Knows about Qwen3.5 names: self_attn.o_proj (full attention), linear_attn.out_proj (linear attention), mlp.down_proj. """ zeroed = [] with torch.no_grad(): if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "o_proj"): layer.self_attn.o_proj.weight.zero_() zeroed.append("self_attn.o_proj") if hasattr(layer, "linear_attn") and hasattr(layer.linear_attn, "out_proj"): layer.linear_attn.out_proj.weight.zero_() zeroed.append("linear_attn.out_proj") if hasattr(layer, "mlp") and hasattr(layer.mlp, "down_proj"): layer.mlp.down_proj.weight.zero_() zeroed.append("mlp.down_proj") return zeroed def grow_layers(model, target_n): """Grow the student to `target_n` decoder layers by appending new ones at the end. New layers are constructed via the existing decoder layer class with the model's own _init_weights, then their output projections are zeroed so each new layer starts as the identity but is still trainable. """ inner = get_inner_with_layers(model) cur_n = len(inner.layers) if target_n == cur_n: return cur_n if target_n < cur_n: raise ValueError(f"target_num_layers={target_n} < current {cur_n}; cannot shrink") # Locate the (text) config that the layers are built from. For multimodal # wrappers this lives at .text_config; for the dense student it's the same # object as model.config. cfg = model.config text_cfg = getattr(cfg, "text_config", cfg) # Extend layer_types by repeating the existing periodic pattern if not hasattr(text_cfg, "layer_types") or not text_cfg.layer_types: raise RuntimeError("text config has no layer_types; cannot extend pattern") period = getattr(text_cfg, "full_attention_interval", 4) new_types = list(text_cfg.layer_types) while len(new_types) < target_n: new_types.append(new_types[len(new_types) % period]) text_cfg.layer_types = new_types text_cfg.num_hidden_layers = target_n if hasattr(cfg, "num_hidden_layers") and cfg is not text_cfg: cfg.num_hidden_layers = target_n # Construct new layers using the same class as the existing ones layer_cls = type(inner.layers[0]) device = next(inner.parameters()).device dtype = next(inner.parameters()).dtype new_layer_zeroed = [] for i in range(cur_n, target_n): new_layer = layer_cls(text_cfg, layer_idx=i) # Apply the parent model's init scheme (std=initializer_range etc.) new_layer.apply(model._init_weights) new_layer.to(device=device, dtype=dtype) # Zero output projections -> identity at init, gradients still flow zeroed = _zero_output_projections(new_layer) new_layer_zeroed.append((i, zeroed)) inner.layers.append(new_layer) return target_n, new_layer_zeroed def load_student(model_id, dtype, grad_ckpt, attn_impl): from transformers import AutoModelForCausalLM log.info(f"Loading student: {model_id} (dtype={dtype})") model = AutoModelForCausalLM.from_pretrained( model_id, dtype=dtype, low_cpu_mem_usage=True, attn_implementation=attn_impl, ) model.config.use_cache = False if grad_ckpt: model.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) return model def load_teacher(model_id, dtype, attn_impl): """Load teacher model. Handles both pure CausalLM and multimodal (ConditionalGeneration) wrappers.""" from transformers import AutoConfig cfg = AutoConfig.from_pretrained(model_id) archs = list(getattr(cfg, "architectures", []) or []) arch = archs[0] if archs else "" is_multimodal = "ConditionalGeneration" in arch or "ImageText" in arch log.info(f"Loading teacher: {model_id} (arch={arch}, multimodal={is_multimodal}, dtype={dtype})") if is_multimodal: from transformers import AutoModelForImageTextToText model = AutoModelForImageTextToText.from_pretrained( model_id, dtype=dtype, low_cpu_mem_usage=True, attn_implementation=attn_impl, ) else: from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained( model_id, dtype=dtype, low_cpu_mem_usage=True, attn_implementation=attn_impl, ) model.config.use_cache = False model.eval() for p in model.parameters(): p.requires_grad_(False) return model def teacher_forward(teacher, input_ids, attention_mask): """Get teacher logits whether the model is unimodal or multimodal.""" out = teacher(input_ids=input_ids, attention_mask=attention_mask) logits = getattr(out, "logits", None) if logits is None: raise RuntimeError("teacher forward did not return .logits") return logits # ---------------------------------------------------------------------------- # Data # ---------------------------------------------------------------------------- class StreamingTextLoader: """Per-rank shard of a HF streaming dataset, yielding tokenized samples.""" def __init__( self, name, text_field, min_chars, max_seq_len, kl_start_pos, tokenizer, rank, world_size, seed, shuffle_buffer, ): from datasets import load_dataset from datasets.distributed import split_dataset_by_node # HF Hub occasionally returns 5xx during dataset metadata crawl. Retry. last_err = None for attempt in range(8): try: ds = load_dataset(name, split="train", streaming=True) break except Exception as e: last_err = e wait = min(2 ** attempt, 30) log.warning( f"load_dataset({name!r}) failed (attempt {attempt + 1}/8): " f"{type(e).__name__}: {e}; sleeping {wait}s" ) time.sleep(wait) else: raise RuntimeError(f"load_dataset failed after 8 retries") from last_err ds = ds.shuffle(seed=seed, buffer_size=shuffle_buffer) ds = split_dataset_by_node(ds, rank=rank, world_size=world_size) self._ds = iter(ds) self._text_field = text_field self._min_chars = min_chars self._max_seq_len = max_seq_len self._min_tokens = kl_start_pos + 16 self._tokenizer = tokenizer def next_batch(self, n): out = [] scanned = 0 while len(out) < n and scanned < n * 50: try: item = next(self._ds) except StopIteration: break scanned += 1 text = item.get(self._text_field, "") or "" if len(text) < self._min_chars: continue ids = self._tokenizer( text, return_tensors="pt", truncation=True, max_length=self._max_seq_len, ).input_ids.squeeze(0) if ids.shape[0] < self._min_tokens: continue out.append(ids) return out def collate_pad(token_lists, pad_id): """Right-pad a list of [L_i] tensors into [B, max_L] + attention_mask.""" max_len = max(t.shape[0] for t in token_lists) B = len(token_lists) input_ids = torch.full((B, max_len), pad_id, dtype=torch.long) attention_mask = torch.zeros((B, max_len), dtype=torch.long) for i, t in enumerate(token_lists): L = t.shape[0] input_ids[i, :L] = t attention_mask[i, :L] = 1 return input_ids, attention_mask # ---------------------------------------------------------------------------- # Loss # ---------------------------------------------------------------------------- def _kl_chunk_sum(s_chunk, t_chunk, m_chunk): """Compute (sum of masked KL) over a slice. Used as a checkpointed unit so the fp32 softmax intermediates only live for one chunk's worth of memory at a time.""" s = s_chunk.float() t = t_chunk.float() t_log_p = F.log_softmax(t, dim=-1) s_log_p = F.log_softmax(s, dim=-1) t_p = t_log_p.exp() per_token = (t_p * (t_log_p - s_log_p)).sum(-1) return (per_token * m_chunk).sum() def kl_loss_masked(student_logits, teacher_logits, attention_mask, start_pos, chunk_size): """Forward KL(teacher || student), masked for padding & start_pos, in fp32. If chunk_size > 0, processes the [start_pos:] sequence in chunks of that many positions, with gradient checkpointing on each chunk so peak memory is bounded by one chunk's intermediates rather than the full sequence's. """ s_full = student_logits[:, start_pos:, :] t_full = teacher_logits[:, start_pos:, :].detach() m_full = attention_mask[:, start_pos:].float() T = s_full.shape[1] if chunk_size <= 0 or chunk_size >= T: return _kl_chunk_sum(s_full, t_full, m_full) / m_full.sum().clamp_min(1.0) total_kl = torch.zeros((), device=s_full.device, dtype=torch.float32) for i in range(0, T, chunk_size): end = min(i + chunk_size, T) s_c = s_full[:, i:end, :] t_c = t_full[:, i:end, :] m_c = m_full[:, i:end] chunk_kl = checkpoint_utils.checkpoint( _kl_chunk_sum, s_c, t_c, m_c, use_reentrant=False ) total_kl = total_kl + chunk_kl return total_kl / m_full.sum().clamp_min(1.0) # ---------------------------------------------------------------------------- # Optimizer / scheduler # ---------------------------------------------------------------------------- def make_optimizer(model, train_cfg, new_layer_indices=None): """Create AdamW. If `new_layer_lr_mul != 1.0` and we know which layers are 'new' (returned from grow_layers), put their params in a separate group with a multiplied LR. Useful for the 'wake up new layers without disturbing the old ones' regime.""" base_lr = train_cfg["lr"] mul = train_cfg["new_layer_lr_mul"] common = dict( weight_decay=train_cfg["weight_decay"], betas=tuple(train_cfg["betas"]), eps=train_cfg["eps"], ) if not new_layer_indices or mul == 1.0: return AdamW( [p for p in model.parameters() if p.requires_grad], lr=base_lr, **common, ) inner = get_inner_with_layers(model) new_pids = set() for idx in new_layer_indices: for p in inner.layers[idx].parameters(): if p.requires_grad: new_pids.add(id(p)) new_params = [] rest_params = [] for p in model.parameters(): if not p.requires_grad: continue (new_params if id(p) in new_pids else rest_params).append(p) return AdamW( [ {"params": rest_params, "lr": base_lr}, {"params": new_params, "lr": base_lr * mul}, ], **common, ) def make_scheduler(optimizer, train_cfg): schedule = train_cfg["schedule"] warmup = train_cfg["warmup_steps"] total = train_cfg["max_steps"] if schedule == "constant": from transformers import get_constant_schedule_with_warmup return get_constant_schedule_with_warmup(optimizer, warmup) if schedule == "cosine": from transformers import get_cosine_schedule_with_warmup return get_cosine_schedule_with_warmup(optimizer, warmup, total) if schedule == "linear": from transformers import get_linear_schedule_with_warmup return get_linear_schedule_with_warmup(optimizer, warmup, total) raise ValueError(f"unknown schedule: {schedule!r}") # ---------------------------------------------------------------------------- # Eval # ---------------------------------------------------------------------------- @torch.no_grad() def evaluate(accelerator, student, teacher, eval_batches, pad_id, kl_start_pos, kl_chunk_size): student.eval() sdev = accelerator.device total = 0.0 n = 0 for sample in eval_batches: ids, mask = collate_pad([sample], pad_id) ids = ids.to(sdev) mask = mask.to(sdev) t_logits = teacher_forward(teacher, ids, mask) s_logits = student(input_ids=ids, attention_mask=mask).logits loss = kl_loss_masked( s_logits, t_logits, mask, start_pos=kl_start_pos, chunk_size=kl_chunk_size, ) total += loss.item() n += 1 del t_logits, s_logits, loss student.train() if n == 0: local = torch.tensor(float("inf"), device=sdev) else: local = torch.tensor(total / n, device=sdev) gathered = accelerator.gather(local.unsqueeze(0)) return gathered.mean().item() def save_best(accelerator, student, tokenizer, output_dir, step, eval_kl): accelerator.wait_for_everyone() if accelerator.is_main_process: out_dir = Path(output_dir) / "best" if out_dir.exists(): shutil.rmtree(out_dir) out_dir.mkdir(parents=True, exist_ok=True) unwrapped = accelerator.unwrap_model(student) unwrapped.save_pretrained(out_dir, safe_serialization=True) tokenizer.save_pretrained(out_dir) with open(out_dir / "best.json", "w") as f: json.dump({"step": step, "eval_kl": eval_kl}, f, indent=2) log.info(f" saved best @ step {step}: eval_kl={eval_kl:.6f} -> {out_dir}") accelerator.wait_for_everyone() # ---------------------------------------------------------------------------- # Main # ---------------------------------------------------------------------------- def main(): p = argparse.ArgumentParser() p.add_argument("--config", required=True, help="Path to TOML config") args = p.parse_args() cfg = load_config(args.config) accelerator = Accelerator(mixed_precision=cfg["train"]["mixed_precision"]) set_seed(cfg["train"]["seed"]) if accelerator.is_main_process: log.info(f"Loaded config from {args.config}") log.info(f"World size: {accelerator.num_processes}") log.info(f"Mixed precision: {cfg['train']['mixed_precision']}") # ---- Tokenizer from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["tokenizer"]) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token pad_id = tokenizer.pad_token_id # ---- Models (separate dtypes per config) student_dtype = parse_dtype(cfg["train"]["student_dtype"]) teacher_dtype = parse_dtype(cfg["train"]["teacher_dtype"]) student = load_student( cfg["model"]["student"], student_dtype, grad_ckpt=cfg["train"]["grad_checkpointing"], attn_impl=cfg["train"]["attn_implementation"], ) teacher = load_teacher( cfg["model"]["teacher"], teacher_dtype, attn_impl=cfg["train"]["attn_implementation"], ) # ---- Layer modifications: grow first, then zero (composable) target_n = cfg["init"]["target_num_layers"] cur_n = len(get_inner_with_layers(student).layers) new_layer_indices = [] if target_n != cur_n: new_n, new_zeroed = grow_layers(student, target_n) new_layer_indices = [idx for idx, _ in new_zeroed] if accelerator.is_main_process: log.info(f"Grew student from {cur_n} -> {new_n} layers") for idx, names in new_zeroed: log.info(f" layer {idx}: zeroed {names}") zero_idx = cfg["init"]["zero_layers"] if zero_idx: n = zero_layers(student, zero_idx) if accelerator.is_main_process: log.info(f"Zeroed student layers {zero_idx} (model has {n} layers)") teacher = teacher.to(accelerator.device) # ---- Optimizer / scheduler optimizer = make_optimizer(student, cfg["train"], new_layer_indices=new_layer_indices) scheduler = make_scheduler(optimizer, cfg["train"]) if accelerator.is_main_process and len(optimizer.param_groups) > 1: log.info( f"Param groups: rest lr={optimizer.param_groups[0]['lr']:.2e}, " f"new lr={optimizer.param_groups[1]['lr']:.2e} " f"({len(new_layer_indices)} layers grown)" ) # NB: do NOT pass `scheduler` to accelerator.prepare. When prepared, accelerate # advances the scheduler by `num_processes` steps per call (to match the # "single-GPU equivalent" timeline). Combined with our explicit max_steps # accounting, that causes the cosine to cycle multiple times mid-run. By # leaving the scheduler unprepared, scheduler.step() advances exactly once # per training step, matching how max_steps is interpreted in this script. student, optimizer = accelerator.prepare(student, optimizer) # ---- Output dir + config snapshot output_dir = Path(cfg["log"]["output_dir"]) if accelerator.is_main_process: output_dir.mkdir(parents=True, exist_ok=True) shutil.copy2(args.config, output_dir / "config.snapshot.toml") # ---- Wandb use_wandb = cfg["log"]["wandb"] if use_wandb and accelerator.is_main_process: import wandb wandb.init( project=cfg["log"]["wandb_project"], name=cfg["log"]["wandb_run"], config=cfg, ) # ---- Data loaders train_loader = StreamingTextLoader( name=cfg["data"]["dataset"], text_field=cfg["data"]["text_field"], min_chars=cfg["data"]["min_chars"], max_seq_len=cfg["data"]["max_seq_len"], kl_start_pos=cfg["data"]["kl_start_pos"], tokenizer=tokenizer, rank=accelerator.process_index, world_size=accelerator.num_processes, seed=cfg["data"]["seed"], shuffle_buffer=cfg["data"]["shuffle_buffer"], ) eval_loader = StreamingTextLoader( name=cfg["data"]["dataset"], text_field=cfg["data"]["text_field"], min_chars=cfg["data"]["min_chars"], max_seq_len=cfg["data"]["max_seq_len"], kl_start_pos=cfg["data"]["kl_start_pos"], tokenizer=tokenizer, rank=accelerator.process_index, world_size=accelerator.num_processes, seed=cfg["eval"]["seed"], shuffle_buffer=cfg["data"]["shuffle_buffer"], ) eval_per_rank = max(1, cfg["eval"]["samples"] // accelerator.num_processes) eval_batches = eval_loader.next_batch(eval_per_rank) if accelerator.is_main_process: log.info( f"Eval set: {len(eval_batches)}/rank x {accelerator.num_processes} ranks " f"= {len(eval_batches) * accelerator.num_processes} samples" ) # ---- Train loop samples_per_step = cfg["train"]["samples_per_step"] micro_batch_size = cfg["train"]["micro_batch_size"] grad_clip = cfg["train"]["grad_clip"] kl_start_pos = cfg["data"]["kl_start_pos"] kl_chunk_size = cfg["train"]["kl_chunk_size"] max_steps = cfg["train"]["max_steps"] eval_every = cfg["eval"]["every_steps"] log_every = cfg["log"]["log_every"] if accelerator.is_main_process: log.info( f"=== Training: max_steps={max_steps}, samples_per_step={samples_per_step} " f"(per rank, micro={micro_batch_size}), " f"effective batch={samples_per_step * accelerator.num_processes}" ) student.train() best_kl = float("inf") global_step = 0 while global_step < max_steps: t0 = time.time() batch = train_loader.next_batch(samples_per_step) if not batch: log.warning(f"rank {accelerator.process_index}: data exhausted") break optimizer.zero_grad() batch_n = len(batch) kl_sum = 0.0 for mb_start in range(0, batch_n, micro_batch_size): micro = batch[mb_start : mb_start + micro_batch_size] mb_n = len(micro) ids, mask = collate_pad(micro, pad_id) ids = ids.to(accelerator.device) mask = mask.to(accelerator.device) with torch.no_grad(): t_logits = teacher_forward(teacher, ids, mask) s_logits = student(input_ids=ids, attention_mask=mask).logits loss = kl_loss_masked( s_logits, t_logits, mask, start_pos=kl_start_pos, chunk_size=kl_chunk_size, ) # Weight by micro size so summing micros gives the batch mean scaled = loss * (mb_n / batch_n) accelerator.backward(scaled) kl_sum += loss.item() * mb_n del t_logits, s_logits, loss, scaled if grad_clip > 0: accelerator.clip_grad_norm_(student.parameters(), grad_clip) optimizer.step() scheduler.step() global_step += 1 elapsed = time.time() - t0 kl_local = torch.tensor(kl_sum / batch_n, device=accelerator.device) kl_avg = accelerator.gather(kl_local.unsqueeze(0)).mean().item() del kl_local if accelerator.is_main_process and global_step % log_every == 0: lr_now = scheduler.get_last_lr()[0] log.info( f"step {global_step}/{max_steps} | kl {kl_avg:.4f} | " f"lr {lr_now:.2e} | {elapsed:.2f}s" ) if use_wandb: import wandb wandb.log( { "train/kl": kl_avg, "train/lr": lr_now, "perf/step_time_s": elapsed, }, step=global_step, ) if global_step % eval_every == 0: eval_kl = evaluate( accelerator, student, teacher, eval_batches, pad_id, kl_start_pos, kl_chunk_size, ) if accelerator.is_main_process: log.info( f" eval @ step {global_step}: kl={eval_kl:.6f} " f"(best={best_kl:.6f})" ) if use_wandb: import wandb wandb.log({"eval/kl": eval_kl}, step=global_step) if eval_kl < best_kl: best_kl = eval_kl save_best( accelerator, student, tokenizer, output_dir, global_step, eval_kl ) student.train() if global_step % 20 == 0: gc.collect() torch.cuda.empty_cache() # Final eval eval_kl = evaluate( accelerator, student, teacher, eval_batches, pad_id, kl_start_pos, kl_chunk_size, ) if accelerator.is_main_process: log.info(f" final eval: kl={eval_kl:.6f} (best={best_kl:.6f})") if use_wandb: import wandb wandb.log({"eval/kl": eval_kl}, step=global_step) if eval_kl < best_kl: best_kl = eval_kl save_best(accelerator, student, tokenizer, output_dir, global_step, eval_kl) if accelerator.is_main_process: log.info(f"Done. Best eval KL = {best_kl:.6f}") if use_wandb: import wandb wandb.finish() if __name__ == "__main__": main()