# wire up trl sftrainer with completion-only loss masking and the autodetected # precision. saves the lora adapter, tokenizer, and training history under # runs//. import json from pathlib import Path from typing import Optional import torch from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer from cleanup.config import TrainConfig from cleanup.data.download import load_pairs from cleanup.data.tokenize import RESPONSE_TEMPLATE, to_dataset from cleanup.train.model import load_base_and_tokenizer, wrap_with_lora def _detect_precision_flags(cfg: TrainConfig) -> dict: if not torch.cuda.is_available(): return {"bf16": False, "fp16": False, "tf32": False} if cfg.bf16 and torch.cuda.is_bf16_supported(): return {"bf16": True, "fp16": False, "tf32": cfg.tf32} if cfg.fp16: return {"bf16": False, "fp16": True, "tf32": cfg.tf32} return {"bf16": False, "fp16": False, "tf32": cfg.tf32} def train( cfg: TrainConfig, run_dir: Path, smoke: bool = False, epochs_override: Optional[int] = None, lr_override: Optional[float] = None, ) -> dict: run_dir = Path(run_dir) run_dir.mkdir(parents=True, exist_ok=True) model, tokenizer = load_base_and_tokenizer(cfg) model = wrap_with_lora(model, cfg) train_rows = load_pairs(cfg.data_dir, "train", cfg.train_rows if not smoke else 200) val_rows = load_pairs(cfg.data_dir, "val", cfg.val_rows if not smoke else 40) train_ds = to_dataset(train_rows) val_ds = to_dataset(val_rows) # the collator scans each tokenized example for response_template and masks # everything up to and including it with -100 in the labels. cross entropy # then only fires on assistant tokens. collator = DataCollatorForCompletionOnlyLM( response_template=RESPONSE_TEMPLATE, tokenizer=tokenizer, ) epochs = epochs_override if epochs_override is not None else cfg.num_epochs if smoke: epochs = 1 lr = lr_override if lr_override is not None else cfg.learning_rate precision = _detect_precision_flags(cfg) args = SFTConfig( output_dir=str(run_dir), num_train_epochs=epochs, per_device_train_batch_size=cfg.train_batch_size if not smoke else 2, per_device_eval_batch_size=cfg.eval_batch_size if not smoke else 2, gradient_accumulation_steps=cfg.gradient_accumulation_steps if not smoke else 1, learning_rate=lr, weight_decay=cfg.weight_decay, warmup_ratio=cfg.warmup_ratio, max_grad_norm=cfg.max_grad_norm, adam_beta1=cfg.adam_beta1, adam_beta2=cfg.adam_beta2, adam_epsilon=cfg.adam_epsilon, lr_scheduler_type=cfg.lr_scheduler_type, eval_strategy="steps", eval_steps=cfg.eval_steps if not smoke else 25, save_strategy="steps", save_steps=cfg.save_steps if not smoke else 25, save_total_limit=cfg.save_total_limit, logging_steps=cfg.logging_steps if not smoke else 5, report_to=[], metric_for_best_model=cfg.metric_for_best_model, greater_is_better=cfg.greater_is_better, load_best_model_at_end=True, seed=cfg.seed, dataloader_num_workers=cfg.dataloader_num_workers, max_seq_length=cfg.max_seq_length, packing=False, **precision, gradient_checkpointing=not smoke, ) trainer = SFTTrainer( model=model, args=args, train_dataset=train_ds, eval_dataset=val_ds, tokenizer=tokenizer, data_collator=collator, ) trainer.train() # save lora adapter + tokenizer under runs//model model_dir = run_dir / "model" model_dir.mkdir(exist_ok=True) trainer.model.save_pretrained(model_dir) tokenizer.save_pretrained(model_dir) # dump the training log history so the report builder can chart it history_path = run_dir / "training_history.json" history_path.write_text(json.dumps(trainer.state.log_history, indent=2)) summary = { "epochs": epochs, "learning_rate": lr, "train_rows": len(train_rows), "val_rows": len(val_rows), "best_metric": trainer.state.best_metric, "best_model_checkpoint": trainer.state.best_model_checkpoint, } (run_dir / "train_summary.json").write_text(json.dumps(summary, indent=2)) return summary