| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Decorators. |
| | """ |
| |
|
| | import functools |
| | import threading |
| | import time |
| | from typing import Callable |
| | import torch |
| |
|
| | from common.distributed import barrier_if_distributed, get_global_rank, get_local_rank |
| | from common.logger import get_logger |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | def log_on_entry(func: Callable) -> Callable: |
| | """ |
| | Functions with this decorator will log the function name at entry. |
| | When using multiple decorators, this must be applied innermost to properly capture the name. |
| | """ |
| |
|
| | def log_on_entry_wrapper(*args, **kwargs): |
| | logger.info(f"Entering {func.__name__}") |
| | return func(*args, **kwargs) |
| |
|
| | return log_on_entry_wrapper |
| |
|
| |
|
| | def barrier_on_entry(func: Callable) -> Callable: |
| | """ |
| | Functions with this decorator will start executing when all ranks are ready to enter. |
| | """ |
| |
|
| | def barrier_on_entry_wrapper(*args, **kwargs): |
| | barrier_if_distributed() |
| | return func(*args, **kwargs) |
| |
|
| | return barrier_on_entry_wrapper |
| |
|
| |
|
| | def _conditional_execute_wrapper_factory(execute: bool, func: Callable) -> Callable: |
| | """ |
| | Helper function for local_rank_zero_only and global_rank_zero_only. |
| | """ |
| |
|
| | def conditional_execute_wrapper(*args, **kwargs): |
| | |
| | result = func(*args, **kwargs) if execute else None |
| | |
| | barrier_if_distributed() |
| | |
| | return result |
| |
|
| | return conditional_execute_wrapper |
| |
|
| |
|
| | def _asserted_wrapper_factory(condition: bool, func: Callable, err_msg: str = "") -> Callable: |
| | """ |
| | Helper function for some functions with special constraints, |
| | especially functions called by other global_rank_zero_only / local_rank_zero_only ones, |
| | in case they are wrongly invoked in other scenarios. |
| | """ |
| |
|
| | def asserted_execute_wrapper(*args, **kwargs): |
| | assert condition, err_msg |
| | result = func(*args, **kwargs) |
| | return result |
| |
|
| | return asserted_execute_wrapper |
| |
|
| |
|
| | def local_rank_zero_only(func: Callable) -> Callable: |
| | """ |
| | Functions with this decorator will only execute on local rank zero. |
| | """ |
| | return _conditional_execute_wrapper_factory(get_local_rank() == 0, func) |
| |
|
| |
|
| | def global_rank_zero_only(func: Callable) -> Callable: |
| | """ |
| | Functions with this decorator will only execute on global rank zero. |
| | """ |
| | return _conditional_execute_wrapper_factory(get_global_rank() == 0, func) |
| |
|
| |
|
| | def assert_only_global_rank_zero(func: Callable) -> Callable: |
| | """ |
| | Functions with this decorator are only accessible to processes with global rank zero. |
| | """ |
| | return _asserted_wrapper_factory( |
| | get_global_rank() == 0, func, err_msg="Not accessible to processes with global_rank != 0" |
| | ) |
| |
|
| |
|
| | def assert_only_local_rank_zero(func: Callable) -> Callable: |
| | """ |
| | Functions with this decorator are only accessible to processes with local rank zero. |
| | """ |
| | return _asserted_wrapper_factory( |
| | get_local_rank() == 0, func, err_msg="Not accessible to processes with local_rank != 0" |
| | ) |
| |
|
| |
|
| | def new_thread(func: Callable) -> Callable: |
| | """ |
| | Functions with this decorator will run in a new thread. |
| | The function will return the thread, which can be joined to wait for completion. |
| | """ |
| |
|
| | def new_thread_wrapper(*args, **kwargs): |
| | thread = threading.Thread(target=func, args=args, kwargs=kwargs) |
| | thread.start() |
| | return thread |
| |
|
| | return new_thread_wrapper |
| |
|
| |
|
| | def log_runtime(func: Callable) -> Callable: |
| | """ |
| | Functions with this decorator will logging the runtime. |
| | """ |
| |
|
| | @functools.wraps(func) |
| | def wrapped(*args, **kwargs): |
| | torch.distributed.barrier() |
| | start = time.perf_counter() |
| | result = func(*args, **kwargs) |
| | torch.distributed.barrier() |
| | logger.info(f"Completed {func.__name__} in {time.perf_counter() - start:.3f} seconds.") |
| | return result |
| |
|
| | return wrapped |
| |
|