| import logging | |
| import torch.distributed as dist | |
| LOG_LEVEL = logging.INFO | |
| SUBPROCESS_LOG_LEVEL = logging.ERROR | |
| LOG_FORMATTER = '[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s' | |
| def get_logger(name, level=LOG_LEVEL, log_file=None, file_mode='w'): | |
| formatter = logging.Formatter(LOG_FORMATTER) | |
| logger = logging.getLogger(name) | |
| for handler in logger.root.handlers: | |
| if type(handler) is logging.StreamHandler: | |
| handler.setLevel(logging.ERROR) | |
| if dist.is_available() and dist.is_initialized(): | |
| rank = dist.get_rank() | |
| else: | |
| rank = 0 | |
| if rank == 0 and log_file is not None: | |
| file_handler = logging.FileHandler(log_file, file_mode) | |
| file_handler.setFormatter(formatter) | |
| file_handler.setLevel(level) | |
| logger.addHandler(file_handler) | |
| if rank == 0: | |
| logger.setLevel(level) | |
| else: | |
| logger.setLevel(SUBPROCESS_LOG_LEVEL) | |
| stream_handler = logging.StreamHandler() | |
| stream_handler.setFormatter(formatter) | |
| stream_handler.setLevel(level) | |
| logger.addHandler(stream_handler) | |
| return logger | |