File size: 3,238 Bytes
cb0ad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from collections import defaultdict
import time
import datetime


class TimeCounter:
    def __init__(self, start_epoch, num_epochs, epoch_iters):
        self.start_epoch = start_epoch
        self.num_epochs = num_epochs
        self.epoch_iters = epoch_iters
        self.start_time = None

    def reset(self):
        self.start_time = time.time()

    def step(self, epoch, batch):
        used = time.time() - self.start_time
        finished_batch_nums = (epoch - self.start_epoch) * self.epoch_iters + batch
        batch_time_cost = used / finished_batch_nums
        total = (self.num_epochs - self.start_epoch) * self.epoch_iters * batch_time_cost
        left = total - used
        return str(datetime.timedelta(seconds=left))


def format_table(table, padding=1):
    table = [[str(subitem) for subitem in item] for item in table]
    num_cols = max([len(item) for item in table])
    cols_width = [0] * num_cols

    for row in table:
        for col_idx, cell in enumerate(row):
            cols_width[col_idx] = max(cols_width[col_idx], len(cell))

    string = '��'
    for col_idx in range(num_cols):
        string += '��' * (padding * 2 + cols_width[col_idx])
        if col_idx == num_cols - 1:
            string += '��'
        else:
            string += '��'
    string += '\n'

    for row_idx, row in enumerate(table):
        string += '��'
        for col_idx in range(num_cols):
            if col_idx < len(row):
                word = row[col_idx]
            else:
                word = ''
            col_width = cols_width[col_idx]
            left_pad = (col_width - len(word))//2
            right_pad = col_width - len(word) - left_pad
            string += ' ' * (padding + left_pad)
            string += word
            string += ' ' * (padding + right_pad)
            string += '��'
        
        string += '\n'

        if row_idx < len(table) - 1:
            string += '��'
        else:
            string += '��'
        for col_idx in range(num_cols):
            string += '��' * (padding * 2 + cols_width[col_idx])
            if col_idx == num_cols - 1:
                if row_idx < len(table) - 1:
                    string += '��'
                else:
                    string += '��'
            else:
                if row_idx < len(table) - 1:
                    string += '��'
                else:
                    string += '��'

        string += '\n'
    return string


class TicTocCounter:
    def __init__(self):
        self.tics = dict()
        self.seps = defaultdict(list)

    def tic(self, name):
        self.tics[name] = time.time()
    
    def toc(self, name):
        toc = time.time()
        if name in self.tics:
            self.seps[name].append(toc-self.tics[name])
    
    def __repr__(self):
        string = 'TicTocCount Result:\n'
        infos = [['Name', 'Mean Time', 'Total Time']]
        for key, val in self.seps.items():
            mean = sum(val)/len(val)
            total = sum(val)
            infos.append([key, '%0.4f' % mean, '%0.4f' % total])
        string += format_table(infos)
        return string

    def reset(self):
        self.tics.clear()
        self.seps.clear()

global_tictoc_counter = TicTocCounter()