| import torch | |
| from time import time | |
| __all__ = ['timer'] | |
| def timer(f, *args, text='', text_size=64, **kwargs): | |
| if isinstance(text, str) and len(text) > 0: | |
| text = text | |
| elif hasattr(f, '__name__'): | |
| text = f.__name__ | |
| elif hasattr(f, '__class__'): | |
| text = f.__class__.__name__ | |
| else: | |
| text = '' | |
| torch.cuda.synchronize() | |
| start = time() | |
| out = f(*args, **kwargs) | |
| torch.cuda.synchronize() | |
| padding = '.' * (text_size - len(text)) | |
| print(f'{text}{padding}: {time() - start:0.3f}s') | |
| return out | |