| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import functools |
| | import logging |
| | import os |
| | from typing import Optional |
| |
|
| | from .state import PartialState |
| |
|
| |
|
| | class MultiProcessAdapter(logging.LoggerAdapter): |
| | """ |
| | An adapter to assist with logging in multiprocess. |
| | |
| | `log` takes in an additional `main_process_only` kwarg, which dictates whether it should be called on all processes |
| | or only the main executed one. Default is `main_process_only=True`. |
| | |
| | Does not require an `Accelerator` object to be created first. |
| | """ |
| |
|
| | @staticmethod |
| | def _should_log(main_process_only): |
| | "Check if log should be performed" |
| | state = PartialState() |
| | return not main_process_only or (main_process_only and state.is_main_process) |
| |
|
| | def log(self, level, msg, *args, **kwargs): |
| | """ |
| | Delegates logger call after checking if we should log. |
| | |
| | Accepts a new kwarg of `main_process_only`, which will dictate whether it will be logged across all processes |
| | or only the main executed one. Default is `True` if not passed |
| | |
| | Also accepts "in_order", which if `True` makes the processes log one by one, in order. This is much easier to |
| | read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not |
| | break with the previous behavior. |
| | |
| | `in_order` is ignored if `main_process_only` is passed. |
| | """ |
| | if PartialState._shared_state == {}: |
| | raise RuntimeError( |
| | "You must initialize the accelerate state by calling either `PartialState()` or `Accelerator()` before using the logging utility." |
| | ) |
| | main_process_only = kwargs.pop("main_process_only", True) |
| | in_order = kwargs.pop("in_order", False) |
| | |
| | kwargs.setdefault("stacklevel", 2) |
| |
|
| | if self.isEnabledFor(level): |
| | if self._should_log(main_process_only): |
| | msg, kwargs = self.process(msg, kwargs) |
| | self.logger.log(level, msg, *args, **kwargs) |
| |
|
| | elif in_order: |
| | state = PartialState() |
| | for i in range(state.num_processes): |
| | if i == state.process_index: |
| | msg, kwargs = self.process(msg, kwargs) |
| | self.logger.log(level, msg, *args, **kwargs) |
| | state.wait_for_everyone() |
| |
|
| | @functools.lru_cache(None) |
| | def warning_once(self, *args, **kwargs): |
| | """ |
| | This method is identical to `logger.warning()`, but will emit the warning with the same message only once |
| | |
| | Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the |
| | cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to |
| | switch to another type of cache that includes the caller frame information in the hashing function. |
| | """ |
| | self.warning(*args, **kwargs) |
| |
|
| |
|
| | def get_logger(name: str, log_level: Optional[str] = None): |
| | """ |
| | Returns a `logging.Logger` for `name` that can handle multiprocessing. |
| | |
| | If a log should be called on all processes, pass `main_process_only=False` If a log should be called on all |
| | processes and in order, also pass `in_order=True` |
| | |
| | Args: |
| | name (`str`): |
| | The name for the logger, such as `__file__` |
| | log_level (`str`, *optional*): |
| | The log level to use. If not passed, will default to the `LOG_LEVEL` environment variable, or `INFO` if not |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from accelerate.logging import get_logger |
| | >>> from accelerate import Accelerator |
| | |
| | >>> logger = get_logger(__name__) |
| | |
| | >>> accelerator = Accelerator() |
| | >>> logger.info("My log", main_process_only=False) |
| | >>> logger.debug("My log", main_process_only=True) |
| | |
| | >>> logger = get_logger(__name__, log_level="DEBUG") |
| | >>> logger.info("My log") |
| | >>> logger.debug("My second log") |
| | |
| | >>> array = ["a", "b", "c", "d"] |
| | >>> letter_at_rank = array[accelerator.process_index] |
| | >>> logger.info(letter_at_rank, in_order=True) |
| | ``` |
| | """ |
| | if log_level is None: |
| | log_level = os.environ.get("ACCELERATE_LOG_LEVEL", None) |
| | logger = logging.getLogger(name) |
| | if log_level is not None: |
| | logger.setLevel(log_level.upper()) |
| | logger.root.setLevel(log_level.upper()) |
| | return MultiProcessAdapter(logger, {}) |
| |
|