unet_chemical_map / src /callbacks /model_checkpoint.py
Sm00thix's picture
Initial upload of source and weights
78a947a
"""
Contains an implementation of a ModelCheckpoint class that saves and loads the best
model weights and associated optimizer state based on a specified metric.
Author: Ole-Christian Galbo Engstrøm
E-mail: ocge@foss.dk
"""
from pathlib import Path
import torch
class ModelCheckpoint:
def __init__(
self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
mode: str,
save_dir: Path,
) -> None:
self.model = model
self.optim = optim
if mode not in ["min", "max"]:
raise ValueError("mode must be 'min' or 'max'")
self.mode = mode
self.save_dir = save_dir
self.save_dir.mkdir(parents=True, exist_ok=True)
self.model_path = self.save_dir / "best_weights.pt"
self.optim_path = self.save_dir / "best_weights_optim_state.pt"
self.best_metric = torch.tensor(
float("inf") if self.mode == "min" else float("-inf"),
dtype=torch.float32,
requires_grad=False,
)
def save_weights(self):
torch.save(self.model.state_dict(), self.model_path)
torch.save(self.optim.state_dict(), self.optim_path)
def save_best_weights(self, metric: torch.Tensor):
if (self.mode == "min" and metric < self.best_metric) or (
self.mode == "max" and metric > self.best_metric
):
print(f"New best metric found {self.best_metric} -> {metric}.")
self.best_metric = metric
self.save_weights()
return True
return False
def load_best_weights(self, keep_current_lr: bool = True):
print(f"Loading best weights from {self.model_path}")
self.model.load_state_dict(torch.load(self.model_path, weights_only=False))
# Restore the state of the optimizer but keep the current learning rate
print(f"Loading best optimizer state from {self.optim_path}")
if keep_current_lr:
current_lr = self.optim.param_groups[0]["lr"]
self.optim.load_state_dict(torch.load(self.optim_path, weights_only=False))
if keep_current_lr:
print(f"Restoring current learning rate: {current_lr}")
self.optim.param_groups[0]["lr"] = current_lr