#!/usr/bin/env python # -*- coding:utf-8 _*- import os import logging import sys import typing # -------- log setting --------- DEFAULT_LOGGER = "time_moe_logger" DEFAULT_FORMATTER = logging.Formatter( '%(asctime)s - %(filename)s[pid:%(process)d;line:%(lineno)d:%(funcName)s] - %(levelname)s: %(message)s' ) _ch = logging.StreamHandler(stream=sys.stdout) _ch.setFormatter(DEFAULT_FORMATTER) _DEFAULT_HANDLERS = [_ch] _LOGGER_CACHE = {} # type: typing.Dict[str, logging.Logger] def is_local_rank_0(): local_rank = os.getenv('LOCAL_RANK') if local_rank is None or local_rank == '0': return True else: return False def get_logger(name, level="INFO", handlers=None, update=False): if name in _LOGGER_CACHE and not update: return _LOGGER_CACHE[name] logger = logging.getLogger(name) logger.setLevel(level) logger.handlers = handlers or _DEFAULT_HANDLERS logger.propagate = False return logger def log_in_local_rank_0(*msg, type='info', used_logger=None): msg = ' '.join([str(s) for s in msg]) if used_logger is None: used_logger = logger if is_local_rank_0(): if type == 'warn' or type == 'warning': used_logger.warning(msg) elif type == 'error': used_logger.error(msg) else: used_logger.info(msg) def adaptive_print(*args, **kwargs): """ 仅在主进程打印的 print 函数。 """ if is_local_rank_0(): print(*args, **kwargs) # -------------------------- Singleton Object -------------------------- logger = get_logger(DEFAULT_LOGGER)