File size: 1,601 Bytes
8bc3305 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | import os
import logging
import torch.distributed as dist
class Rank0Filter(logging.Filter):
"""Only allow the rank-0 process to output logs"""
def filter(self, record):
# Check the current process rank and only allow rank 0 to pass
return dist.get_rank() == 0
def create_logger(log_path):
# Ensure the distributed environment has been initialized to avoid get_rank() errors
if not dist.is_initialized():
raise RuntimeError("torch.distributed must be initialized before creating the logger!")
# Current process rank
rank = dist.get_rank()
# Create the logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Clear existing handlers to avoid duplicate outputs
logger.handlers = []
logger.log_path = log_path
# Only rank 0 creates the log file and console output
if rank == 0:
# Create the log directory
log_dir = os.path.dirname(log_path)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
# File handler (writes to the log file)
fh = logging.FileHandler(log_path)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
logger.addHandler(fh)
# Console handler (prints to the terminal)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
logger.addHandler(sh)
else:
# Non-rank-0 processes: add a null handler and do not output any logs
logger.addHandler(logging.NullHandler())
return logger |