File size: 882 Bytes
663494c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import functools
import time
from collections import defaultdict
import torch
time_maps = defaultdict(lambda: 0.0)
count_maps = defaultdict(lambda: 0.0)
def run_time(name):
def middle(fn):
def wrapper(*args, **kwargs):
torch.cuda.synchronize()
start = time.time()
res = fn(*args, **kwargs)
torch.cuda.synchronize()
time_maps["%s : %s" % (name, fn.__name__)] += time.time() - start
count_maps["%s : %s" % (name, fn.__name__)] += 1
print(
"%s : %s takes up %f "
% (
name,
fn.__name__,
time_maps["%s : %s" % (name, fn.__name__)]
/ count_maps["%s : %s" % (name, fn.__name__)],
)
)
return res
return wrapper
return middle
|