| import os |
| import logging |
| import torch.distributed as dist |
|
|
| class Rank0Filter(logging.Filter): |
| """Only allow the rank-0 process to output logs""" |
| def filter(self, record): |
| |
| return dist.get_rank() == 0 |
|
|
| def create_logger(log_path): |
| |
| if not dist.is_initialized(): |
| raise RuntimeError("torch.distributed must be initialized before creating the logger!") |
|
|
| |
| rank = dist.get_rank() |
|
|
| |
| logger = logging.getLogger() |
| logger.setLevel(logging.INFO) |
| |
| logger.handlers = [] |
|
|
| logger.log_path = log_path |
|
|
| |
| if rank == 0: |
| |
| log_dir = os.path.dirname(log_path) |
| if log_dir and not os.path.exists(log_dir): |
| os.makedirs(log_dir, exist_ok=True) |
|
|
| |
| fh = logging.FileHandler(log_path) |
| formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') |
| fh.setFormatter(formatter) |
| logger.addHandler(fh) |
|
|
| |
| sh = logging.StreamHandler() |
| sh.setFormatter(formatter) |
| logger.addHandler(sh) |
| else: |
| |
| logger.addHandler(logging.NullHandler()) |
|
|
| return logger |