Spaces:
Running
Running
File size: 4,200 Bytes
3193174 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """
Context managers for callback handling.
Provides thread-safe context management for callbacks.
"""
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any
from uuid import UUID
from .base import AsyncCallbackHandler, BaseCallbackHandler
from .handlers.metrics import MetricsCallbackHandler
from .manager import AsyncCallbackManager, CallbackManager
# Type alias for handlers
Handler = BaseCallbackHandler | AsyncCallbackHandler
__all__ = [
"collect_metrics",
"get_callback_manager",
"set_callback_manager",
"trace_as_callback",
]
# Context variable for current callback manager
_current_callback_manager: ContextVar[CallbackManager | None] = ContextVar("current_callback_manager", default=None)
def get_callback_manager() -> CallbackManager | None:
"""Get the current callback manager from context."""
return _current_callback_manager.get()
def set_callback_manager(manager: CallbackManager | None) -> None:
"""Set the current callback manager in context."""
_current_callback_manager.set(manager)
@contextmanager
def trace_as_callback(
handlers: list[Handler] | None = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
parent_run_id: UUID | None = None,
) -> Generator[CallbackManager]:
"""
Context manager for tracing with callbacks.
Example:
from callbacks import trace_as_callback, StdoutCallbackHandler
with trace_as_callback(handlers=[StdoutCallbackHandler()]) as manager:
runner.run_round(graph)
"""
manager = CallbackManager.configure(
handlers=handlers,
tags=tags,
metadata=metadata,
)
manager.parent_run_id = parent_run_id
token = _current_callback_manager.set(manager)
try:
yield manager
finally:
_current_callback_manager.reset(token)
@contextmanager
def collect_metrics() -> Generator[MetricsCallbackHandler]:
"""
Context manager for collecting metrics.
Example:
from callbacks import collect_metrics
with collect_metrics() as metrics:
runner.run_round(graph)
print(f"Total tokens: {metrics.total_tokens}")
print(metrics.get_metrics())
"""
handler = MetricsCallbackHandler()
manager = CallbackManager.configure(handlers=[handler])
token = _current_callback_manager.set(manager)
try:
yield handler
finally:
_current_callback_manager.reset(token)
def configure_callbacks(
handlers: list[Handler] | None = None,
inheritable_handlers: list[Handler] | None = None,
tags: list[str] | None = None,
inheritable_tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
inheritable_metadata: dict[str, Any] | None = None,
) -> CallbackManager:
"""
Create a configured callback manager.
Like LangChain's CallbackManager.configure().
Example:
manager = configure_callbacks(
handlers=[StdoutCallbackHandler()],
tags=["production"],
metadata={"user_id": "123"},
)
"""
return CallbackManager.configure(
handlers=handlers,
inheritable_handlers=inheritable_handlers,
tags=tags,
inheritable_tags=inheritable_tags,
metadata=metadata,
inheritable_metadata=inheritable_metadata,
)
def configure_async_callbacks(
handlers: list[Handler] | None = None,
inheritable_handlers: list[Handler] | None = None,
tags: list[str] | None = None,
inheritable_tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
inheritable_metadata: dict[str, Any] | None = None,
) -> AsyncCallbackManager:
"""
Create a configured async callback manager.
Example:
manager = configure_async_callbacks(
handlers=[MyAsyncHandler()],
)
"""
return AsyncCallbackManager(
handlers=handlers,
inheritable_handlers=inheritable_handlers,
tags=tags,
inheritable_tags=inheritable_tags,
metadata=metadata,
inheritable_metadata=inheritable_metadata,
)
|