File size: 6,402 Bytes
bc90483 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | # Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import functools
import logging
import os
import sys
from typing import Optional
from termcolor import colored
from dinov3.distributed import TorchDistributedEnvironment
from dinov3.logging.helpers import MetricLogger, SmoothedValue
_LEVEL_COLORED_KWARGS = {
logging.DEBUG: {"color": "green", "attrs": ["bold"]},
logging.INFO: {"color": "green"},
logging.WARNING: {"color": "yellow"},
logging.ERROR: {"color": "red"},
logging.CRITICAL: {"color": "red", "attrs": ["bold"]},
}
class _LevelColoredFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def formatMessage(self, record):
log = super().formatMessage(record)
colored_kwargs = _LEVEL_COLORED_KWARGS.get(record.levelno)
if colored_kwargs is None:
return log
msg = record.msg % record.args if record.msg == "%s" else record.msg
index = log.rfind(msg, len(log) - len(msg))
# Can happen in some cases, like if the msg contains `%s` which
# have been replaced in `formatMessage`. Fallback to no colors
if index == -1:
return log
prefix = log[:index]
prefix = colored(prefix, **colored_kwargs)
return prefix + msg
# So that calling _configure_logger multiple times won't add many handlers
@functools.lru_cache()
def _configure_logger(
name: Optional[str] = None,
*,
level: int = logging.DEBUG,
output: Optional[str] = None,
color: bool = True,
log_to_stdout_only_in_main_process: bool = True,
):
"""
Configure a logger.
Adapted from Detectron2.
Args:
name: The name of the logger to configure.
level: The logging level to use.
output: A file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
color: Whether stdout output should be colored (ignored if stdout is not a terminal).
log_to_stdout_only_in_main_process: The main process (rank 0) always logs to stdout,
regardless of this flag. If False, other ranks will also log to their stdout.
Returns:
The configured logger.
"""
# Disable colored output if the stdout is not a terminal
color = color and os.isatty(sys.stdout.fileno())
logger = logging.getLogger(name)
logger.setLevel(level)
logger.propagate = False
# Loosely match Google glog format:
# [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg
# but use a shorter timestamp and include the logger name:
# [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg
fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] "
fmt_message = "%(message)s"
fmt = fmt_prefix + fmt_message
datefmt = "%Y%m%d %H:%M:%S"
plain_formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
torch_env = TorchDistributedEnvironment()
# rank 0 always logs to stdout, for other ranks it depends on log_to_stdout_only_in_main_process
should_log_to_stdout = torch_env.is_main_process or not log_to_stdout_only_in_main_process
if should_log_to_stdout:
handler = logging.StreamHandler(stream=sys.stdout)
handler.setLevel(logging.DEBUG)
formatter: logging.Formatter
if color:
formatter = _LevelColoredFormatter(
fmt=fmt,
datefmt=datefmt,
)
else:
formatter = plain_formatter
handler.setFormatter(formatter)
logger.addHandler(handler)
# file logging for all workers
if output:
if os.path.splitext(output)[-1] in (".txt", ".log"):
filename = output
else:
filename = os.path.join(output, "logs", "log.txt")
if not torch_env.is_main_process:
filename = filename + f".rank{torch_env.rank}"
os.makedirs(os.path.dirname(filename), exist_ok=True)
handler = logging.StreamHandler(open(filename, "a"))
handler.setLevel(logging.DEBUG)
handler.setFormatter(plain_formatter)
logger.addHandler(handler)
logger.debug(f"PyTorch distributed environment: {torch_env}")
return logger
def setup_logging(
output: Optional[str] = None,
*,
name: Optional[str] = None,
level: int = logging.DEBUG,
color: bool = True,
capture_warnings: bool = True,
log_to_stdout_only_in_main_process: bool = True,
) -> None:
"""
Setup logging.
Args:
output: A file name or a directory to save log files. If None, log
files will not be saved. If output ends with ".txt" or ".log", it
is assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
name: The name of the logger to configure, by default the root logger.
level: The logging level to use.
color: Whether stdout output should be colored (ignored if stdout is not a terminal).
capture_warnings: Whether warnings should be captured as logs.
log_to_stdout_only_in_main_process: The main process (rank 0) always logs to stdout,
regardless of this flag. If False, other ranks will also log to their stdout.
"""
logging.captureWarnings(capture_warnings)
# Ensure the path is canonical to properly use the cache of `_configure_logger`
output = output if output is None else os.path.realpath(output)
_configure_logger(
name,
level=level,
output=output,
color=color,
log_to_stdout_only_in_main_process=log_to_stdout_only_in_main_process,
)
def cleanup_logging(*, name: Optional[str] = None) -> None:
logger = logging.getLogger(name)
for handler in logger.handlers:
handler.flush()
handler.close()
logger.removeHandler(handler)
# clears the cache of `_configure_logger` to allow re-initialization
_configure_logger.cache_clear()
|