shunliwang
update
8bc3305
import os
import logging
import torch.distributed as dist
class Rank0Filter(logging.Filter):
"""Only allow the rank-0 process to output logs"""
def filter(self, record):
# Check the current process rank and only allow rank 0 to pass
return dist.get_rank() == 0
def create_logger(log_path):
# Ensure the distributed environment has been initialized to avoid get_rank() errors
if not dist.is_initialized():
raise RuntimeError("torch.distributed must be initialized before creating the logger!")
# Current process rank
rank = dist.get_rank()
# Create the logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Clear existing handlers to avoid duplicate outputs
logger.handlers = []
logger.log_path = log_path
# Only rank 0 creates the log file and console output
if rank == 0:
# Create the log directory
log_dir = os.path.dirname(log_path)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
# File handler (writes to the log file)
fh = logging.FileHandler(log_path)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
logger.addHandler(fh)
# Console handler (prints to the terminal)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
logger.addHandler(sh)
else:
# Non-rank-0 processes: add a null handler and do not output any logs
logger.addHandler(logging.NullHandler())
return logger