Spaces:
Running
on
Zero
Running
on
Zero
| """Training callbacks for MANIFOLD.""" | |
| from __future__ import annotations | |
| from abc import ABC, abstractmethod | |
| from pathlib import Path | |
| from typing import Optional, Dict, Any, List, TYPE_CHECKING | |
| import json | |
| import time | |
| if TYPE_CHECKING: | |
| from manifold.training.trainer import MANIFOLDTrainer | |
| class Callback(ABC): | |
| """Base class for training callbacks.""" | |
| def on_train_start(self, trainer: "MANIFOLDTrainer") -> None: | |
| """Called at the start of training.""" | |
| pass | |
| def on_train_end(self, trainer: "MANIFOLDTrainer") -> None: | |
| """Called at the end of training.""" | |
| pass | |
| def on_epoch_start(self, trainer: "MANIFOLDTrainer", epoch: int) -> None: | |
| """Called at the start of each epoch.""" | |
| pass | |
| def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None: | |
| """Called at the end of each epoch.""" | |
| pass | |
| def on_batch_end(self, trainer: "MANIFOLDTrainer", batch_info: Dict[str, Any]) -> None: | |
| """Called at the end of each batch.""" | |
| pass | |
| class CheckpointCallback(Callback): | |
| """Save model checkpoints during training.""" | |
| def __init__( | |
| self, | |
| save_dir: str | Path, | |
| save_every_n_epochs: int = 5, | |
| save_best: bool = True, | |
| monitor: str = "val_loss", | |
| mode: str = "min", | |
| ): | |
| self.save_dir = Path(save_dir) | |
| self.save_dir.mkdir(parents=True, exist_ok=True) | |
| self.save_every_n_epochs = save_every_n_epochs | |
| self.save_best = save_best | |
| self.monitor = monitor | |
| self.mode = mode | |
| self.best_value = float("inf") if mode == "min" else float("-inf") | |
| def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None: | |
| epoch = epoch_info["epoch"] | |
| # Save periodic checkpoint | |
| if (epoch + 1) % self.save_every_n_epochs == 0: | |
| path = self.save_dir / f"checkpoint_epoch_{epoch+1}.pt" | |
| trainer.save_checkpoint(path) | |
| # Save best checkpoint | |
| if self.save_best: | |
| current = epoch_info.get("val", {}).get("loss", float("inf")) | |
| is_best = (self.mode == "min" and current < self.best_value) or \ | |
| (self.mode == "max" and current > self.best_value) | |
| if is_best: | |
| self.best_value = current | |
| path = self.save_dir / "best_model.pt" | |
| trainer.save_checkpoint(path) | |
| class EarlyStoppingCallback(Callback): | |
| """Stop training when metric stops improving.""" | |
| def __init__( | |
| self, | |
| monitor: str = "val_loss", | |
| patience: int = 10, | |
| min_delta: float = 0.0, | |
| mode: str = "min", | |
| ): | |
| self.monitor = monitor | |
| self.patience = patience | |
| self.min_delta = min_delta | |
| self.mode = mode | |
| self.best_value = float("inf") if mode == "min" else float("-inf") | |
| self.counter = 0 | |
| self.should_stop = False | |
| def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None: | |
| current = epoch_info.get("val", {}).get("loss", float("inf")) | |
| if self.mode == "min": | |
| improved = current < self.best_value - self.min_delta | |
| else: | |
| improved = current > self.best_value + self.min_delta | |
| if improved: | |
| self.best_value = current | |
| self.counter = 0 | |
| else: | |
| self.counter += 1 | |
| if self.counter >= self.patience: | |
| self.should_stop = True | |
| print(f"Early stopping triggered after {self.counter} epochs without improvement") | |
| class WandBCallback(Callback): | |
| """Log metrics to Weights & Biases.""" | |
| def __init__( | |
| self, | |
| project: str = "manifold", | |
| name: Optional[str] = None, | |
| config: Optional[Dict[str, Any]] = None, | |
| ): | |
| self.project = project | |
| self.name = name | |
| self.config = config | |
| self._wandb = None | |
| self._run = None | |
| def on_train_start(self, trainer: "MANIFOLDTrainer") -> None: | |
| try: | |
| import wandb | |
| self._wandb = wandb | |
| self._run = wandb.init( | |
| project=self.project, | |
| name=self.name, | |
| config=self.config or {}, | |
| ) | |
| except ImportError: | |
| print("wandb not installed, skipping WandB logging") | |
| def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None: | |
| if self._wandb is None: | |
| return | |
| metrics = { | |
| "epoch": epoch_info["epoch"], | |
| "lr": epoch_info.get("lr", 0), | |
| "stage": epoch_info.get("stage", {}).get("stage_name", ""), | |
| } | |
| for prefix in ["train", "val"]: | |
| for k, v in epoch_info.get(prefix, {}).items(): | |
| metrics[f"{prefix}/{k}"] = v | |
| self._wandb.log(metrics) | |
| def on_train_end(self, trainer: "MANIFOLDTrainer") -> None: | |
| if self._run: | |
| self._run.finish() | |
| class ProgressCallback(Callback): | |
| """Print training progress.""" | |
| def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None: | |
| epoch = epoch_info["epoch"] | |
| stage = epoch_info.get("stage", {}).get("stage_name", "") | |
| train_loss = epoch_info.get("train", {}).get("loss", 0) | |
| val_loss = epoch_info.get("val", {}).get("loss", 0) | |
| val_acc = epoch_info.get("val", {}).get("accuracy", 0) | |
| lr = epoch_info.get("lr", 0) | |
| print(f"Epoch {epoch+1} | {stage} | " | |
| f"Train Loss: {train_loss:.4f} | " | |
| f"Val Loss: {val_loss:.4f} | " | |
| f"Val Acc: {val_acc:.4f} | " | |
| f"LR: {lr:.2e}") | |
| class CallbackManager: | |
| """Orchestrate multiple callbacks.""" | |
| def __init__(self, callbacks: Optional[List[Callback]] = None): | |
| self.callbacks = callbacks or [] | |
| def add(self, callback: Callback) -> None: | |
| self.callbacks.append(callback) | |
| def on_train_start(self, trainer: "MANIFOLDTrainer") -> None: | |
| for cb in self.callbacks: | |
| cb.on_train_start(trainer) | |
| def on_train_end(self, trainer: "MANIFOLDTrainer") -> None: | |
| for cb in self.callbacks: | |
| cb.on_train_end(trainer) | |
| def on_epoch_start(self, trainer: "MANIFOLDTrainer", epoch: int) -> None: | |
| for cb in self.callbacks: | |
| cb.on_epoch_start(trainer, epoch) | |
| def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None: | |
| for cb in self.callbacks: | |
| cb.on_epoch_end(trainer, epoch_info) | |
| def on_batch_end(self, trainer: "MANIFOLDTrainer", batch_info: Dict[str, Any]) -> None: | |
| for cb in self.callbacks: | |
| cb.on_batch_end(trainer, batch_info) | |