File size: 1,601 Bytes
8bc3305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import logging
import torch.distributed as dist

class Rank0Filter(logging.Filter):
    """Only allow the rank-0 process to output logs"""
    def filter(self, record):
        # Check the current process rank and only allow rank 0 to pass
        return dist.get_rank() == 0

def create_logger(log_path):
    # Ensure the distributed environment has been initialized to avoid get_rank() errors
    if not dist.is_initialized():
        raise RuntimeError("torch.distributed must be initialized before creating the logger!")

    # Current process rank
    rank = dist.get_rank()

    # Create the logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    # Clear existing handlers to avoid duplicate outputs
    logger.handlers = []

    logger.log_path = log_path

    # Only rank 0 creates the log file and console output
    if rank == 0:
        # Create the log directory
        log_dir = os.path.dirname(log_path)
        if log_dir and not os.path.exists(log_dir):
            os.makedirs(log_dir, exist_ok=True)

        # File handler (writes to the log file)
        fh = logging.FileHandler(log_path)
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        fh.setFormatter(formatter)
        logger.addHandler(fh)

        # Console handler (prints to the terminal)
        sh = logging.StreamHandler()
        sh.setFormatter(formatter)
        logger.addHandler(sh)
    else:
        # Non-rank-0 processes: add a null handler and do not output any logs
        logger.addHandler(logging.NullHandler())

    return logger