import logging import torch.distributed as dist def setup_root_logger(level=logging.INFO): """ Configure the root logger to only log from rank 0 process. This allows using simple logging.info() calls throughout the code. """ # Get the rank (default to 0 if distributed not initialized) rank = 0 if dist.is_available() and dist.is_initialized(): rank = dist.get_rank() # Clear any existing handlers root = logging.getLogger() if root.handlers: for handler in root.handlers: root.removeHandler(handler) # Set logging level root.setLevel(level) # For rank 0, add a normal handler if rank == 0: handler = logging.StreamHandler() formatter = logging.Formatter( '%(asctime)s | %(levelname)s | %(message)s', datefmt='%m-%d,%H:%M' ) handler.setFormatter(formatter) root.addHandler(handler) else: # For other ranks, set a very high logging level to suppress output root.setLevel(logging.CRITICAL + 1) # Higher than any standard level