| |
|
|
| import logging |
| import math |
| import sys |
| import time |
| from datetime import timedelta |
|
|
| from core.distributed import get_global_rank, get_is_slurm_job |
|
|
|
|
| class LogFormatter(logging.Formatter): |
| """ |
| Custom logger for distributed jobs, displaying rank |
| and preserving indent from the custom prefix format. |
| """ |
|
|
| def __init__(self): |
| self.start_time = time.time() |
| self.rank = get_global_rank() |
| self.show_rank = not get_is_slurm_job() |
|
|
| def formatTime(self, record): |
| subsecond, seconds = math.modf(record.created) |
| curr_date = ( |
| time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds)) |
| + f".{int(subsecond * 1_000_000):06d}" |
| ) |
| delta = timedelta(seconds=round(record.created - self.start_time)) |
| return f"{curr_date} - {delta}" |
|
|
| def formatPrefix(self, record): |
| fmt_time = self.formatTime(record) |
| if self.show_rank: |
| return f"{self.rank}: {record.levelname:<7} {fmt_time} - " |
| else: |
| return f"{record.levelname:<7} {fmt_time} - " |
|
|
| def formatMessage(self, record, indent: str): |
| content = record.getMessage() |
| content = content.replace("\n", "\n" + indent) |
| |
| |
| if record.exc_info: |
| |
| |
| if not record.exc_text: |
| record.exc_text = self.formatException(record.exc_info) |
| if record.exc_text: |
| if content[-1:] != "\n": |
| content = content + "\n" + indent |
| content = content + indent.join( |
| [l + "\n" for l in record.exc_text.splitlines()] |
| ) |
| if content[-1:] == "\n": |
| content = content[:-1] |
| if record.stack_info: |
| if content[-1:] != "\n": |
| content = content + "\n" + indent |
| stack_text = self.formatStack(record.stack_info) |
| content = content + indent.join([l + "\n" for l in stack_text.splitlines()]) |
| if content[-1:] == "\n": |
| content = content[:-1] |
|
|
| return content |
|
|
| def format(self, record): |
| prefix = self.formatPrefix(record) |
| indent = " " * len(prefix) |
| content = self.formatMessage(record, indent) |
| return prefix + content |
|
|
|
|
| def set_root_log_level(log_level: str): |
| logger = logging.getLogger() |
| level: int | str = log_level.upper() |
| try: |
| level = int(log_level) |
| except ValueError: |
| pass |
| try: |
| logger.setLevel(level) |
| except Exception: |
| logger.warning( |
| f"Failed to set logging level to {log_level}, using default 'NOTSET'" |
| ) |
| logger.setLevel(logging.NOTSET) |
|
|
|
|
| def init_logger( |
| log_file: str | None = None, |
| *, |
| name: str | None = None, |
| level: str = "NOTSET", |
| ): |
| """ |
| Setup logging. |
| |
| Args: |
| log_file: A file name to save file logs to. |
| name: The name of the logger to configure, by default the root logger. |
| level: The logging level to use. |
| """ |
| set_root_log_level(level) |
| logger = logging.getLogger(name) |
|
|
| |
| stdout_handler = logging.StreamHandler(sys.stdout) |
| stdout_handler.setLevel(logging.NOTSET) |
| stdout_handler.setFormatter(LogFormatter()) |
|
|
| |
| stderr_handler = logging.StreamHandler(sys.stderr) |
| stderr_handler.setLevel(logging.WARNING) |
| stderr_handler.setFormatter(LogFormatter()) |
|
|
| |
| logger.handlers.clear() |
| logger.handlers.append(stdout_handler) |
| logger.handlers.append(stderr_handler) |
|
|
| if log_file is not None and get_global_rank() == 0: |
| |
| file_handler = logging.FileHandler(log_file, "a") |
| file_handler.setLevel(logging.NOTSET) |
| file_handler.setFormatter(LogFormatter()) |
| |
| logger = logging.getLogger() |
| logger.addHandler(file_handler) |
|
|