Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| from contextlib import contextmanager | |
| from datetime import datetime | |
| from pathlib import Path | |
| from threading import local | |
| from typing import Any | |
| try: | |
| from termcolor import colored | |
| except ImportError: | |
| colored = None | |
| # Thread-local storage for context information | |
| _context_storage = local() | |
| class ActiveLearningLoggerAdapter(logging.LoggerAdapter): | |
| """Logger adapter that automatically includes active learning iteration context. | |
| This adapter automatically adds iteration information to log messages | |
| by accessing the driver's current iteration state. | |
| """ | |
| def __init__(self, logger: logging.Logger, driver_ref: Any = None): | |
| """Initialize the adapter with a logger and optional driver reference. | |
| Parameters | |
| ---------- | |
| logger : logging.Logger | |
| The underlying logger to adapt | |
| driver_ref : Any, optional | |
| Reference to the driver object to get iteration context from | |
| """ | |
| super().__init__(logger, {}) | |
| self.driver_ref = driver_ref | |
| def process(self, msg: str, kwargs: dict[str, Any]) -> tuple[str, dict[str, Any]]: | |
| """Process the log message to add iteration, run ID, and phase context. | |
| Parameters | |
| ---------- | |
| msg : str | |
| The log message | |
| kwargs : dict[str, Any] | |
| Additional keyword arguments | |
| Returns | |
| ------- | |
| tuple[str, dict[str, Any]] | |
| Processed message and kwargs | |
| """ | |
| # Add iteration, run ID, and phase context if driver reference is available | |
| if self.driver_ref is not None: | |
| extra = kwargs.get("extra", {}) | |
| # Add iteration context | |
| if hasattr(self.driver_ref, "active_learning_step_idx"): | |
| iteration = getattr(self.driver_ref, "active_learning_step_idx", None) | |
| if iteration is not None: | |
| extra["iteration"] = iteration | |
| # Add run ID context | |
| if hasattr(self.driver_ref, "run_id"): | |
| run_id = getattr(self.driver_ref, "run_id", None) | |
| if run_id is not None: | |
| extra["run_id"] = run_id | |
| # Add current phase context | |
| if hasattr(self.driver_ref, "current_phase"): | |
| phase = getattr(self.driver_ref, "current_phase", None) | |
| if phase is not None: | |
| extra["phase"] = phase | |
| if extra: | |
| kwargs["extra"] = extra | |
| return msg, kwargs | |
| class JSONFormatter(logging.Formatter): | |
| """JSON formatter for structured logging to files. | |
| This formatter converts log records to JSON format, including all | |
| contextual information and metadata for structured analysis. | |
| """ | |
| def format(self, record: logging.LogRecord) -> str: | |
| """Format the log record as JSON. | |
| Parameters | |
| ---------- | |
| record : logging.LogRecord | |
| The log record to format | |
| Returns | |
| ------- | |
| str | |
| JSON-formatted log message | |
| """ | |
| log_entry = { | |
| "timestamp": datetime.fromtimestamp(record.created).isoformat(), | |
| "level": record.levelname, | |
| "logger": record.name, | |
| "message": record.getMessage(), | |
| "module": record.module, | |
| "function": record.funcName, | |
| "line": record.lineno, | |
| } | |
| # Add contextual information if available | |
| if hasattr(record, "context"): | |
| log_entry["context"] = record.context | |
| if hasattr(record, "caller_object"): | |
| log_entry["caller_object"] = record.caller_object | |
| if hasattr(record, "iteration"): | |
| log_entry["iteration"] = record.iteration | |
| if hasattr(record, "phase"): | |
| log_entry["phase"] = record.phase | |
| extra_keys = list(filter(lambda x: x not in log_entry, record.__dict__.keys())) | |
| # Add any extra fields | |
| for key in extra_keys: | |
| log_entry[key] = record.__dict__[key] | |
| return json.dumps(log_entry) | |
| def _get_context_stack(): | |
| """Get the context stack for the current thread.""" | |
| if not hasattr(_context_storage, "context_stack"): | |
| _context_storage.context_stack = [] | |
| return _context_storage.context_stack | |
| class ContextFormatter(logging.Formatter): | |
| """Standard formatter that includes active learning context information with colors.""" | |
| def format(self, record): | |
| # Build context string | |
| context_parts = [] | |
| if hasattr(record, "caller_object") and record.caller_object: | |
| context_parts.append(f"obj:{record.caller_object}") | |
| if hasattr(record, "run_id") and record.run_id: | |
| context_parts.append(f"run:{record.run_id}") | |
| if hasattr(record, "iteration") and record.iteration is not None: | |
| context_parts.append(f"iter:{record.iteration}") | |
| if hasattr(record, "phase") and record.phase: | |
| context_parts.append(f"phase:{record.phase}") | |
| if hasattr(record, "context") and record.context: | |
| for key, value in record.context.items(): | |
| context_parts.append(f"{key}:{value}") | |
| context_str = f"[{', '.join(context_parts)}]" if context_parts else "" | |
| # Use standard formatting | |
| base_msg = super().format(record) | |
| # Add color to the message based on level if termcolor is available | |
| if colored is not None: | |
| match record.levelno: | |
| case level if level >= logging.ERROR: | |
| base_msg = colored(base_msg, "red") | |
| case level if level >= logging.WARNING: | |
| base_msg = colored(base_msg, "yellow") | |
| case level if level >= logging.INFO: | |
| base_msg = colored(base_msg, "white") | |
| case _: # DEBUG | |
| base_msg = colored(base_msg, "cyan") | |
| # Add colored context string | |
| if context_str: | |
| if colored is not None: | |
| context_str = colored(context_str, "blue") | |
| base_msg += f" {context_str}" | |
| return base_msg | |
| class ContextInjectingFilter(logging.Filter): | |
| """Filter that injects contextual information into log records.""" | |
| def filter(self, record): | |
| # Add context information from thread-local storage | |
| context_stack = _get_context_stack() | |
| if context_stack: | |
| current_context = context_stack[-1] | |
| if current_context["caller_object"]: | |
| record.caller_object = current_context["caller_object"] | |
| if current_context["iteration"] is not None: | |
| record.iteration = current_context["iteration"] | |
| if current_context.get("phase"): | |
| record.phase = current_context["phase"] | |
| if current_context["context"]: | |
| record.context = current_context["context"] | |
| return True | |
| def setup_active_learning_logger( | |
| name: str, | |
| run_id: str, | |
| log_dir: str | Path = Path("active_learning_logs"), | |
| level: int = logging.INFO, | |
| ) -> logging.Logger: | |
| """Set up a logger with active learning-specific formatting and handlers. | |
| Parameters | |
| ---------- | |
| name : str | |
| Logger name | |
| run_id : str | |
| Unique identifier for this run, used in log filename | |
| log_dir : str | Path, optional | |
| Directory to store log files, by default "./logs" | |
| level : int, optional | |
| Logging level, by default logging.INFO | |
| Returns | |
| ------- | |
| logging.Logger | |
| Configured standard Python logger | |
| Example | |
| ------- | |
| >>> logger = setup_active_learning_logger("experiment", "run_001") | |
| >>> logger.info("Starting experiment") | |
| >>> with log_context(caller_object="Trainer", iteration=5): | |
| ... logger.info("Training step") | |
| """ | |
| # Get standard logger | |
| logger = logging.getLogger(name) | |
| logger.setLevel(level) | |
| # Clear any existing handlers to avoid duplicates | |
| logger.handlers.clear() | |
| # Disable propagation to prevent duplicate messages from parent loggers | |
| logger.propagate = False | |
| # Create log directory if it doesn't exist | |
| if isinstance(log_dir, str): | |
| log_dir_path = Path(log_dir) | |
| else: | |
| log_dir_path = log_dir | |
| log_dir_path.mkdir(parents=True, exist_ok=True) | |
| # Set up console handler with standard formatting | |
| console_handler = logging.StreamHandler() | |
| console_handler.setLevel(logging.DEBUG) | |
| console_handler.setFormatter( | |
| ContextFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
| ) | |
| console_handler.addFilter(ContextInjectingFilter()) | |
| logger.addHandler(console_handler) | |
| # Set up file handler with JSON formatting | |
| log_file = log_dir_path / f"{run_id}.log" | |
| file_handler = logging.FileHandler(log_file, mode="w") | |
| file_handler.setLevel(logging.DEBUG) | |
| file_handler.setFormatter(JSONFormatter()) | |
| file_handler.addFilter(ContextInjectingFilter()) | |
| logger.addHandler(file_handler) | |
| return logger | |
| def log_context( | |
| caller_object: str | None = None, | |
| iteration: int | None = None, | |
| phase: str | None = None, | |
| **kwargs: Any, | |
| ): | |
| """Context manager for adding contextual information to log messages. | |
| Parameters | |
| ---------- | |
| caller_object : str, optional | |
| Name or identifier of the object making the log call | |
| iteration : int, optional | |
| Current iteration counter | |
| phase : str, optional | |
| Current phase of the active learning process | |
| **kwargs : Any | |
| Additional contextual key-value pairs | |
| Example | |
| ------- | |
| >>> from logging import getLogger | |
| >>> from physicsnemo.active_learning.logger import log_context | |
| >>> logger = getLogger("my_logger") | |
| >>> with log_context(caller_object="Trainer", iteration=5, phase="training", epoch=2): | |
| ... logger.info("Processing batch") | |
| """ | |
| context_info = { | |
| "caller_object": caller_object, | |
| "iteration": iteration, | |
| "phase": phase, | |
| "context": kwargs, | |
| } | |
| context_stack = _get_context_stack() | |
| context_stack.append(context_info) | |
| try: | |
| yield | |
| finally: | |
| context_stack.pop() | |