File size: 10,990 Bytes
c3d0544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import json
import logging
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from threading import local
from typing import Any

try:
    from termcolor import colored
except ImportError:
    colored = None


# Thread-local storage for context information
_context_storage = local()


class ActiveLearningLoggerAdapter(logging.LoggerAdapter):
    """Logger adapter that automatically includes active learning iteration context.

    This adapter automatically adds iteration information to log messages
    by accessing the driver's current iteration state.
    """

    def __init__(self, logger: logging.Logger, driver_ref: Any = None):
        """Initialize the adapter with a logger and optional driver reference.

        Parameters
        ----------
        logger : logging.Logger
            The underlying logger to adapt
        driver_ref : Any, optional
            Reference to the driver object to get iteration context from
        """
        super().__init__(logger, {})
        self.driver_ref = driver_ref

    def process(self, msg: str, kwargs: dict[str, Any]) -> tuple[str, dict[str, Any]]:
        """Process the log message to add iteration, run ID, and phase context.

        Parameters
        ----------
        msg : str
            The log message
        kwargs : dict[str, Any]
            Additional keyword arguments

        Returns
        -------
        tuple[str, dict[str, Any]]
            Processed message and kwargs
        """
        # Add iteration, run ID, and phase context if driver reference is available
        if self.driver_ref is not None:
            extra = kwargs.get("extra", {})

            # Add iteration context
            if hasattr(self.driver_ref, "active_learning_step_idx"):
                iteration = getattr(self.driver_ref, "active_learning_step_idx", None)
                if iteration is not None:
                    extra["iteration"] = iteration

            # Add run ID context
            if hasattr(self.driver_ref, "run_id"):
                run_id = getattr(self.driver_ref, "run_id", None)
                if run_id is not None:
                    extra["run_id"] = run_id

            # Add current phase context
            if hasattr(self.driver_ref, "current_phase"):
                phase = getattr(self.driver_ref, "current_phase", None)
                if phase is not None:
                    extra["phase"] = phase

            if extra:
                kwargs["extra"] = extra

        return msg, kwargs


class JSONFormatter(logging.Formatter):
    """JSON formatter for structured logging to files.

    This formatter converts log records to JSON format, including all
    contextual information and metadata for structured analysis.
    """

    def format(self, record: logging.LogRecord) -> str:
        """Format the log record as JSON.

        Parameters
        ----------
        record : logging.LogRecord
            The log record to format

        Returns
        -------
        str
            JSON-formatted log message
        """
        log_entry = {
            "timestamp": datetime.fromtimestamp(record.created).isoformat(),
            "level": record.levelname,
            "logger": record.name,
            "message": record.getMessage(),
            "module": record.module,
            "function": record.funcName,
            "line": record.lineno,
        }

        # Add contextual information if available
        if hasattr(record, "context"):
            log_entry["context"] = record.context

        if hasattr(record, "caller_object"):
            log_entry["caller_object"] = record.caller_object

        if hasattr(record, "iteration"):
            log_entry["iteration"] = record.iteration

        if hasattr(record, "phase"):
            log_entry["phase"] = record.phase

        extra_keys = list(filter(lambda x: x not in log_entry, record.__dict__.keys()))
        # Add any extra fields
        for key in extra_keys:
            log_entry[key] = record.__dict__[key]

        return json.dumps(log_entry)


def _get_context_stack():
    """Get the context stack for the current thread."""
    if not hasattr(_context_storage, "context_stack"):
        _context_storage.context_stack = []
    return _context_storage.context_stack


