File size: 9,093 Bytes
557ee65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import logging
import json
import time
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Dict, Optional, Union
from . import components

# Context variables for trace and span propagation
# These allow us to keep track of trace_id and span context without passing it through every function call
_trace_id_var: ContextVar[Optional[str]] = ContextVar("trace_id", default=None)
_span_name_var: ContextVar[Optional[str]] = ContextVar("span_name", default=None)
_component_var: ContextVar[Optional[str]] = ContextVar("component", default=None)


def set_trace_id(trace_id: str):
    """Set the trace ID in the current context."""
    _trace_id_var.set(trace_id)


def get_trace_id() -> Optional[str]:
    """Get the trace ID from the current context."""
    return _trace_id_var.get()


def clear_trace_id():
    """Clear the trace ID from the current context."""
    _trace_id_var.set(None)


class StructuredFormatter(logging.Formatter):
    """
    Format logs as structured JSON records.
    Each record includes standard fields plus trace and span context.
    """

    def format(self, record: logging.LogRecord) -> str:
        # Core fields
        log_records = {
            "timestamp": self.formatTime(record, self.datefmt),
            "level": record.levelname,
            "logger": record.name,
            "message": record.getMessage(),
            "trace_id": get_trace_id(),
            "span_name": _span_name_var.get(),
            "component": _component_var.get() or getattr(record, "component", "unknown"),
            "event": getattr(record, "event", "info"),
        }

        # Add duration if it was added to the record
        if hasattr(record, "duration_ms"):
            log_records["duration_ms"] = record.duration_ms

        # Add extra fields passed via extra={"fields": {...}}
        if hasattr(record, "fields") and isinstance(record.fields, dict):
            log_records.update(record.fields)

        return json.dumps(log_records)


def setup_logging(level: int = logging.INFO):
    """
    Configure the root logger to use the StructuredFormatter.
    This should be called as early as possible in the application lifecycle.
    """
    import os

    root_logger = logging.getLogger()
    root_logger.setLevel(level)

    formatter = StructuredFormatter()

    # Configure console handler
    if not any(
        isinstance(h, logging.StreamHandler) and not isinstance(h, logging.FileHandler)
        for h in root_logger.handlers
    ):
        console_handler = logging.StreamHandler()
        console_handler.setFormatter(formatter)
        root_logger.addHandler(console_handler)

    # Configure file handler
    log_dir = "logs"
    if not os.path.exists(log_dir):
        os.makedirs(log_dir, exist_ok=True)

    log_file = os.path.join(log_dir, "app.log")
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    root_logger.addHandler(file_handler)

    # Update any existing handlers that might not have our formatter
    for handler in root_logger.handlers:
        handler.setFormatter(formatter)


# Initialize logging on import
setup_logging()


def get_logger(name: str) -> logging.Logger:
    """Return a logger with the given name."""
    return logging.getLogger(name)


def bind_trace(trace_id: str):
    """Bind a trace ID to the current context."""
    set_trace_id(trace_id)


def bind_new_trace_id():
    """Create a new UUID and bind it as trace_id to the current context."""
    import uuid

    trace_id = str(uuid.uuid4())
    set_trace_id(trace_id)
    return trace_id


def ensure_trace():
    """
    Guarantee a trace_id exists in the current context.
    If none is active, create and bind a new one.
    """
    trace_id = get_trace_id()
    if not trace_id:
        bind_new_trace_id()
    return get_trace_id()


import os

class SpanContext:
    """
    Context manager for a span that supports structured metadata.
    """
    def __init__(self, obs_module, name: str, component: str):
        self.obs = obs_module
        self.name = name
        self.component = component
        self.start_time = None
        self.fields = {}
        self.previous_span = None
        self.previous_component = None

    def set_field(self, key: str, value: Any):
        """Add a metadata field to be included in the span end log."""
        self.fields[key] = value

    def __enter__(self):
        # Guard for duplicates
        active_span = _span_name_var.get()
        if active_span == self.name:
            dev_mode = os.getenv("DEV_MODE", "false").lower() == "true"
            if dev_mode:
                raise RuntimeError(f"Duplicate nested span detected: {self.name} in {self.component}")
            else:
                log_event(
                    "warning",
                    f"Duplicate nested span detected: {self.name}",
                    event="duplicate_span_detected",
                    component=self.component,
                    span_name=self.name,
                )

        # Store previous context
        self.previous_span = _span_name_var.get()
        self.previous_component = _component_var.get()

        _span_name_var.set(self.name)
        _component_var.set(self.component)
        
        log_event("info", f"Started span: {self.name}", event="start", component=self.component)
        self.start_time = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        duration_ms = (time.time() - self.start_time) * 1000
        
        # Merge exception info if any
        if exc_val:
            self.fields["error"] = str(exc_val)
            self.fields["error_type"] = exc_type.__name__ if exc_type else "Exception"
        
        # Pop and log
        self.obs._pop_span(
            span_name=self.name,
            component=self.component,
            duration_ms=duration_ms,
            extra_fields=self.fields,
            previous_span=self.previous_span,
            previous_component=self.previous_component
        )


def _classify_span(span_name: str, component: str) -> str:
    """Helper to categorize spans for structured analytics."""
    if span_name.endswith(".llm"):
        return components.LLM

    if span_name.endswith(".run") and component == components.AGENT:
        return components.AGENT

    if component == components.SERVICE:
        return components.SERVICE

    if component == components.DOMAIN:
        return components.DOMAIN

    if component == components.ORCHESTRATOR:
        return components.ORCHESTRATOR

    return "internal"


def _pop_span(
    span_name: str, 
    component: str, 
    duration_ms: Optional[float] = None, 
    extra_fields: Optional[Dict[str, Any]] = None,
    previous_span: Optional[str] = None,
    previous_component: Optional[str] = None
):
    """
    Internal logic to close a span, classify it, and emit the structured log.
    """
    fields = extra_fields or {}
    
    # Feature extraction
    feature = span_name.split(".")[0] if "." in span_name else span_name
    
    # Classification
    span_type = _classify_span(span_name, component)

    log_event(
        "info", 
        f"Ended span: {span_name}", 
        event="span_end", 
        component=component, 
        duration_ms=duration_ms,
        span_type=span_type,
        feature=feature,
        **fields
    )

    # Restore previous context if provided, otherwise clear if matching
    if previous_span is not None or previous_component is not None:
        _span_name_var.set(previous_span)
        _component_var.set(previous_component)
    elif _span_name_var.get() == span_name:
        _span_name_var.set(None)
        _component_var.set(None)


def start_span(name: str, component: str) -> SpanContext:
    """
    Start a new span context.
    Returns a SpanContext object that works as a context manager.
    """
    import sys
    return SpanContext(sys.modules[__name__], name, component)


def end_span(name: str, component: str, **fields):
    """
    End the current span context and log the end event.
    When using start_span with 'with', this is called automatically.
    """
    # Manual end still works and correctly pops the context
    _pop_span(span_name=name, component=component, extra_fields=fields)


def log_event(
    level: str, message: str, event: str = "info", component: Optional[str] = None, **fields
):
    """
    Log a structured event with optional extra fields.

    Args:
        level: Log level (e.g., "info", "error", "debug")
        message: The log message
        event: The type of event (start, end, info, error)
        component: Overwrite the current component
        **fields: Additional key-value pairs to include in the log
    """
    # Use the logger name of the caller or a default
    logger = logging.getLogger("observability")
    lvl = getattr(logging, level.upper(), logging.INFO)

    extra = {"event": event, "fields": fields}
    if component:
        extra["component"] = component

    logger.log(lvl, message, extra=extra)