File size: 6,402 Bytes
bc90483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.

import functools
import logging
import os
import sys
from typing import Optional

from termcolor import colored

from dinov3.distributed import TorchDistributedEnvironment

from dinov3.logging.helpers import MetricLogger, SmoothedValue

_LEVEL_COLORED_KWARGS = {
    logging.DEBUG: {"color": "green", "attrs": ["bold"]},
    logging.INFO: {"color": "green"},
    logging.WARNING: {"color": "yellow"},
    logging.ERROR: {"color": "red"},
    logging.CRITICAL: {"color": "red", "attrs": ["bold"]},
}


class _LevelColoredFormatter(logging.Formatter):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def formatMessage(self, record):
        log = super().formatMessage(record)

        colored_kwargs = _LEVEL_COLORED_KWARGS.get(record.levelno)
        if colored_kwargs is None:
            return log

        msg = record.msg % record.args if record.msg == "%s" else record.msg
        index = log.rfind(msg, len(log) - len(msg))
        # Can happen in some cases, like if the msg contains `%s` which
        # have been replaced in `formatMessage`. Fallback to no colors
        if index == -1:
            return log
        prefix = log[:index]
        prefix = colored(prefix, **colored_kwargs)
        return prefix + msg


# So that calling _configure_logger multiple times won't add many handlers
@functools.lru_cache()
def _configure_logger(

    name: Optional[str] = None,

    *,

    level: int = logging.DEBUG,

    output: Optional[str] = None,

    color: bool = True,

    log_to_stdout_only_in_main_process: bool = True,

):
    """

    Configure a logger.



    Adapted from Detectron2.



    Args:

        name: The name of the logger to configure.

        level: The logging level to use.

        output: A file name or a directory to save log. If None, will not save log file.

            If ends with ".txt" or ".log", assumed to be a file name.

            Otherwise, logs will be saved to `output/log.txt`.

        color: Whether stdout output should be colored (ignored if stdout is not a terminal).

        log_to_stdout_only_in_main_process: The main process (rank 0) always logs to stdout,

            regardless of this flag. If False, other ranks will also log to their stdout.



    Returns:

        The configured logger.

    """

    # Disable colored output if the stdout is not a terminal
    color = color and os.isatty(sys.stdout.fileno())

    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.propagate = False

    # Loosely match Google glog format:
    #   [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg
    # but use a shorter timestamp and include the logger name:
    #   [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg
    fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] "
    fmt_message = "%(message)s"
    fmt = fmt_prefix + fmt_message
    datefmt = "%Y%m%d %H:%M:%S"
    plain_formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)

    torch_env = TorchDistributedEnvironment()

    # rank 0 always logs to stdout, for other ranks it depends on log_to_stdout_only_in_main_process
    should_log_to_stdout = torch_env.is_main_process or not log_to_stdout_only_in_main_process
    if should_log_to_stdout:
        handler = logging.StreamHandler(stream=sys.stdout)
        handler.setLevel(logging.DEBUG)

        formatter: logging.Formatter
        if color:
            formatter = _LevelColoredFormatter(
                fmt=fmt,
                datefmt=datefmt,
            )
        else:
            formatter = plain_formatter

        handler.setFormatter(formatter)
        logger.addHandler(handler)

    # file logging for all workers
    if output:
        if os.path.splitext(output)[-1] in (".txt", ".log"):
            filename = output
        else:
            filename = os.path.join(output, "logs", "log.txt")

        if not torch_env.is_main_process:
            filename = filename + f".rank{torch_env.rank}"

        os.makedirs(os.path.dirname(filename), exist_ok=True)

        handler = logging.StreamHandler(open(filename, "a"))
        handler.setLevel(logging.DEBUG)
        handler.setFormatter(plain_formatter)
        logger.addHandler(handler)

    logger.debug(f"PyTorch distributed environment: {torch_env}")
    return logger


def setup_logging(

    output: Optional[str] = None,

    *,

    name: Optional[str] = None,

    level: int = logging.DEBUG,

    color: bool = True,

    capture_warnings: bool = True,

    log_to_stdout_only_in_main_process: bool = True,

) -> None:
    """

    Setup logging.



    Args:

        output: A file name or a directory to save log files. If None, log

            files will not be saved. If output ends with ".txt" or ".log", it

            is assumed to be a file name.

            Otherwise, logs will be saved to `output/log.txt`.

        name: The name of the logger to configure, by default the root logger.

        level: The logging level to use.

        color: Whether stdout output should be colored (ignored if stdout is not a terminal).

        capture_warnings: Whether warnings should be captured as logs.

        log_to_stdout_only_in_main_process: The main process (rank 0) always logs to stdout,

            regardless of this flag. If False, other ranks will also log to their stdout.

    """
    logging.captureWarnings(capture_warnings)
    # Ensure the path is canonical to properly use the cache of `_configure_logger`
    output = output if output is None else os.path.realpath(output)
    _configure_logger(
        name,
        level=level,
        output=output,
        color=color,
        log_to_stdout_only_in_main_process=log_to_stdout_only_in_main_process,
    )


def cleanup_logging(*, name: Optional[str] = None) -> None:
    logger = logging.getLogger(name)
    for handler in logger.handlers:
        handler.flush()
        handler.close()
        logger.removeHandler(handler)

    # clears the cache of `_configure_logger` to allow re-initialization
    _configure_logger.cache_clear()