File size: 1,878 Bytes
c29babb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from concurrent.futures import ThreadPoolExecutor

from lightning.pytorch import callbacks as pl_callbacks
from typing_extensions import override

from src.utils import logger


class ModelCheckpointParallel(pl_callbacks.ModelCheckpoint):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.threads = []
        self.thread_pool = ThreadPoolExecutor(1, thread_name_prefix="ModelCheckpointParallel")

    @override
    def on_train_batch_end(self, *args, **kwargs):
        trainer = args[0]
        if self._should_skip_saving_checkpoint(trainer):
            return
        self.threads.append(self.thread_pool.submit(super().on_train_batch_end, *args, **kwargs))

    @override
    def on_train_epoch_end(self, *args, **kwargs):
        trainer = args[0]
        if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
            self.threads.append(self.thread_pool.submit(super().on_train_epoch_end, *args, **kwargs))

    @override
    def on_validation_end(self, *args, **kwargs):
        trainer = args[0]
        if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer):
            self.threads.append(self.thread_pool.submit(super().on_validation_end, *args, **kwargs))

    def wait(self):
        for thread in self.threads:
            try:
                thread.result()
            except Exception as e:
                logger.print_error(f"Exception during checkpoint saving in thread: {e}")
        self.thread_pool.shutdown(wait=True)
        self.thread_pool = ThreadPoolExecutor(1, thread_name_prefix="ModelCheckpointParallel")
        self.threads = []

    @override
    def on_train_end(self, *args, **kwargs):
        self.wait()

    @override
    def on_test_start(self, *args, **kwargs):
        self.wait()