class ContextFormatter(logging.Formatter):
    """Standard formatter that includes active learning context information with colors."""

    def format(self, record):
        # Build context string
        context_parts = []
        if hasattr(record, "caller_object") and record.caller_object:
            context_parts.append(f"obj:{record.caller_object}")
        if hasattr(record, "run_id") and record.run_id:
            context_parts.append(f"run:{record.run_id}")
        if hasattr(record, "iteration") and record.iteration is not None:
            context_parts.append(f"iter:{record.iteration}")
        if hasattr(record, "phase") and record.phase:
            context_parts.append(f"phase:{record.phase}")
        if hasattr(record, "context") and record.context:
            for key, value in record.context.items():
                context_parts.append(f"{key}:{value}")

        context_str = f"[{', '.join(context_parts)}]" if context_parts else ""

        # Use standard formatting
        base_msg = super().format(record)

        # Add color to the message based on level if termcolor is available
        if colored is not None:
            match record.levelno:
                case level if level >= logging.ERROR:
                    base_msg = colored(base_msg, "red")
                case level if level >= logging.WARNING:
                    base_msg = colored(base_msg, "yellow")
                case level if level >= logging.INFO:
                    base_msg = colored(base_msg, "white")
                case _:  # DEBUG
                    base_msg = colored(base_msg, "cyan")

        # Add colored context string
        if context_str:
            if colored is not None:
                context_str = colored(context_str, "blue")
            base_msg += f" {context_str}"

        return base_msg


class ContextInjectingFilter(logging.Filter):
    """Filter that injects contextual information into log records."""

    def filter(self, record):
        # Add context information from thread-local storage
        context_stack = _get_context_stack()
        if context_stack:
            current_context = context_stack[-1]
            if current_context["caller_object"]:
                record.caller_object = current_context["caller_object"]
            if current_context["iteration"] is not None:
                record.iteration = current_context["iteration"]
            if current_context.get("phase"):
                record.phase = current_context["phase"]
            if current_context["context"]:
                record.context = current_context["context"]
        return True


def setup_active_learning_logger(
    name: str,
    run_id: str,
    log_dir: str | Path = Path("active_learning_logs"),
    level: int = logging.INFO,
) -> logging.Logger:
    """Set up a logger with active learning-specific formatting and handlers.

    Parameters
    ----------
    name : str
        Logger name
    run_id : str
        Unique identifier for this run, used in log filename
    log_dir : str | Path, optional
        Directory to store log files, by default "./logs"
    level : int, optional
        Logging level, by default logging.INFO

    Returns
    -------
    logging.Logger
        Configured standard Python logger

    Example
    -------
    >>> logger = setup_active_learning_logger("experiment", "run_001")
    >>> logger.info("Starting experiment")
    >>> with log_context(caller_object="Trainer", iteration=5):
    ...     logger.info("Training step")
    """
    # Get standard logger
    logger = logging.getLogger(name)
    logger.setLevel(level)

    # Clear any existing handlers to avoid duplicates
    logger.handlers.clear()

    # Disable propagation to prevent duplicate messages from parent loggers
    logger.propagate = False

    # Create log directory if it doesn't exist
    if isinstance(log_dir, str):
        log_dir_path = Path(log_dir)
    else:
        log_dir_path = log_dir
    log_dir_path.mkdir(parents=True, exist_ok=True)

    # Set up console handler with standard formatting
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.DEBUG)
    console_handler.setFormatter(
        ContextFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    )
    console_handler.addFilter(ContextInjectingFilter())
    logger.addHandler(console_handler)

    # Set up file handler with JSON formatting
    log_file = log_dir_path / f"{run_id}.log"
    file_handler = logging.FileHandler(log_file, mode="w")
    file_handler.setLevel(logging.DEBUG)
    file_handler.setFormatter(JSONFormatter())
    file_handler.addFilter(ContextInjectingFilter())
    logger.addHandler(file_handler)

    return logger


@contextmanager
def log_context(
    caller_object: str | None = None,
    iteration: int | None = None,
    phase: str | None = None,
    **kwargs: Any,
):
    """Context manager for adding contextual information to log messages.

    Parameters
    ----------
    caller_object : str, optional
        Name or identifier of the object making the log call
    iteration : int, optional
        Current iteration counter
    phase : str, optional
        Current phase of the active learning process
    **kwargs : Any
        Additional contextual key-value pairs

    Example
    -------
    >>> from logging import getLogger
    >>> from physicsnemo.active_learning.logger import log_context
    >>> logger = getLogger("my_logger")
    >>> with log_context(caller_object="Trainer", iteration=5, phase="training", epoch=2):
    ...     logger.info("Processing batch")
    """
    context_info = {
        "caller_object": caller_object,
        "iteration": iteration,
        "phase": phase,
        "context": kwargs,
    }

    context_stack = _get_context_stack()
    context_stack.append(context_info)

    try:
        yield
    finally:
        context_stack.pop()