File size: 2,652 Bytes
fbb20ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from collections import deque
from time import time

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only


class TrainSpeedTimer(pl.Callback):
    def __init__(self, N_avg=5):
        """
        This callback times the training speed (averge over recent 5 iterations)
            1. Data waiting time: this should be small, otherwise the data loading should be improved
            2. Single batch time: this is the time for one batch of training (excluding data waiting)
        """
        super().__init__()
        self.last_batch_end = None
        self.this_batch_start = None

        # time queues for averaging
        self.data_waiting_time_queue = deque(maxlen=N_avg)
        self.single_batch_time_queue = deque(maxlen=N_avg)

    @rank_zero_only
    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        """Count the time of data waiting"""
        if self.last_batch_end is not None:
            # This should be small, otherwise the data loading should be improved
            data_waiting = time() - self.last_batch_end

            # Average the time
            self.data_waiting_time_queue.append(data_waiting)
            average_time = sum(self.data_waiting_time_queue) / len(
                self.data_waiting_time_queue
            )

            # Log to prog-bar
            pl_module.log(
                "train_timer/data_waiting",
                average_time,
                on_step=True,
                on_epoch=False,
                prog_bar=True,
                logger=True,
            )

        self.this_batch_start = time()

    @rank_zero_only
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        # Effective training time elapsed (excluding data waiting)
        single_batch = time() - self.this_batch_start

        # Average the time
        self.single_batch_time_queue.append(single_batch)
        average_time = sum(self.single_batch_time_queue) / len(
            self.single_batch_time_queue
        )

        # Log iter time
        pl_module.log(
            "train_timer/single_batch",
            average_time,
            on_step=True,
            on_epoch=False,
            prog_bar=False,
            logger=True,
        )

        # Set timer for counting data waiting
        self.last_batch_end = time()

    @rank_zero_only
    def on_train_epoch_end(self, trainer, pl_module):
        # Reset the timer
        self.last_batch_end = None
        self.this_batch_start = None
        # Clear the queue
        self.data_waiting_time_queue.clear()
        self.single_batch_time_queue.clear()