| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
|
|
| from tqdm import tqdm |
| from typing import List, Any, Optional, Tuple, Dict |
|
|
| |
| from data.dataloader import DataLoader |
|
|
|
|
| class EPUTrainer: |
| def __init__(self, |
| model: nn.Module, |
| device: torch.device, |
| optimizer: optim.Optimizer, |
| criterion: nn.Module, |
| epochs: int, |
| train_loader: DataLoader, |
| val_loader: Optional[DataLoader] = None, |
| callbacks: Optional[List[object]] = None, |
| metrics: Optional = None, |
| checkpoint_dir: Optional[str] = None, |
| ): |
| self.model = model |
| self.val_loader = val_loader |
| self.train_loader = train_loader |
|
|
| self.device = device |
| self.epochs = epochs |
| self.optimizer = optimizer |
| self.criterion = criterion |
| self.callbacks = callbacks or [] |
| self.checkpoint_dir = checkpoint_dir |
|
|
| self.metrics_fun = metrics |
| |
|
|
| |
| self.best_metric = float("inf") |
| self.best_model_path = None |
| self.history = [] |
|
|
| self.state = {"model": self.model, |
| "epoch": 0, |
| "early_stop": False, |
| } |
|
|
| def train(self): |
| self.model.to(self.device) |
|
|
| self._on_training_begin() |
|
|
| for epoch in range(self.epochs): |
| self.state["epoch"] = epoch |
| self._on_epoch_begin() |
|
|
| train_loss, train_metrics = self._train_one_epoch() |
| val_loss, val_metrics = self._validate_epoch() |
|
|
| self.history.append({"epoch": epoch, |
| "train_loss": train_loss, |
| "val_loss": val_loss, |
| "train_metrics": train_metrics, |
| "val_metrics": val_metrics,} |
| ) |
|
|
| self._on_epoch_end(train_loss, train_metrics, val_loss, val_metrics) |
| self._on_validation_end() |
|
|
| if self.state.get("early_stop", False): |
| print("Early stopping triggered.") |
| break |
|
|
| self._on_training_end() |
| |
|
|
| def _train_one_epoch(self) -> Tuple[float, Dict[str, float]]: |
| self.model.train() |
| running_loss = 0.0 |
| predictions, ground_truth = [], [] |
|
|
| for i, sample in enumerate(tqdm(self.train_loader, desc=f"Training Epoch {self.state['epoch'] + 1}")): |
| x, y = sample |
| x = x.to(self.device) |
| y = y.to(self.device, dtype=torch.float32).unsqueeze(1) |
|
|
| self.optimizer.zero_grad() |
|
|
| y_hat = self.model(x, ret_raw_logits=True) |
| loss = self.criterion(y_hat, y) |
|
|
| loss.backward() |
| self.optimizer.step() |
|
|
| running_loss += loss.item() |
| predictions.append(y_hat.detach().cpu()) |
| ground_truth.append(y.detach().cpu()) |
|
|
| for callback in self.callbacks: |
| if hasattr(callback, "on_batch_end"): |
| callback.on_batch_end( |
| {**self.state, |
| "batch": i, |
| "loss": loss.item()} |
| ) |
|
|
| avg_loss = running_loss / len(self.train_loader) |
|
|
| metrics = {} |
| if self.metrics_fun is not None: |
| metrics = self.metrics_fun.compute( |
| y_true=torch.cat(ground_truth, axis=0), |
| y_pred=torch.cat(predictions, axis=0) |
| ) |
| return avg_loss, metrics |
|
|
| def _validate_epoch(self) -> Tuple[float, Dict[str, float]]: |
| if self.val_loader is None: |
| return 0.0, {} |
|
|
| self.model.eval() |
| total_loss = 0 |
| predictions, ground_truths = [], [] |
|
|
| with torch.no_grad(): |
| for sample in tqdm(self.val_loader, desc="Validating"): |
| x, y = sample |
| x = x.to(self.device) |
| y = y.to(self.device, dtype=torch.float32).unsqueeze(1) |
| y_hat = self.model(x, ret_raw_logits=True) |
| loss = self.criterion(y_hat, y) |
|
|
| total_loss += loss.item() |
| predictions.append(y_hat.detach().cpu()) |
| ground_truths.append(y.detach().cpu()) |
|
|
| avg_loss = total_loss / len(self.val_loader) |
| metrics = {} |
| if self.metrics_fun is not None: |
| metrics = self.metrics_fun.compute( |
| y_true=torch.cat(ground_truths, axis=0), |
| y_pred=torch.cat(predictions, axis=0) |
| ) |
|
|
| return avg_loss, metrics |
|
|
| def _on_training_begin(self): |
| for callback in self.callbacks: |
| if hasattr(callback, "on_training_begin"): |
| callback.on_training_begin(self.state) |
|
|
| def _on_epoch_begin(self): |
| for callback in self.callbacks: |
| if hasattr(callback, "on_epoch_begin"): |
| callback.on_epoch_begin(self.state) |
|
|
| def _on_epoch_end(self, train_loss, train_metrics, val_loss, val_metrics): |
| |
| self.state.update( |
| {"train_loss": train_loss, |
| "val_loss": val_loss, |
| "train_metrics": train_metrics, |
| "val_metrics": val_metrics, |
| } |
| ) |
| |
| print(f"Epoch {self.state['epoch'] + 1} | " |
| f"Train loss: {train_loss:.4f} | Validation Loss: {val_loss:.4f}") |
|
|
| |
| if train_metrics is not None: |
| train_metrics_str = " | ".join([f"{k}: {v:.4f}" for k, v in train_metrics.items()]) |
| print(f"Train metrics:\t\t {train_metrics_str}") |
| if val_metrics: |
| val_metrics_str = " | ".join([f"{k}: {v:.4f}" for k, v in val_metrics.items()]) |
| print(f"Validation metrics:\t {val_metrics_str}") |
|
|
| |
| for callback in self.callbacks: |
| if hasattr(callback, "on_epoch_end"): |
| callback.on_epoch_end(self.state) |
|
|
| def _on_validation_end(self,): |
| for callback in self.callbacks: |
| if hasattr(callback, "on_validation_end"): |
| |
| callback.on_validation_end(self.state) |
|
|
| def _on_training_end(self): |
| for callback in self.callbacks: |
| if hasattr(callback, "on_training_end"): |
| callback.on_training_end(self.state) |
|
|
| def get_model(self) -> torch.nn.Module: |
| return self.model |
|
|
| def get_metrics(self): |
| return self.metrics_fun |
|
|
| |
| |
| |
| |
| |
| |
|
|