File size: 4,200 Bytes
3193174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Context managers for callback handling.

Provides thread-safe context management for callbacks.
"""

from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any
from uuid import UUID

from .base import AsyncCallbackHandler, BaseCallbackHandler
from .handlers.metrics import MetricsCallbackHandler
from .manager import AsyncCallbackManager, CallbackManager

# Type alias for handlers
Handler = BaseCallbackHandler | AsyncCallbackHandler

__all__ = [
    "collect_metrics",
    "get_callback_manager",
    "set_callback_manager",
    "trace_as_callback",
]

# Context variable for current callback manager
_current_callback_manager: ContextVar[CallbackManager | None] = ContextVar("current_callback_manager", default=None)


def get_callback_manager() -> CallbackManager | None:
    """Get the current callback manager from context."""
    return _current_callback_manager.get()


def set_callback_manager(manager: CallbackManager | None) -> None:
    """Set the current callback manager in context."""
    _current_callback_manager.set(manager)


@contextmanager
def trace_as_callback(
    handlers: list[Handler] | None = None,
    tags: list[str] | None = None,
    metadata: dict[str, Any] | None = None,
    parent_run_id: UUID | None = None,
) -> Generator[CallbackManager]:
    """
    Context manager for tracing with callbacks.

    Example:
        from callbacks import trace_as_callback, StdoutCallbackHandler

        with trace_as_callback(handlers=[StdoutCallbackHandler()]) as manager:
            runner.run_round(graph)

    """
    manager = CallbackManager.configure(
        handlers=handlers,
        tags=tags,
        metadata=metadata,
    )
    manager.parent_run_id = parent_run_id

    token = _current_callback_manager.set(manager)
    try:
        yield manager
    finally:
        _current_callback_manager.reset(token)


@contextmanager
def collect_metrics() -> Generator[MetricsCallbackHandler]:
    """
    Context manager for collecting metrics.

    Example:
        from callbacks import collect_metrics

        with collect_metrics() as metrics:
            runner.run_round(graph)

        print(f"Total tokens: {metrics.total_tokens}")
        print(metrics.get_metrics())

    """
    handler = MetricsCallbackHandler()
    manager = CallbackManager.configure(handlers=[handler])

    token = _current_callback_manager.set(manager)
    try:
        yield handler
    finally:
        _current_callback_manager.reset(token)


def configure_callbacks(
    handlers: list[Handler] | None = None,
    inheritable_handlers: list[Handler] | None = None,
    tags: list[str] | None = None,
    inheritable_tags: list[str] | None = None,
    metadata: dict[str, Any] | None = None,
    inheritable_metadata: dict[str, Any] | None = None,
) -> CallbackManager:
    """
    Create a configured callback manager.

    Like LangChain's CallbackManager.configure().

    Example:
        manager = configure_callbacks(
            handlers=[StdoutCallbackHandler()],
            tags=["production"],
            metadata={"user_id": "123"},
        )

    """
    return CallbackManager.configure(
        handlers=handlers,
        inheritable_handlers=inheritable_handlers,
        tags=tags,
        inheritable_tags=inheritable_tags,
        metadata=metadata,
        inheritable_metadata=inheritable_metadata,
    )


def configure_async_callbacks(
    handlers: list[Handler] | None = None,
    inheritable_handlers: list[Handler] | None = None,
    tags: list[str] | None = None,
    inheritable_tags: list[str] | None = None,
    metadata: dict[str, Any] | None = None,
    inheritable_metadata: dict[str, Any] | None = None,
) -> AsyncCallbackManager:
    """
    Create a configured async callback manager.

    Example:
        manager = configure_async_callbacks(
            handlers=[MyAsyncHandler()],
        )

    """
    return AsyncCallbackManager(
        handlers=handlers,
        inheritable_handlers=inheritable_handlers,
        tags=tags,
        inheritable_tags=inheritable_tags,
        metadata=metadata,
        inheritable_metadata=inheritable_metadata,
    )