File size: 557 Bytes
26225c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
from time import time
__all__ = ['timer']
def timer(f, *args, text='', text_size=64, **kwargs):
if isinstance(text, str) and len(text) > 0:
text = text
elif hasattr(f, '__name__'):
text = f.__name__
elif hasattr(f, '__class__'):
text = f.__class__.__name__
else:
text = ''
torch.cuda.synchronize()
start = time()
out = f(*args, **kwargs)
torch.cuda.synchronize()
padding = '.' * (text_size - len(text))
print(f'{text}{padding}: {time() - start:0.3f}s')
return out
|