"""Training callbacks for MANIFOLD.""" from __future__ import annotations from abc import ABC, abstractmethod from pathlib import Path from typing import Optional, Dict, Any, List, TYPE_CHECKING import json import time if TYPE_CHECKING: from manifold.training.trainer import MANIFOLDTrainer class Callback(ABC): """Base class for training callbacks.""" def on_train_start(self, trainer: "MANIFOLDTrainer") -> None: """Called at the start of training.""" pass def on_train_end(self, trainer: "MANIFOLDTrainer") -> None: """Called at the end of training.""" pass def on_epoch_start(self, trainer: "MANIFOLDTrainer", epoch: int) -> None: """Called at the start of each epoch.""" pass def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None: """Called at the end of each epoch.""" pass def on_batch_end(self, trainer: "MANIFOLDTrainer", batch_info: Dict[str, Any]) -> None: """Called at the end of each batch.""" pass class CheckpointCallback(Callback): """Save model checkpoints during training.""" def __init__( self, save_dir: str | Path, save_every_n_epochs: int = 5, save_best: bool = True, monitor: str = "val_loss", mode: str = "min", ): self.save_dir = Path(save_dir) self.save_dir.mkdir(parents=True, exist_ok=True) self.save_every_n_epochs = save_every_n_epochs self.save_best = save_best self.monitor = monitor self.mode = mode self.best_value = float("inf") if mode == "min" else float("-inf") def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None: epoch = epoch_info["epoch"] # Save periodic checkpoint if (epoch + 1) % self.save_every_n_epochs == 0: path = self.save_dir / f"checkpoint_epoch_{epoch+1}.pt" trainer.save_checkpoint(path) # Save best checkpoint if self.save_best: current = epoch_info.get("val", {}).get("loss", float("inf")) is_best = (self.mode == "min" and current < self.best_value) or \ (self.mode == "max" and current > self.best_value) if is_best: self.best_value = current path = self.save_dir / "best_model.pt" trainer.save_checkpoint(path) class EarlyStoppingCallback(Callback): """Stop training when metric stops improving.""" def __init__( self, monitor: str = "val_loss", patience: int = 10, min_delta: float = 0.0, mode: str = "min", ): self.monitor = monitor self.patience = patience self.min_delta = min_delta self.mode = mode self.best_value = float("inf") if mode == "min" else float("-inf") self.counter = 0 self.should_stop = False def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None: current = epoch_info.get("val", {}).get("loss", float("inf")) if self.mode == "min": improved = current < self.best_value - self.min_delta else: improved = current > self.best_value + self.min_delta if improved: self.best_value = current self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: self.should_stop = True print(f"Early stopping triggered after {self.counter} epochs without improvement") class WandBCallback(Callback): """Log metrics to Weights & Biases.""" def __init__( self, project: str = "manifold", name: Optional[str] = None, config: Optional[Dict[str, Any]] = None, ): self.project = project self.name = name self.config = config self._wandb = None self._run = None def on_train_start(self, trainer: "MANIFOLDTrainer") -> None: try: import wandb self._wandb = wandb self._run = wandb.init( project=self.project, name=self.name, config=self.config or {}, ) except ImportError: print("wandb not installed, skipping WandB logging") def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None: if self._wandb is None: return metrics = { "epoch": epoch_info["epoch"], "lr": epoch_info.get("lr", 0), "stage": epoch_info.get("stage", {}).get("stage_name", ""), } for prefix in ["train", "val"]: for k, v in epoch_info.get(prefix, {}).items(): metrics[f"{prefix}/{k}"] = v self._wandb.log(metrics) def on_train_end(self, trainer: "MANIFOLDTrainer") -> None: if self._run: self._run.finish() class ProgressCallback(Callback): """Print training progress.""" def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None: epoch = epoch_info["epoch"] stage = epoch_info.get("stage", {}).get("stage_name", "") train_loss = epoch_info.get("train", {}).get("loss", 0) val_loss = epoch_info.get("val", {}).get("loss", 0) val_acc = epoch_info.get("val", {}).get("accuracy", 0) lr = epoch_info.get("lr", 0) print(f"Epoch {epoch+1} | {stage} | " f"Train Loss: {train_loss:.4f} | " f"Val Loss: {val_loss:.4f} | " f"Val Acc: {val_acc:.4f} | " f"LR: {lr:.2e}") class CallbackManager: """Orchestrate multiple callbacks.""" def __init__(self, callbacks: Optional[List[Callback]] = None): self.callbacks = callbacks or [] def add(self, callback: Callback) -> None: self.callbacks.append(callback) def on_train_start(self, trainer: "MANIFOLDTrainer") -> None: for cb in self.callbacks: cb.on_train_start(trainer) def on_train_end(self, trainer: "MANIFOLDTrainer") -> None: for cb in self.callbacks: cb.on_train_end(trainer) def on_epoch_start(self, trainer: "MANIFOLDTrainer", epoch: int) -> None: for cb in self.callbacks: cb.on_epoch_start(trainer, epoch) def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None: for cb in self.callbacks: cb.on_epoch_end(trainer, epoch_info) def on_batch_end(self, trainer: "MANIFOLDTrainer", batch_info: Dict[str, Any]) -> None: for cb in self.callbacks: cb.on_batch_end(trainer, batch_info)