Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import logging | |
| import math | |
| import sys | |
| import time | |
| from datetime import timedelta | |
| import fsspec | |
| from bytelatent.distributed import get_global_rank, get_is_slurm_job | |
| class LogFormatter(logging.Formatter): | |
| """ | |
| Custom logger for distributed jobs, displaying rank | |
| and preserving indent from the custom prefix format. | |
| """ | |
| def __init__(self): | |
| self.start_time = time.time() | |
| self.rank = get_global_rank() | |
| self.show_rank = not get_is_slurm_job() # srun has --label | |
| def formatTime(self, record): | |
| subsecond, seconds = math.modf(record.created) | |
| curr_date = ( | |
| time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds)) | |
| + f".{int(subsecond * 1_000_000):06d}" | |
| ) | |
| delta = timedelta(seconds=round(record.created - self.start_time)) | |
| return f"{curr_date} - {delta}" | |
| def formatPrefix(self, record): | |
| fmt_time = self.formatTime(record) | |
| if self.show_rank: | |
| return f"{self.rank}: {record.levelname:<7} {fmt_time} - " | |
| else: | |
| return f"{record.levelname:<7} {fmt_time} - " | |
| def formatMessage(self, record, indent: str): | |
| content = record.getMessage() | |
| content = content.replace("\n", "\n" + indent) | |
| # Exception handling as in the default formatter, albeit with indenting | |
| # according to our custom prefix | |
| if record.exc_info: | |
| # Cache the traceback text to avoid converting it multiple times | |
| # (it's constant anyway) | |
| if not record.exc_text: | |
| record.exc_text = self.formatException(record.exc_info) | |
| if record.exc_text: | |
| if content[-1:] != "\n": | |
| content = content + "\n" + indent | |
| content = content + indent.join( | |
| [l + "\n" for l in record.exc_text.splitlines()] | |
| ) | |
| if content[-1:] == "\n": | |
| content = content[:-1] | |
| if record.stack_info: | |
| if content[-1:] != "\n": | |
| content = content + "\n" + indent | |
| stack_text = self.formatStack(record.stack_info) | |
| content = content + indent.join([l + "\n" for l in stack_text.splitlines()]) | |
| if content[-1:] == "\n": | |
| content = content[:-1] | |
| return content | |
| def format(self, record): | |
| prefix = self.formatPrefix(record) | |
| indent = " " * len(prefix) | |
| content = self.formatMessage(record, indent) | |
| return prefix + content | |
| def set_root_log_level(log_level: str): | |
| logger = logging.getLogger() | |
| level: int | str = log_level.upper() | |
| try: | |
| level = int(log_level) | |
| except ValueError: | |
| pass | |
| try: | |
| logger.setLevel(level) # type: ignore | |
| except Exception: | |
| logger.warning( | |
| f"Failed to set logging level to {log_level}, using default 'NOTSET'" | |
| ) | |
| logger.setLevel(logging.NOTSET) | |
| def init_logger( | |
| log_file: str | None = None, | |
| *, | |
| name: str | None = None, | |
| level: str = "INFO", | |
| fs: fsspec.AbstractFileSystem | None = None, | |
| ): | |
| """ | |
| Setup logging. | |
| Args: | |
| log_file: A file name to save file logs to. | |
| name: The name of the logger to configure, by default the root logger. | |
| level: The logging level to use. | |
| """ | |
| set_root_log_level(level) | |
| logger = logging.getLogger(name) | |
| # stdout: everything | |
| stdout_handler = logging.StreamHandler(sys.stdout) | |
| stdout_handler.setLevel(logging.NOTSET) | |
| stdout_handler.setFormatter(LogFormatter()) | |
| # stderr: warnings / errors and above | |
| stderr_handler = logging.StreamHandler(sys.stderr) | |
| stderr_handler.setLevel(logging.WARNING) | |
| stderr_handler.setFormatter(LogFormatter()) | |
| # set stream handlers | |
| logger.handlers.clear() | |
| logger.handlers.append(stdout_handler) | |
| logger.handlers.append(stderr_handler) | |
| if log_file is not None and get_global_rank() == 0: | |
| # build file handler | |
| if fs is None: | |
| file_handler = logging.FileHandler(log_file, "a") | |
| else: | |
| file_stream = fs.open(log_file, mode="a") | |
| file_handler = logging.StreamHandler(file_stream) | |
| file_handler.setLevel(logging.NOTSET) | |
| file_handler.setFormatter(LogFormatter()) | |
| # update logger | |
| logger = logging.getLogger() | |
| logger.addHandler(file_handler) | |