import functools import time from collections import defaultdict import torch time_maps = defaultdict(lambda: 0.0) count_maps = defaultdict(lambda: 0.0) def run_time(name): def middle(fn): def wrapper(*args, **kwargs): torch.cuda.synchronize() start = time.time() res = fn(*args, **kwargs) torch.cuda.synchronize() time_maps["%s : %s" % (name, fn.__name__)] += time.time() - start count_maps["%s : %s" % (name, fn.__name__)] += 1 print( "%s : %s takes up %f " % ( name, fn.__name__, time_maps["%s : %s" % (name, fn.__name__)] / count_maps["%s : %s" % (name, fn.__name__)], ) ) return res return wrapper return middle