| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | import inspect |
| | import logging |
| | import os |
| | import warnings |
| | from typing import Optional |
| |
|
| | from torch.distributed.elastic.utils.log_level import get_log_level |
| |
|
| |
|
| | def get_logger(name: Optional[str] = None) -> logging.Logger: |
| | """ |
| | Util function to set up a simple logger that writes |
| | into stderr. The loglevel is fetched from the LOGLEVEL |
| | env. variable or WARNING as default. The function will use the |
| | module name of the caller if no name is provided. |
| | |
| | Args: |
| | name: Name of the logger. If no name provided, the name will |
| | be derived from the call stack. |
| | """ |
| |
|
| | |
| | |
| | return _setup_logger(name or _derive_module_name(depth=2)) |
| |
|
| |
|
| | def _setup_logger(name: Optional[str] = None) -> logging.Logger: |
| | logger = logging.getLogger(name) |
| | logger.setLevel(os.environ.get("LOGLEVEL", get_log_level())) |
| | return logger |
| |
|
| |
|
| | def _derive_module_name(depth: int = 1) -> Optional[str]: |
| | """ |
| | Derives the name of the caller module from the stack frames. |
| | |
| | Args: |
| | depth: The position of the frame in the stack. |
| | """ |
| | try: |
| | stack = inspect.stack() |
| | assert depth < len(stack) |
| | |
| | frame_info = stack[depth] |
| |
|
| | module = inspect.getmodule(frame_info[0]) |
| | if module: |
| | module_name = module.__name__ |
| | else: |
| | |
| | |
| | |
| | filename = frame_info[1] |
| | module_name = os.path.splitext(os.path.basename(filename))[0] |
| | return module_name |
| | except Exception as e: |
| | warnings.warn( |
| | f"Error deriving logger module name, using <None>. Exception: {e}", |
| | RuntimeWarning, |
| | ) |
| | return None |
| |
|