ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import logging
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from threading import local
from typing import Any
try:
from termcolor import colored
except ImportError:
colored = None
# Thread-local storage for context information
_context_storage = local()
class ActiveLearningLoggerAdapter(logging.LoggerAdapter):
"""Logger adapter that automatically includes active learning iteration context.
This adapter automatically adds iteration information to log messages
by accessing the driver's current iteration state.
"""
def __init__(self, logger: logging.Logger, driver_ref: Any = None):
"""Initialize the adapter with a logger and optional driver reference.
Parameters
----------
logger : logging.Logger
The underlying logger to adapt
driver_ref : Any, optional
Reference to the driver object to get iteration context from
"""
super().__init__(logger, {})
self.driver_ref = driver_ref
def process(self, msg: str, kwargs: dict[str, Any]) -> tuple[str, dict[str, Any]]:
"""Process the log message to add iteration, run ID, and phase context.
Parameters
----------
msg : str
The log message
kwargs : dict[str, Any]
Additional keyword arguments
Returns
-------
tuple[str, dict[str, Any]]
Processed message and kwargs
"""
# Add iteration, run ID, and phase context if driver reference is available
if self.driver_ref is not None:
extra = kwargs.get("extra", {})
# Add iteration context
if hasattr(self.driver_ref, "active_learning_step_idx"):
iteration = getattr(self.driver_ref, "active_learning_step_idx", None)
if iteration is not None:
extra["iteration"] = iteration
# Add run ID context
if hasattr(self.driver_ref, "run_id"):
run_id = getattr(self.driver_ref, "run_id", None)
if run_id is not None:
extra["run_id"] = run_id
# Add current phase context
if hasattr(self.driver_ref, "current_phase"):
phase = getattr(self.driver_ref, "current_phase", None)
if phase is not None:
extra["phase"] = phase
if extra:
kwargs["extra"] = extra
return msg, kwargs
class JSONFormatter(logging.Formatter):
"""JSON formatter for structured logging to files.
This formatter converts log records to JSON format, including all
contextual information and metadata for structured analysis.
"""
def format(self, record: logging.LogRecord) -> str:
"""Format the log record as JSON.
Parameters
----------
record : logging.LogRecord
The log record to format
Returns
-------
str
JSON-formatted log message
"""
log_entry = {
"timestamp": datetime.fromtimestamp(record.created).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
"module": record.module,
"function": record.funcName,
"line": record.lineno,
}
# Add contextual information if available
if hasattr(record, "context"):
log_entry["context"] = record.context
if hasattr(record, "caller_object"):
log_entry["caller_object"] = record.caller_object
if hasattr(record, "iteration"):
log_entry["iteration"] = record.iteration
if hasattr(record, "phase"):
log_entry["phase"] = record.phase
extra_keys = list(filter(lambda x: x not in log_entry, record.__dict__.keys()))
# Add any extra fields
for key in extra_keys:
log_entry[key] = record.__dict__[key]
return json.dumps(log_entry)
def _get_context_stack():
"""Get the context stack for the current thread."""
if not hasattr(_context_storage, "context_stack"):
_context_storage.context_stack = []
return _context_storage.context_stack
class ContextFormatter(logging.Formatter):
"""Standard formatter that includes active learning context information with colors."""
def format(self, record):
# Build context string
context_parts = []
if hasattr(record, "caller_object") and record.caller_object:
context_parts.append(f"obj:{record.caller_object}")
if hasattr(record, "run_id") and record.run_id:
context_parts.append(f"run:{record.run_id}")
if hasattr(record, "iteration") and record.iteration is not None:
context_parts.append(f"iter:{record.iteration}")
if hasattr(record, "phase") and record.phase:
context_parts.append(f"phase:{record.phase}")
if hasattr(record, "context") and record.context:
for key, value in record.context.items():
context_parts.append(f"{key}:{value}")
context_str = f"[{', '.join(context_parts)}]" if context_parts else ""
# Use standard formatting
base_msg = super().format(record)
# Add color to the message based on level if termcolor is available
if colored is not None:
match record.levelno:
case level if level >= logging.ERROR:
base_msg = colored(base_msg, "red")
case level if level >= logging.WARNING:
base_msg = colored(base_msg, "yellow")
case level if level >= logging.INFO:
base_msg = colored(base_msg, "white")
case _: # DEBUG
base_msg = colored(base_msg, "cyan")
# Add colored context string
if context_str:
if colored is not None:
context_str = colored(context_str, "blue")
base_msg += f" {context_str}"
return base_msg
class ContextInjectingFilter(logging.Filter):
"""Filter that injects contextual information into log records."""
def filter(self, record):
# Add context information from thread-local storage
context_stack = _get_context_stack()
if context_stack:
current_context = context_stack[-1]
if current_context["caller_object"]:
record.caller_object = current_context["caller_object"]
if current_context["iteration"] is not None:
record.iteration = current_context["iteration"]
if current_context.get("phase"):
record.phase = current_context["phase"]
if current_context["context"]:
record.context = current_context["context"]
return True
def setup_active_learning_logger(
name: str,
run_id: str,
log_dir: str | Path = Path("active_learning_logs"),
level: int = logging.INFO,
) -> logging.Logger:
"""Set up a logger with active learning-specific formatting and handlers.
Parameters
----------
name : str
Logger name
run_id : str
Unique identifier for this run, used in log filename
log_dir : str | Path, optional
Directory to store log files, by default "./logs"
level : int, optional
Logging level, by default logging.INFO
Returns
-------
logging.Logger
Configured standard Python logger
Example
-------
>>> logger = setup_active_learning_logger("experiment", "run_001")
>>> logger.info("Starting experiment")
>>> with log_context(caller_object="Trainer", iteration=5):
... logger.info("Training step")
"""
# Get standard logger
logger = logging.getLogger(name)
logger.setLevel(level)
# Clear any existing handlers to avoid duplicates
logger.handlers.clear()
# Disable propagation to prevent duplicate messages from parent loggers
logger.propagate = False
# Create log directory if it doesn't exist
if isinstance(log_dir, str):
log_dir_path = Path(log_dir)
else:
log_dir_path = log_dir
log_dir_path.mkdir(parents=True, exist_ok=True)
# Set up console handler with standard formatting
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(
ContextFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
console_handler.addFilter(ContextInjectingFilter())
logger.addHandler(console_handler)
# Set up file handler with JSON formatting
log_file = log_dir_path / f"{run_id}.log"
file_handler = logging.FileHandler(log_file, mode="w")
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(JSONFormatter())
file_handler.addFilter(ContextInjectingFilter())
logger.addHandler(file_handler)
return logger
@contextmanager
def log_context(
caller_object: str | None = None,
iteration: int | None = None,
phase: str | None = None,
**kwargs: Any,
):
"""Context manager for adding contextual information to log messages.
Parameters
----------
caller_object : str, optional
Name or identifier of the object making the log call
iteration : int, optional
Current iteration counter
phase : str, optional
Current phase of the active learning process
**kwargs : Any
Additional contextual key-value pairs
Example
-------
>>> from logging import getLogger
>>> from physicsnemo.active_learning.logger import log_context
>>> logger = getLogger("my_logger")
>>> with log_context(caller_object="Trainer", iteration=5, phase="training", epoch=2):
... logger.info("Processing batch")
"""
context_info = {
"caller_object": caller_object,
"iteration": iteration,
"phase": phase,
"context": kwargs,
}
context_stack = _get_context_stack()
context_stack.append(context_info)
try:
yield
finally:
context_stack.pop()