File size: 2,013 Bytes
cb0ad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import logging
import os
import sys

from .comm import get_rank


_default_logger = None


def __init_logger():
    global _default_logger
    if get_rank() == 0:
        logger = logging.getLogger('default')
        logger.setLevel(logging.DEBUG)
        formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")

        if not any([isinstance(item, logging.StreamHandler) for item in logger.handlers]):
            ch = logging.StreamHandler(stream=sys.stdout)
            ch.setLevel(logging.DEBUG)
            ch.setFormatter(formatter)
            logger.addHandler(ch)
        _default_logger = logger


__init_logger()


def setup_logger(name, save_dir, filename="log.txt"):
    global _default_logger
    # don't log results for the non-master process
    if get_rank() == 0:
        logger = logging.getLogger(name)
        logger.setLevel(logging.DEBUG)
        formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")

        if not any([isinstance(item, logging.StreamHandler) for item in logger.handlers]):
            ch = logging.StreamHandler(stream=sys.stdout)
            ch.setLevel(logging.DEBUG)
            ch.setFormatter(formatter)
            logger.addHandler(ch)

        logger.handlers = [item for item in logger.handlers if not isinstance(item, logging.FileHandler)]
        if save_dir:
            log_path = os.path.join(save_dir, filename)
            if not os.path.exists(os.path.dirname(log_path)):
                os.makedirs(os.path.dirname(log_path))
            fh = logging.FileHandler(log_path)
            fh.setLevel(logging.DEBUG)
            fh.setFormatter(formatter)
            logger.addHandler(fh)
        
        _default_logger = logger


def info(*args, **kwargs):
    if get_rank() == 0:
        _default_logger.info(*args, **kwargs)


def error(*args, **kwargs):
    if get_rank() == 0:
        _default_logger.error(*args, **kwargs)