| | from collections import deque |
| | from datetime import datetime |
| | import io |
| | import logging |
| | import sys |
| | import threading |
| |
|
| | logs = None |
| | stdout_interceptor = None |
| | stderr_interceptor = None |
| |
|
| |
|
| | class LogInterceptor(io.TextIOWrapper): |
| | def __init__(self, stream, *args, **kwargs): |
| | buffer = stream.buffer |
| | encoding = stream.encoding |
| | super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering) |
| | self._lock = threading.Lock() |
| | self._flush_callbacks = [] |
| | self._logs_since_flush = [] |
| |
|
| | def write(self, data): |
| | entry = {"t": datetime.now().isoformat(), "m": data} |
| | with self._lock: |
| | self._logs_since_flush.append(entry) |
| |
|
| | |
| | |
| | if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"): |
| | logs.pop() |
| | logs.append(entry) |
| | super().write(data) |
| |
|
| | def flush(self): |
| | super().flush() |
| | for cb in self._flush_callbacks: |
| | cb(self._logs_since_flush) |
| | self._logs_since_flush = [] |
| |
|
| | def on_flush(self, callback): |
| | self._flush_callbacks.append(callback) |
| |
|
| |
|
| | def get_logs(): |
| | return logs |
| |
|
| |
|
| | def on_flush(callback): |
| | if stdout_interceptor is not None: |
| | stdout_interceptor.on_flush(callback) |
| | if stderr_interceptor is not None: |
| | stderr_interceptor.on_flush(callback) |
| |
|
| | def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False): |
| | global logs |
| | if logs: |
| | return |
| |
|
| | |
| | logs = deque(maxlen=capacity) |
| |
|
| | global stdout_interceptor |
| | global stderr_interceptor |
| | stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout) |
| | stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr) |
| |
|
| | |
| | logger = logging.getLogger() |
| | logger.setLevel(log_level) |
| |
|
| | stream_handler = logging.StreamHandler() |
| | stream_handler.setFormatter(logging.Formatter("%(message)s")) |
| |
|
| | if use_stdout: |
| | |
| | stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR) |
| |
|
| | |
| | stdout_handler = logging.StreamHandler(sys.stdout) |
| | stdout_handler.setFormatter(logging.Formatter("%(message)s")) |
| | stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR) |
| | logger.addHandler(stdout_handler) |
| |
|
| | logger.addHandler(stream_handler) |
| |
|
| |
|
| | STARTUP_WARNINGS = [] |
| |
|
| |
|
| | def log_startup_warning(msg): |
| | logging.warning(msg) |
| | STARTUP_WARNINGS.append(msg) |
| |
|
| |
|
| | def print_startup_warnings(): |
| | for s in STARTUP_WARNINGS: |
| | logging.warning(s) |
| | STARTUP_WARNINGS.clear() |
| |
|