| """Main training loop for SAGE.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import time |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| from torch.utils.data import DataLoader |
| import yaml |
|
|
| from data.dataset import DatasetConfig, PackedDataset |
| from eval.perplexity import evaluate_perplexity |
| from model.config import ModelConfig |
| from model.model import SageTransformer |
| from train.checkpoint import load_latest_checkpoint, save_checkpoint |
| from train.hardware import HardwareConfig |
| from train.loss import masked_cross_entropy |
| from train.optimizer import ScheduleConfig, create_optimizer, create_scheduler |
|
|
|
|
| @dataclass |
| class TrainerConfig: |
| """High-level trainer settings.""" |
|
|
| output_dir: str = "runs/default" |
| checkpoint_interval: int = 1000 |
| log_interval: int = 10 |
| eval_interval: int = 1000 |
| total_steps: int = 25_000 |
| seed: int = 42 |
| use_wandb: bool = True |
|
|
|
|
| def collate_batch(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: |
| """Stack packed dataset examples into a batch.""" |
| keys = batch[0].keys() |
| return {key: torch.stack([item[key] for item in batch], dim=0) for key in keys} |
|
|
|
|
| def create_dataloader(dataset: PackedDataset, batch_size: int) -> DataLoader: |
| """Create the training DataLoader.""" |
| return DataLoader(dataset, batch_size=batch_size, collate_fn=collate_batch) |
|
|
|
|
| def train( |
| model: SageTransformer, |
| train_dataset: PackedDataset, |
| validation_dataset: PackedDataset | None, |
| model_config: ModelConfig, |
| schedule_config: ScheduleConfig, |
| trainer_config: TrainerConfig, |
| ) -> dict[str, object]: |
| """Run the training loop and return the final summary.""" |
| torch.manual_seed(trainer_config.seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(trainer_config.seed) |
| hw = HardwareConfig(model_size_b=1.0, context_length=model_config.context_length) |
| device = torch.device(hw.device) |
| model = model.to(device) |
| optimizer = create_optimizer(model, schedule_config) |
| scheduler = create_scheduler(optimizer, schedule_config) |
| scaler = torch.GradScaler("cuda", enabled=(hw.device == "cuda" and hw.dtype == torch.float16)) |
| start_step = load_latest_checkpoint(model, optimizer, scheduler, scaler, trainer_config.output_dir, device) |
| train_dataset.skip(start_step * hw.grad_accum) |
| train_loader = create_dataloader(train_dataset, batch_size=hw.micro_batch) |
| train_iter = iter(train_loader) |
|
|
| Path(trainer_config.output_dir).mkdir(parents=True, exist_ok=True) |
| metrics_path = Path(trainer_config.output_dir) / "metrics.jsonl" |
| tokens_seen = start_step * hw.micro_batch * model_config.context_length |
| last_log_time = time.perf_counter() |
| wandb_run = _init_wandb(trainer_config, model_config, schedule_config, hw.summary()) |
|
|
| model.train() |
| for step in range(start_step, trainer_config.total_steps): |
| optimizer.zero_grad(set_to_none=True) |
| step_loss = 0.0 |
| for _ in range(hw.grad_accum): |
| try: |
| batch = next(train_iter) |
| except StopIteration: |
| train_iter = iter(train_loader) |
| batch = next(train_iter) |
| input_ids = batch["input_ids"].to(device) |
| labels = batch["labels"].to(device) |
| loss_mask = batch["loss_mask"].to(device) |
| if hw.use_amp: |
| with torch.autocast(device_type=hw.device, dtype=hw.dtype): |
| logits, _ = model(input_ids) |
| loss = masked_cross_entropy(logits, labels, loss_mask) / hw.grad_accum |
| else: |
| logits, _ = model(input_ids) |
| loss = masked_cross_entropy(logits, labels, loss_mask) / hw.grad_accum |
| scaler.scale(loss).backward() |
| step_loss += loss.item() |
| tokens_seen += int(input_ids.numel()) |
|
|
| scaler.unscale_(optimizer) |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
|
|
| if (step + 1) % trainer_config.log_interval == 0: |
| now = time.perf_counter() |
| elapsed = max(now - last_log_time, 1.0e-6) |
| tokens_per_second = (hw.micro_batch * hw.grad_accum * model_config.context_length) / elapsed |
| metrics = { |
| "step": step + 1, |
| "loss": step_loss, |
| "learning_rate": scheduler.get_last_lr()[0], |
| "tokens_seen": tokens_seen, |
| "tokens_per_second": tokens_per_second, |
| "grad_norm": float(grad_norm), |
| } |
| with metrics_path.open("a", encoding="utf-8") as handle: |
| handle.write(json.dumps(metrics) + "\n") |
| if wandb_run is not None: |
| wandb_run.log(metrics, step=step + 1) |
| last_log_time = now |
|
|
| if (step + 1) % trainer_config.eval_interval == 0 and validation_dataset is not None: |
| val_loader = create_dataloader(validation_dataset, batch_size=1) |
| evaluation = evaluate_perplexity(model, val_loader, device=device, dtype=hw.dtype if hw.use_amp else None) |
| with metrics_path.open("a", encoding="utf-8") as handle: |
| handle.write(json.dumps({"step": step + 1, **evaluation}) + "\n") |
| if wandb_run is not None: |
| wandb_run.log(evaluation, step=step + 1) |
|
|
| if (step + 1) % trainer_config.checkpoint_interval == 0: |
| save_checkpoint( |
| model=model, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| scaler=scaler, |
| step=step + 1, |
| config={"model": model_config.to_dict(), "schedule": asdict(schedule_config), "trainer": asdict(trainer_config)}, |
| output_dir=trainer_config.output_dir, |
| ) |
|
|
| if wandb_run is not None: |
| wandb_run.finish() |
| return {"output_dir": trainer_config.output_dir, "tokens_seen": tokens_seen, "hardware": hw.summary()} |
|
|
|
|
| def _init_wandb( |
| trainer_config: TrainerConfig, |
| model_config: ModelConfig, |
| schedule_config: ScheduleConfig, |
| hardware_summary: dict[str, object], |
| ): |
| """Start a wandb run when available and enabled.""" |
| if not trainer_config.use_wandb: |
| return None |
| try: |
| import wandb |
| except ImportError: |
| return None |
| return wandb.init( |
| project="sage-llm", |
| name=Path(trainer_config.output_dir).name, |
| config={ |
| "model": model_config.to_dict(), |
| "schedule": asdict(schedule_config), |
| "trainer": asdict(trainer_config), |
| "hardware": hardware_summary, |
| }, |
| mode="offline", |
| ) |
|
|
|
|
| def build_argparser() -> argparse.ArgumentParser: |
| """Build the trainer CLI.""" |
| parser = argparse.ArgumentParser(description="Train the SAGE dense language model.") |
| parser.add_argument("--model-config", default="configs/model/1b.yaml") |
| parser.add_argument("--schedule-config", default="configs/train/schedule.yaml") |
| parser.add_argument("--train-shards", nargs="+", default=[]) |
| parser.add_argument("--validation-shards", nargs="*", default=[]) |
| parser.add_argument("--output-dir", default="runs/default") |
| parser.add_argument("--steps", type=int, default=None) |
| parser.add_argument("--disable-wandb", action="store_true") |
| return parser |
|
|
|
|
| def main(argv: Optional[list[str]] = None) -> None: |
| """CLI entrypoint for local training runs.""" |
| parser = build_argparser() |
| args = parser.parse_args(argv) |
| model_config = ModelConfig.from_yaml(args.model_config) |
| schedule_payload = yaml.safe_load(Path(args.schedule_config).read_text(encoding="utf-8")) |
| schedule = ScheduleConfig( |
| peak_learning_rate=schedule_payload["peak_learning_rate"], |
| min_learning_rate=schedule_payload["min_learning_rate"], |
| warmup_steps=schedule_payload["warmup_steps"], |
| weight_decay=schedule_payload["weight_decay"], |
| betas=tuple(schedule_payload["betas"]), |
| adam_eps=schedule_payload["adam_eps"], |
| total_steps=args.steps or schedule_payload["total_steps"] if "total_steps" in schedule_payload else (args.steps or 25_000), |
| ) |
| trainer_config = TrainerConfig( |
| output_dir=args.output_dir, |
| checkpoint_interval=schedule_payload.get("checkpoint_interval", 1000), |
| log_interval=schedule_payload.get("log_interval", 10), |
| eval_interval=schedule_payload.get("eval_interval", 1000), |
| total_steps=args.steps or schedule_payload.get("total_steps", 25_000), |
| seed=schedule_payload.get("seed", 42), |
| use_wandb=not args.disable_wandb, |
| ) |
| if not args.train_shards: |
| print("No training shards provided. The trainer entrypoint is configured correctly but requires shard paths to run.") |
| return |
| train_dataset = PackedDataset(DatasetConfig(tuple(args.train_shards), model_config.context_length, split="train")) |
| validation_dataset = None |
| if args.validation_shards: |
| validation_dataset = PackedDataset(DatasetConfig(tuple(args.validation_shards), model_config.context_length, split="validation")) |
| model = SageTransformer(model_config) |
| summary = train(model, train_dataset, validation_dataset, model_config, schedule, trainer_config) |
| print(json.dumps(summary, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|