File size: 2,256 Bytes
78a947a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
"""
Contains an implementation of a ModelCheckpoint class that saves and loads the best
model weights and associated optimizer state based on a specified metric.
Author: Ole-Christian Galbo Engstrøm
E-mail: ocge@foss.dk
"""
from pathlib import Path
import torch
class ModelCheckpoint:
def __init__(
self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
mode: str,
save_dir: Path,
) -> None:
self.model = model
self.optim = optim
if mode not in ["min", "max"]:
raise ValueError("mode must be 'min' or 'max'")
self.mode = mode
self.save_dir = save_dir
self.save_dir.mkdir(parents=True, exist_ok=True)
self.model_path = self.save_dir / "best_weights.pt"
self.optim_path = self.save_dir / "best_weights_optim_state.pt"
self.best_metric = torch.tensor(
float("inf") if self.mode == "min" else float("-inf"),
dtype=torch.float32,
requires_grad=False,
)
def save_weights(self):
torch.save(self.model.state_dict(), self.model_path)
torch.save(self.optim.state_dict(), self.optim_path)
def save_best_weights(self, metric: torch.Tensor):
if (self.mode == "min" and metric < self.best_metric) or (
self.mode == "max" and metric > self.best_metric
):
print(f"New best metric found {self.best_metric} -> {metric}.")
self.best_metric = metric
self.save_weights()
return True
return False
def load_best_weights(self, keep_current_lr: bool = True):
print(f"Loading best weights from {self.model_path}")
self.model.load_state_dict(torch.load(self.model_path, weights_only=False))
# Restore the state of the optimizer but keep the current learning rate
print(f"Loading best optimizer state from {self.optim_path}")
if keep_current_lr:
current_lr = self.optim.param_groups[0]["lr"]
self.optim.load_state_dict(torch.load(self.optim_path, weights_only=False))
if keep_current_lr:
print(f"Restoring current learning rate: {current_lr}")
self.optim.param_groups[0]["lr"] = current_lr
|