gMAS / src /callbacks /context.py
Артём Боярских
chore: initial commit
3193174
"""
Context managers for callback handling.
Provides thread-safe context management for callbacks.
"""
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any
from uuid import UUID
from .base import AsyncCallbackHandler, BaseCallbackHandler
from .handlers.metrics import MetricsCallbackHandler
from .manager import AsyncCallbackManager, CallbackManager
# Type alias for handlers
Handler = BaseCallbackHandler | AsyncCallbackHandler
__all__ = [
"collect_metrics",
"get_callback_manager",
"set_callback_manager",
"trace_as_callback",
]
# Context variable for current callback manager
_current_callback_manager: ContextVar[CallbackManager | None] = ContextVar("current_callback_manager", default=None)
def get_callback_manager() -> CallbackManager | None:
"""Get the current callback manager from context."""
return _current_callback_manager.get()
def set_callback_manager(manager: CallbackManager | None) -> None:
"""Set the current callback manager in context."""
_current_callback_manager.set(manager)
@contextmanager
def trace_as_callback(
handlers: list[Handler] | None = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
parent_run_id: UUID | None = None,
) -> Generator[CallbackManager]:
"""
Context manager for tracing with callbacks.
Example:
from callbacks import trace_as_callback, StdoutCallbackHandler
with trace_as_callback(handlers=[StdoutCallbackHandler()]) as manager:
runner.run_round(graph)
"""
manager = CallbackManager.configure(
handlers=handlers,
tags=tags,
metadata=metadata,
)
manager.parent_run_id = parent_run_id
token = _current_callback_manager.set(manager)
try:
yield manager
finally:
_current_callback_manager.reset(token)
@contextmanager
def collect_metrics() -> Generator[MetricsCallbackHandler]:
"""
Context manager for collecting metrics.
Example:
from callbacks import collect_metrics
with collect_metrics() as metrics:
runner.run_round(graph)
print(f"Total tokens: {metrics.total_tokens}")
print(metrics.get_metrics())
"""
handler = MetricsCallbackHandler()
manager = CallbackManager.configure(handlers=[handler])
token = _current_callback_manager.set(manager)
try:
yield handler
finally:
_current_callback_manager.reset(token)
def configure_callbacks(
handlers: list[Handler] | None = None,
inheritable_handlers: list[Handler] | None = None,
tags: list[str] | None = None,
inheritable_tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
inheritable_metadata: dict[str, Any] | None = None,
) -> CallbackManager:
"""
Create a configured callback manager.
Like LangChain's CallbackManager.configure().
Example:
manager = configure_callbacks(
handlers=[StdoutCallbackHandler()],
tags=["production"],
metadata={"user_id": "123"},
)
"""
return CallbackManager.configure(
handlers=handlers,
inheritable_handlers=inheritable_handlers,
tags=tags,
inheritable_tags=inheritable_tags,
metadata=metadata,
inheritable_metadata=inheritable_metadata,
)
def configure_async_callbacks(
handlers: list[Handler] | None = None,
inheritable_handlers: list[Handler] | None = None,
tags: list[str] | None = None,
inheritable_tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
inheritable_metadata: dict[str, Any] | None = None,
) -> AsyncCallbackManager:
"""
Create a configured async callback manager.
Example:
manager = configure_async_callbacks(
handlers=[MyAsyncHandler()],
)
"""
return AsyncCallbackManager(
handlers=handlers,
inheritable_handlers=inheritable_handlers,
tags=tags,
inheritable_tags=inheritable_tags,
metadata=metadata,
inheritable_metadata=inheritable_metadata,
)