|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect
|
|
|
import logging
|
|
|
import os
|
|
|
import warnings
|
|
|
from typing import Optional
|
|
|
|
|
|
from torch.distributed.elastic.utils.log_level import get_log_level
|
|
|
|
|
|
|
|
|
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
|
|
"""
|
|
|
Util function to set up a simple logger that writes
|
|
|
into stderr. The loglevel is fetched from the LOGLEVEL
|
|
|
env. variable or WARNING as default. The function will use the
|
|
|
module name of the caller if no name is provided.
|
|
|
|
|
|
Args:
|
|
|
name: Name of the logger. If no name provided, the name will
|
|
|
be derived from the call stack.
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
return _setup_logger(name or _derive_module_name(depth=2))
|
|
|
|
|
|
|
|
|
def _setup_logger(name: Optional[str] = None) -> logging.Logger:
|
|
|
logger = logging.getLogger(name)
|
|
|
logger.setLevel(os.environ.get("LOGLEVEL", get_log_level()))
|
|
|
return logger
|
|
|
|
|
|
|
|
|
def _derive_module_name(depth: int = 1) -> Optional[str]:
|
|
|
"""
|
|
|
Derives the name of the caller module from the stack frames.
|
|
|
|
|
|
Args:
|
|
|
depth: The position of the frame in the stack.
|
|
|
"""
|
|
|
try:
|
|
|
stack = inspect.stack()
|
|
|
assert depth < len(stack)
|
|
|
|
|
|
frame_info = stack[depth]
|
|
|
|
|
|
module = inspect.getmodule(frame_info[0])
|
|
|
if module:
|
|
|
module_name = module.__name__
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
filename = frame_info[1]
|
|
|
module_name = os.path.splitext(os.path.basename(filename))[0]
|
|
|
return module_name
|
|
|
except Exception as e:
|
|
|
warnings.warn(
|
|
|
f"Error deriving logger module name, using <None>. Exception: {e}",
|
|
|
RuntimeWarning,
|
|
|
)
|
|
|
return None
|
|
|
|