File size: 2,256 Bytes
78a947a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
Contains an implementation of a ModelCheckpoint class that saves and loads the best
model weights and associated optimizer state based on a specified metric.

Author: Ole-Christian Galbo Engstrøm
E-mail: ocge@foss.dk
"""

from pathlib import Path

import torch


class ModelCheckpoint:
    def __init__(
        self,
        model: torch.nn.Module,
        optim: torch.optim.Optimizer,
        mode: str,
        save_dir: Path,
    ) -> None:
        self.model = model
        self.optim = optim
        if mode not in ["min", "max"]:
            raise ValueError("mode must be 'min' or 'max'")
        self.mode = mode
        self.save_dir = save_dir
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.model_path = self.save_dir / "best_weights.pt"
        self.optim_path = self.save_dir / "best_weights_optim_state.pt"
        self.best_metric = torch.tensor(
            float("inf") if self.mode == "min" else float("-inf"),
            dtype=torch.float32,
            requires_grad=False,
        )

    def save_weights(self):
        torch.save(self.model.state_dict(), self.model_path)
        torch.save(self.optim.state_dict(), self.optim_path)

    def save_best_weights(self, metric: torch.Tensor):
        if (self.mode == "min" and metric < self.best_metric) or (
            self.mode == "max" and metric > self.best_metric
        ):
            print(f"New best metric found {self.best_metric} -> {metric}.")
            self.best_metric = metric
            self.save_weights()
            return True
        return False

    def load_best_weights(self, keep_current_lr: bool = True):
        print(f"Loading best weights from {self.model_path}")
        self.model.load_state_dict(torch.load(self.model_path, weights_only=False))
        # Restore the state of the optimizer but keep the current learning rate
        print(f"Loading best optimizer state from {self.optim_path}")
        if keep_current_lr:
            current_lr = self.optim.param_groups[0]["lr"]
        self.optim.load_state_dict(torch.load(self.optim_path, weights_only=False))
        if keep_current_lr:
            print(f"Restoring current learning rate: {current_lr}")
            self.optim.param_groups[0]["lr"] = current_lr