File size: 1,100 Bytes
b678162 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | import logging
import torch.distributed as dist
def setup_root_logger(level=logging.INFO):
"""
Configure the root logger to only log from rank 0 process.
This allows using simple logging.info() calls throughout the code.
"""
# Get the rank (default to 0 if distributed not initialized)
rank = 0
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
# Clear any existing handlers
root = logging.getLogger()
if root.handlers:
for handler in root.handlers:
root.removeHandler(handler)
# Set logging level
root.setLevel(level)
# For rank 0, add a normal handler
if rank == 0:
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s | %(levelname)s | %(message)s',
datefmt='%m-%d,%H:%M'
)
handler.setFormatter(formatter)
root.addHandler(handler)
else:
# For other ranks, set a very high logging level to suppress output
root.setLevel(logging.CRITICAL + 1) # Higher than any standard level
|