| import logging |
| from contextlib import contextmanager |
| from functools import wraps |
| from time import time |
|
|
| import torch.distributed |
|
|
| from .misc import SingletonMeta |
|
|
| __all__ = ["Timer", "timer"] |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class Timer(metaclass=SingletonMeta): |
| def __init__(self): |
| self.timers = {} |
| self.start_time = {} |
|
|
| def start(self, name): |
| assert name not in self.start_time, f"Timer {name} already started." |
| self.start_time[name] = time() |
| if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: |
| logger.info(f"Timer {name} start") |
|
|
| def end(self, name): |
| assert name in self.start_time, f"Timer {name} not started." |
| elapsed_time = time() - self.start_time[name] |
| self.add(name, elapsed_time) |
| del self.start_time[name] |
| if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: |
| logger.info(f"Timer {name} end (elapsed: {elapsed_time:.1f}s)") |
|
|
| def reset(self, name=None): |
| if name is None: |
| self.timers = {} |
| elif name in self.timers: |
| del self.timers[name] |
|
|
| def add(self, name, elapsed_time): |
| self.timers[name] = self.timers.get(name, 0) + elapsed_time |
|
|
| def log_dict(self): |
| return self.timers |
|
|
| @contextmanager |
| def context(self, name): |
| self.start(name) |
| try: |
| yield |
| finally: |
| self.end(name) |
|
|
|
|
| def timer(name_or_func): |
| """ |
| Can be used either as a decorator or a context manager: |
| |
| @timer |
| def func(): |
| ... |
| |
| or |
| |
| with timer("block_name"): |
| ... |
| """ |
| |
| if isinstance(name_or_func, str): |
| name = name_or_func |
| return Timer().context(name) |
|
|
| func = name_or_func |
|
|
| @wraps(func) |
| def wrapper(*args, **kwargs): |
| with Timer().context(func.__name__): |
| return func(*args, **kwargs) |
|
|
| return wrapper |
|
|
|
|
| @contextmanager |
| def inverse_timer(name): |
| Timer().end(name) |
| try: |
| yield |
| finally: |
| Timer().start(name) |
|
|
|
|
| def with_defer(deferred_func): |
| def decorator(fn): |
| @wraps(fn) |
| def wrapper(*args, **kwargs): |
| try: |
| return fn(*args, **kwargs) |
| finally: |
| deferred_func() |
|
|
| return wrapper |
|
|
| return decorator |
|
|