| | """ |
| | Training loop for SLM. |
| | |
| | Handles the complete training process including: |
| | - Mixed precision training |
| | - Gradient accumulation |
| | - Checkpointing |
| | - Logging |
| | """ |
| |
|
| | import os |
| | import time |
| | import json |
| | from dataclasses import dataclass, asdict |
| | from typing import Optional, Dict, Any, Callable |
| | from pathlib import Path |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import DataLoader |
| | from torch.cuda.amp import autocast, GradScaler |
| | from tqdm import tqdm |
| |
|
| | from .loss import LanguageModelingLoss, compute_perplexity, compute_accuracy |
| | from .optimizer import create_optimizer, create_scheduler, clip_grad_norm |
| |
|
| |
|
| | @dataclass |
| | class TrainingConfig: |
| | """Configuration for training.""" |
| |
|
| | |
| | learning_rate: float = 3e-4 |
| | weight_decay: float = 0.1 |
| | warmup_ratio: float = 0.1 |
| | min_lr_ratio: float = 0.1 |
| | max_grad_norm: float = 1.0 |
| | label_smoothing: float = 0.0 |
| |
|
| | |
| | num_epochs: int = 5 |
| | gradient_accumulation_steps: int = 4 |
| | fp16: bool = True |
| |
|
| | |
| | checkpoint_dir: str = "checkpoints" |
| | save_steps: int = 1000 |
| | save_total_limit: int = 3 |
| |
|
| | |
| | eval_steps: int = 500 |
| | logging_steps: int = 10 |
| |
|
| | |
| | early_stopping_patience: int = 5 |
| | early_stopping_threshold: float = 0.01 |
| |
|
| | |
| | device: str = "auto" |
| |
|
| | |
| | compile_model: bool = False |
| |
|
| | def to_dict(self) -> Dict[str, Any]: |
| | return asdict(self) |
| |
|
| |
|
| | class Trainer: |
| | """Training loop for SLM model.""" |
| |
|
| | def __init__( |
| | self, |
| | model: nn.Module, |
| | config: TrainingConfig, |
| | train_dataloader: DataLoader, |
| | val_dataloader: Optional[DataLoader] = None, |
| | wandb_project: Optional[str] = None, |
| | ): |
| | """Initialize trainer. |
| | |
| | Args: |
| | model: The model to train |
| | config: Training configuration |
| | train_dataloader: Training data loader |
| | val_dataloader: Optional validation data loader |
| | wandb_project: Optional W&B project name for logging |
| | """ |
| | self.config = config |
| | self.train_dataloader = train_dataloader |
| | self.val_dataloader = val_dataloader |
| |
|
| | |
| | if config.device == "auto": |
| | if torch.cuda.is_available(): |
| | self.device = torch.device("cuda") |
| | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| | self.device = torch.device("mps") |
| | else: |
| | self.device = torch.device("cpu") |
| | else: |
| | self.device = torch.device(config.device) |
| |
|
| | print(f"Training on device: {self.device}") |
| |
|
| | |
| | self.model = model.to(self.device) |
| |
|
| | |
| | if hasattr(model, "config"): |
| | self.vocab_size = model.config.vocab_size |
| | else: |
| | self.vocab_size = model.embed_tokens.num_embeddings |
| |
|
| | |
| | self.loss_fn = LanguageModelingLoss( |
| | vocab_size=self.vocab_size, |
| | label_smoothing=config.label_smoothing, |
| | ) |
| |
|
| | |
| | self.steps_per_epoch = len(train_dataloader) |
| | self.total_steps = self.steps_per_epoch * config.num_epochs |
| | self.total_steps = self.total_steps // config.gradient_accumulation_steps |
| |
|
| | |
| | self.optimizer = create_optimizer( |
| | model, |
| | learning_rate=config.learning_rate, |
| | weight_decay=config.weight_decay, |
| | ) |
| |
|
| | self.scheduler = create_scheduler( |
| | self.optimizer, |
| | num_training_steps=self.total_steps, |
| | warmup_ratio=config.warmup_ratio, |
| | min_lr_ratio=config.min_lr_ratio, |
| | ) |
| |
|
| | |
| | self.use_amp = config.fp16 and self.device.type == "cuda" |
| | self.scaler = GradScaler() if self.use_amp else None |
| |
|
| | |
| | self.global_step = 0 |
| | self.epoch = 0 |
| | self.best_val_loss = float("inf") |
| |
|
| | |
| | self.early_stopping_counter = 0 |
| | self.should_stop = False |
| |
|
| | |
| | os.makedirs(config.checkpoint_dir, exist_ok=True) |
| |
|
| | |
| | self.wandb = None |
| | if wandb_project: |
| | try: |
| | import wandb |
| | wandb.init(project=wandb_project, config=config.to_dict()) |
| | self.wandb = wandb |
| | except ImportError: |
| | print("wandb not installed, skipping logging") |
| |
|
| | def train(self) -> Dict[str, Any]: |
| | """Run the full training loop. |
| | |
| | Returns: |
| | Dictionary with training results |
| | """ |
| | print(f"\n{'='*60}") |
| | print("STARTING TRAINING") |
| | print(f"{'='*60}") |
| | print(f"Total epochs: {self.config.num_epochs}") |
| | print(f"Steps per epoch: {self.steps_per_epoch}") |
| | print(f"Total optimization steps: {self.total_steps}") |
| | print(f"Gradient accumulation: {self.config.gradient_accumulation_steps}") |
| | print(f"Mixed precision: {self.use_amp}") |
| | if self.config.early_stopping_patience > 0: |
| | print(f"Early stopping: patience={self.config.early_stopping_patience}") |
| | print(f"{'='*60}\n") |
| |
|
| | training_start = time.time() |
| |
|
| | |
| | start_epoch = self.epoch |
| | if start_epoch > 0: |
| | print(f"Resuming from epoch {start_epoch + 1}") |
| |
|
| | for epoch in range(start_epoch, self.config.num_epochs): |
| | self.epoch = epoch |
| | epoch_loss = self._train_epoch() |
| |
|
| | print(f"\nEpoch {epoch + 1}/{self.config.num_epochs} - Loss: {epoch_loss:.4f}") |
| |
|
| | |
| | if self.val_dataloader is not None: |
| | val_metrics = self.evaluate() |
| | print(f"Validation - Loss: {val_metrics['loss']:.4f}, PPL: {val_metrics['perplexity']:.2f}") |
| |
|
| | |
| | if val_metrics["loss"] < self.best_val_loss - self.config.early_stopping_threshold: |
| | self.best_val_loss = val_metrics["loss"] |
| | self.early_stopping_counter = 0 |
| | self.save_checkpoint("best") |
| | print(f" New best model saved!") |
| | else: |
| | self.early_stopping_counter += 1 |
| | print(f" No improvement. Early stopping: {self.early_stopping_counter}/{self.config.early_stopping_patience}") |
| |
|
| | if self.config.early_stopping_patience > 0 and self.early_stopping_counter >= self.config.early_stopping_patience: |
| | print(f"\nEarly stopping triggered after {self.early_stopping_counter} evaluations without improvement.") |
| | self.should_stop = True |
| |
|
| | |
| | self.save_checkpoint(f"epoch_{epoch + 1}") |
| |
|
| | |
| | if self.should_stop: |
| | print("Stopping training early.") |
| | break |
| |
|
| | training_time = time.time() - training_start |
| | print(f"\n{'='*60}") |
| | print(f"TRAINING COMPLETE") |
| | print(f"Total time: {training_time / 3600:.2f} hours") |
| | print(f"Best validation loss: {self.best_val_loss:.4f}") |
| | if self.should_stop: |
| | print(f"Stopped early at epoch {self.epoch + 1}") |
| | print(f"{'='*60}") |
| |
|
| | return { |
| | "total_steps": self.global_step, |
| | "training_time": training_time, |
| | "best_val_loss": self.best_val_loss, |
| | } |
| |
|
| | def _train_epoch(self) -> float: |
| | """Train for one epoch. |
| | |
| | Returns: |
| | Average training loss for the epoch |
| | """ |
| | self.model.train() |
| | total_loss = 0.0 |
| | num_batches = 0 |
| | accumulated_loss = 0.0 |
| | num_accumulated_batches = 0 |
| |
|
| | |
| | pbar = tqdm( |
| | enumerate(self.train_dataloader), |
| | total=len(self.train_dataloader), |
| | desc=f"Epoch {self.epoch + 1}", |
| | ncols=100, |
| | ) |
| |
|
| | for step, batch in pbar: |
| | |
| | input_ids = batch["input_ids"].to(self.device) |
| | labels = batch["labels"].to(self.device) |
| | |
| | |
| | |
| |
|
| | |
| | with autocast(enabled=self.use_amp): |
| | outputs = self.model(input_ids) |
| | |
| | if isinstance(outputs, torch.Tensor): |
| | logits = outputs |
| | elif hasattr(outputs, 'logits'): |
| | logits = outputs.logits |
| | else: |
| | logits = outputs[0] |
| | loss = self.loss_fn(logits, labels) |
| | loss = loss / self.config.gradient_accumulation_steps |
| |
|
| | |
| | if self.use_amp: |
| | self.scaler.scale(loss).backward() |
| | else: |
| | loss.backward() |
| |
|
| | |
| | unscaled_loss = loss.item() * self.config.gradient_accumulation_steps |
| | accumulated_loss += unscaled_loss |
| | num_accumulated_batches += 1 |
| | total_loss += unscaled_loss |
| | num_batches += 1 |
| |
|
| | |
| | if (step + 1) % self.config.gradient_accumulation_steps == 0: |
| | |
| | if self.use_amp: |
| | self.scaler.unscale_(self.optimizer) |
| |
|
| | grad_norm = clip_grad_norm(self.model, self.config.max_grad_norm) |
| |
|
| | |
| | if self.use_amp: |
| | self.scaler.step(self.optimizer) |
| | self.scaler.update() |
| | else: |
| | self.optimizer.step() |
| |
|
| | self.scheduler.step() |
| | self.optimizer.zero_grad() |
| |
|
| | self.global_step += 1 |
| |
|
| | |
| | if self.global_step % self.config.logging_steps == 0: |
| | |
| | avg_loss = accumulated_loss / max(num_accumulated_batches, 1) |
| | lr = self.scheduler.get_last_lr()[0] |
| |
|
| | |
| | pbar.set_postfix({ |
| | 'loss': f'{avg_loss:.4f}', |
| | 'lr': f'{lr:.2e}', |
| | 'step': f'{self.global_step}/{self.total_steps}' |
| | }) |
| |
|
| | tqdm.write( |
| | f"Step {self.global_step}/{self.total_steps} | " |
| | f"Loss: {avg_loss:.4f} | " |
| | f"LR: {lr:.2e} | " |
| | f"Grad: {grad_norm:.2f}" |
| | ) |
| |
|
| | if self.wandb: |
| | self.wandb.log({ |
| | "train/loss": avg_loss, |
| | "train/learning_rate": lr, |
| | "train/grad_norm": grad_norm, |
| | "train/epoch": self.epoch, |
| | }, step=self.global_step) |
| |
|
| | |
| | accumulated_loss = 0.0 |
| | num_accumulated_batches = 0 |
| |
|
| | |
| | if self.config.eval_steps > 0 and self.global_step % self.config.eval_steps == 0: |
| | if self.val_dataloader is not None: |
| | val_metrics = self.evaluate() |
| | print(f" Eval - Loss: {val_metrics['loss']:.4f}, PPL: {val_metrics['perplexity']:.2f}") |
| |
|
| | if self.wandb: |
| | self.wandb.log({ |
| | "eval/loss": val_metrics["loss"], |
| | "eval/perplexity": val_metrics["perplexity"], |
| | }, step=self.global_step) |
| |
|
| | |
| | if val_metrics["loss"] < self.best_val_loss - self.config.early_stopping_threshold: |
| | self.best_val_loss = val_metrics["loss"] |
| | self.early_stopping_counter = 0 |
| | self.save_checkpoint("best") |
| | print(f" New best model! Loss: {self.best_val_loss:.4f}") |
| | else: |
| | self.early_stopping_counter += 1 |
| | if self.config.early_stopping_patience > 0: |
| | print(f" No improvement ({self.early_stopping_counter}/{self.config.early_stopping_patience})") |
| | if self.early_stopping_counter >= self.config.early_stopping_patience: |
| | print(f"\n Early stopping triggered!") |
| | self.should_stop = True |
| | break |
| |
|
| | |
| | if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0: |
| | self.save_checkpoint(f"step_{self.global_step}") |
| |
|
| | |
| | if self.should_stop: |
| | break |
| |
|
| | return total_loss / max(num_batches, 1) |
| |
|
| | @torch.no_grad() |
| | def evaluate(self) -> Dict[str, float]: |
| | """Evaluate the model on validation data. |
| | |
| | Returns: |
| | Dictionary with evaluation metrics |
| | """ |
| | self.model.eval() |
| | total_loss = 0.0 |
| | total_accuracy = 0.0 |
| | num_batches = 0 |
| |
|
| | for batch in self.val_dataloader: |
| | input_ids = batch["input_ids"].to(self.device) |
| | labels = batch["labels"].to(self.device) |
| |
|
| | with autocast(enabled=self.use_amp): |
| | outputs = self.model(input_ids) |
| | |
| | if isinstance(outputs, torch.Tensor): |
| | logits = outputs |
| | elif hasattr(outputs, 'logits'): |
| | logits = outputs.logits |
| | else: |
| | logits = outputs[0] |
| | loss = self.loss_fn(logits, labels) |
| |
|
| | total_loss += loss.item() |
| | total_accuracy += compute_accuracy(logits, labels).item() |
| | num_batches += 1 |
| |
|
| | self.model.train() |
| |
|
| | avg_loss = total_loss / max(num_batches, 1) |
| | avg_accuracy = total_accuracy / max(num_batches, 1) |
| |
|
| | return { |
| | "loss": avg_loss, |
| | "perplexity": compute_perplexity(torch.tensor(avg_loss)).item(), |
| | "accuracy": avg_accuracy, |
| | } |
| |
|
| | def save_checkpoint(self, name: str): |
| | """Save a checkpoint. |
| | |
| | Args: |
| | name: Checkpoint name (e.g., "best", "epoch_1", "step_1000") |
| | """ |
| | checkpoint_path = os.path.join(self.config.checkpoint_dir, name) |
| | os.makedirs(checkpoint_path, exist_ok=True) |
| |
|
| | |
| | model_path = os.path.join(checkpoint_path, "model.pt") |
| | torch.save(self.model.state_dict(), model_path) |
| |
|
| | |
| | optimizer_path = os.path.join(checkpoint_path, "optimizer.pt") |
| | torch.save({ |
| | "optimizer": self.optimizer.state_dict(), |
| | "scheduler": self.scheduler.state_dict(), |
| | "global_step": self.global_step, |
| | "epoch": self.epoch, |
| | "best_val_loss": self.best_val_loss, |
| | "early_stopping_counter": self.early_stopping_counter, |
| | }, optimizer_path) |
| |
|
| | |
| | config_path = os.path.join(checkpoint_path, "config.json") |
| | with open(config_path, "w") as f: |
| | json.dump(self.config.to_dict(), f, indent=2) |
| |
|
| | print(f"Saved checkpoint: {checkpoint_path}") |
| |
|
| | |
| | self._cleanup_checkpoints() |
| |
|
| | def load_checkpoint(self, checkpoint_path: str): |
| | """Load a checkpoint. |
| | |
| | Args: |
| | checkpoint_path: Path to checkpoint directory |
| | """ |
| | |
| | model_path = os.path.join(checkpoint_path, "model.pt") |
| | state_dict = torch.load(model_path, map_location=self.device) |
| |
|
| | |
| | if any(k.startswith("_orig_mod.") for k in state_dict.keys()): |
| | print(" Detected compiled model checkpoint, removing _orig_mod. prefix...") |
| | state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} |
| |
|
| | self.model.load_state_dict(state_dict) |
| |
|
| | |
| | optimizer_path = os.path.join(checkpoint_path, "optimizer.pt") |
| | if os.path.exists(optimizer_path): |
| | state = torch.load(optimizer_path, map_location=self.device) |
| | self.optimizer.load_state_dict(state["optimizer"]) |
| | self.scheduler.load_state_dict(state["scheduler"]) |
| | self.global_step = state["global_step"] |
| | self.epoch = state["epoch"] |
| | self.best_val_loss = state.get("best_val_loss", float("inf")) |
| | self.early_stopping_counter = state.get("early_stopping_counter", 0) |
| |
|
| | |
| | |
| | if "epoch_" in checkpoint_path: |
| | self.epoch += 1 |
| | print(f" Checkpoint was end-of-epoch, will start from epoch {self.epoch + 1}") |
| |
|
| | print(f"Loaded checkpoint: {checkpoint_path}") |
| | print(f" Resuming from step {self.global_step}, epoch {self.epoch}") |
| | print(f" Best val loss so far: {self.best_val_loss:.4f}") |
| |
|
| | def _cleanup_checkpoints(self): |
| | """Remove old checkpoints to save disk space.""" |
| | if self.config.save_total_limit <= 0: |
| | return |
| |
|
| | checkpoint_dir = Path(self.config.checkpoint_dir) |
| | checkpoints = sorted( |
| | [d for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.startswith("step_")], |
| | key=lambda x: int(x.name.split("_")[1]), |
| | ) |
| |
|
| | |
| | while len(checkpoints) > self.config.save_total_limit: |
| | old_checkpoint = checkpoints.pop(0) |
| | print(f"Removing old checkpoint: {old_checkpoint}") |
| | import shutil |
| | shutil.rmtree(old_checkpoint) |
| |
|