import os import logging import torch.distributed as dist class Rank0Filter(logging.Filter): """Only allow the rank-0 process to output logs""" def filter(self, record): # Check the current process rank and only allow rank 0 to pass return dist.get_rank() == 0 def create_logger(log_path): # Ensure the distributed environment has been initialized to avoid get_rank() errors if not dist.is_initialized(): raise RuntimeError("torch.distributed must be initialized before creating the logger!") # Current process rank rank = dist.get_rank() # Create the logger logger = logging.getLogger() logger.setLevel(logging.INFO) # Clear existing handlers to avoid duplicate outputs logger.handlers = [] logger.log_path = log_path # Only rank 0 creates the log file and console output if rank == 0: # Create the log directory log_dir = os.path.dirname(log_path) if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir, exist_ok=True) # File handler (writes to the log file) fh = logging.FileHandler(log_path) formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') fh.setFormatter(formatter) logger.addHandler(fh) # Console handler (prints to the terminal) sh = logging.StreamHandler() sh.setFormatter(formatter) logger.addHandler(sh) else: # Non-rank-0 processes: add a null handler and do not output any logs logger.addHandler(logging.NullHandler()) return logger