LimmeDev's picture
Initial MANIFOLD upload - CS2 cheat detection training
454ecdd verified
"""Training callbacks for MANIFOLD."""
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Dict, Any, List, TYPE_CHECKING
import json
import time
if TYPE_CHECKING:
from manifold.training.trainer import MANIFOLDTrainer
class Callback(ABC):
"""Base class for training callbacks."""
def on_train_start(self, trainer: "MANIFOLDTrainer") -> None:
"""Called at the start of training."""
pass
def on_train_end(self, trainer: "MANIFOLDTrainer") -> None:
"""Called at the end of training."""
pass
def on_epoch_start(self, trainer: "MANIFOLDTrainer", epoch: int) -> None:
"""Called at the start of each epoch."""
pass
def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None:
"""Called at the end of each epoch."""
pass
def on_batch_end(self, trainer: "MANIFOLDTrainer", batch_info: Dict[str, Any]) -> None:
"""Called at the end of each batch."""
pass
class CheckpointCallback(Callback):
"""Save model checkpoints during training."""
def __init__(
self,
save_dir: str | Path,
save_every_n_epochs: int = 5,
save_best: bool = True,
monitor: str = "val_loss",
mode: str = "min",
):
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
self.save_every_n_epochs = save_every_n_epochs
self.save_best = save_best
self.monitor = monitor
self.mode = mode
self.best_value = float("inf") if mode == "min" else float("-inf")
def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None:
epoch = epoch_info["epoch"]
# Save periodic checkpoint
if (epoch + 1) % self.save_every_n_epochs == 0:
path = self.save_dir / f"checkpoint_epoch_{epoch+1}.pt"
trainer.save_checkpoint(path)
# Save best checkpoint
if self.save_best:
current = epoch_info.get("val", {}).get("loss", float("inf"))
is_best = (self.mode == "min" and current < self.best_value) or \
(self.mode == "max" and current > self.best_value)
if is_best:
self.best_value = current
path = self.save_dir / "best_model.pt"
trainer.save_checkpoint(path)
class EarlyStoppingCallback(Callback):
"""Stop training when metric stops improving."""
def __init__(
self,
monitor: str = "val_loss",
patience: int = 10,
min_delta: float = 0.0,
mode: str = "min",
):
self.monitor = monitor
self.patience = patience
self.min_delta = min_delta
self.mode = mode
self.best_value = float("inf") if mode == "min" else float("-inf")
self.counter = 0
self.should_stop = False
def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None:
current = epoch_info.get("val", {}).get("loss", float("inf"))
if self.mode == "min":
improved = current < self.best_value - self.min_delta
else:
improved = current > self.best_value + self.min_delta
if improved:
self.best_value = current
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.should_stop = True
print(f"Early stopping triggered after {self.counter} epochs without improvement")
class WandBCallback(Callback):
"""Log metrics to Weights & Biases."""
def __init__(
self,
project: str = "manifold",
name: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
):
self.project = project
self.name = name
self.config = config
self._wandb = None
self._run = None
def on_train_start(self, trainer: "MANIFOLDTrainer") -> None:
try:
import wandb
self._wandb = wandb
self._run = wandb.init(
project=self.project,
name=self.name,
config=self.config or {},
)
except ImportError:
print("wandb not installed, skipping WandB logging")
def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None:
if self._wandb is None:
return
metrics = {
"epoch": epoch_info["epoch"],
"lr": epoch_info.get("lr", 0),
"stage": epoch_info.get("stage", {}).get("stage_name", ""),
}
for prefix in ["train", "val"]:
for k, v in epoch_info.get(prefix, {}).items():
metrics[f"{prefix}/{k}"] = v
self._wandb.log(metrics)
def on_train_end(self, trainer: "MANIFOLDTrainer") -> None:
if self._run:
self._run.finish()
class ProgressCallback(Callback):
"""Print training progress."""
def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None:
epoch = epoch_info["epoch"]
stage = epoch_info.get("stage", {}).get("stage_name", "")
train_loss = epoch_info.get("train", {}).get("loss", 0)
val_loss = epoch_info.get("val", {}).get("loss", 0)
val_acc = epoch_info.get("val", {}).get("accuracy", 0)
lr = epoch_info.get("lr", 0)
print(f"Epoch {epoch+1} | {stage} | "
f"Train Loss: {train_loss:.4f} | "
f"Val Loss: {val_loss:.4f} | "
f"Val Acc: {val_acc:.4f} | "
f"LR: {lr:.2e}")
class CallbackManager:
"""Orchestrate multiple callbacks."""
def __init__(self, callbacks: Optional[List[Callback]] = None):
self.callbacks = callbacks or []
def add(self, callback: Callback) -> None:
self.callbacks.append(callback)
def on_train_start(self, trainer: "MANIFOLDTrainer") -> None:
for cb in self.callbacks:
cb.on_train_start(trainer)
def on_train_end(self, trainer: "MANIFOLDTrainer") -> None:
for cb in self.callbacks:
cb.on_train_end(trainer)
def on_epoch_start(self, trainer: "MANIFOLDTrainer", epoch: int) -> None:
for cb in self.callbacks:
cb.on_epoch_start(trainer, epoch)
def on_epoch_end(self, trainer: "MANIFOLDTrainer", epoch_info: Dict[str, Any]) -> None:
for cb in self.callbacks:
cb.on_epoch_end(trainer, epoch_info)
def on_batch_end(self, trainer: "MANIFOLDTrainer", batch_info: Dict[str, Any]) -> None:
for cb in self.callbacks:
cb.on_batch_end(trainer, batch_info)