Spaces:
Running
Running
| """ | |
| Context managers for callback handling. | |
| Provides thread-safe context management for callbacks. | |
| """ | |
| from collections.abc import Generator | |
| from contextlib import contextmanager | |
| from contextvars import ContextVar | |
| from typing import Any | |
| from uuid import UUID | |
| from .base import AsyncCallbackHandler, BaseCallbackHandler | |
| from .handlers.metrics import MetricsCallbackHandler | |
| from .manager import AsyncCallbackManager, CallbackManager | |
| # Type alias for handlers | |
| Handler = BaseCallbackHandler | AsyncCallbackHandler | |
| __all__ = [ | |
| "collect_metrics", | |
| "get_callback_manager", | |
| "set_callback_manager", | |
| "trace_as_callback", | |
| ] | |
| # Context variable for current callback manager | |
| _current_callback_manager: ContextVar[CallbackManager | None] = ContextVar("current_callback_manager", default=None) | |
| def get_callback_manager() -> CallbackManager | None: | |
| """Get the current callback manager from context.""" | |
| return _current_callback_manager.get() | |
| def set_callback_manager(manager: CallbackManager | None) -> None: | |
| """Set the current callback manager in context.""" | |
| _current_callback_manager.set(manager) | |
| def trace_as_callback( | |
| handlers: list[Handler] | None = None, | |
| tags: list[str] | None = None, | |
| metadata: dict[str, Any] | None = None, | |
| parent_run_id: UUID | None = None, | |
| ) -> Generator[CallbackManager]: | |
| """ | |
| Context manager for tracing with callbacks. | |
| Example: | |
| from callbacks import trace_as_callback, StdoutCallbackHandler | |
| with trace_as_callback(handlers=[StdoutCallbackHandler()]) as manager: | |
| runner.run_round(graph) | |
| """ | |
| manager = CallbackManager.configure( | |
| handlers=handlers, | |
| tags=tags, | |
| metadata=metadata, | |
| ) | |
| manager.parent_run_id = parent_run_id | |
| token = _current_callback_manager.set(manager) | |
| try: | |
| yield manager | |
| finally: | |
| _current_callback_manager.reset(token) | |
| def collect_metrics() -> Generator[MetricsCallbackHandler]: | |
| """ | |
| Context manager for collecting metrics. | |
| Example: | |
| from callbacks import collect_metrics | |
| with collect_metrics() as metrics: | |
| runner.run_round(graph) | |
| print(f"Total tokens: {metrics.total_tokens}") | |
| print(metrics.get_metrics()) | |
| """ | |
| handler = MetricsCallbackHandler() | |
| manager = CallbackManager.configure(handlers=[handler]) | |
| token = _current_callback_manager.set(manager) | |
| try: | |
| yield handler | |
| finally: | |
| _current_callback_manager.reset(token) | |
| def configure_callbacks( | |
| handlers: list[Handler] | None = None, | |
| inheritable_handlers: list[Handler] | None = None, | |
| tags: list[str] | None = None, | |
| inheritable_tags: list[str] | None = None, | |
| metadata: dict[str, Any] | None = None, | |
| inheritable_metadata: dict[str, Any] | None = None, | |
| ) -> CallbackManager: | |
| """ | |
| Create a configured callback manager. | |
| Like LangChain's CallbackManager.configure(). | |
| Example: | |
| manager = configure_callbacks( | |
| handlers=[StdoutCallbackHandler()], | |
| tags=["production"], | |
| metadata={"user_id": "123"}, | |
| ) | |
| """ | |
| return CallbackManager.configure( | |
| handlers=handlers, | |
| inheritable_handlers=inheritable_handlers, | |
| tags=tags, | |
| inheritable_tags=inheritable_tags, | |
| metadata=metadata, | |
| inheritable_metadata=inheritable_metadata, | |
| ) | |
| def configure_async_callbacks( | |
| handlers: list[Handler] | None = None, | |
| inheritable_handlers: list[Handler] | None = None, | |
| tags: list[str] | None = None, | |
| inheritable_tags: list[str] | None = None, | |
| metadata: dict[str, Any] | None = None, | |
| inheritable_metadata: dict[str, Any] | None = None, | |
| ) -> AsyncCallbackManager: | |
| """ | |
| Create a configured async callback manager. | |
| Example: | |
| manager = configure_async_callbacks( | |
| handlers=[MyAsyncHandler()], | |
| ) | |
| """ | |
| return AsyncCallbackManager( | |
| handlers=handlers, | |
| inheritable_handlers=inheritable_handlers, | |
| tags=tags, | |
| inheritable_tags=inheritable_tags, | |
| metadata=metadata, | |
| inheritable_metadata=inheritable_metadata, | |
| ) | |