File size: 882 Bytes
663494c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import functools
import time
from collections import defaultdict
import torch

time_maps = defaultdict(lambda: 0.0)
count_maps = defaultdict(lambda: 0.0)


def run_time(name):
    def middle(fn):
        def wrapper(*args, **kwargs):
            torch.cuda.synchronize()
            start = time.time()
            res = fn(*args, **kwargs)
            torch.cuda.synchronize()
            time_maps["%s : %s" % (name, fn.__name__)] += time.time() - start
            count_maps["%s : %s" % (name, fn.__name__)] += 1
            print(
                "%s : %s takes up %f "
                % (
                    name,
                    fn.__name__,
                    time_maps["%s : %s" % (name, fn.__name__)]
                    / count_maps["%s : %s" % (name, fn.__name__)],
                )
            )
            return res

        return wrapper

    return middle