from dataclasses import dataclass, field from typing import Any from uuid import UUID, uuid4 import pytest from callbacks import ( AsyncCallbackHandler, AsyncCallbackManager, BaseCallbackHandler, CallbackManager, get_callback_manager, set_callback_manager, trace_as_callback, ) from callbacks.context import ( collect_metrics, configure_async_callbacks, configure_callbacks, ) @dataclass class Call: name: str kwargs: dict[str, Any] = field(default_factory=dict) class RecordingHandler(BaseCallbackHandler): def __init__(self) -> None: self.calls: list[Call] = [] def on_run_start( self, *, run_id: UUID, query: str, num_agents: int = 0, execution_order: list[str] | None = None, parent_run_id: UUID | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any, ) -> None: self.calls.append( Call( "on_run_start", { "run_id": run_id, "query": query, "num_agents": num_agents, "execution_order": execution_order, "parent_run_id": parent_run_id, "tags": tags, "metadata": metadata, **kwargs, }, ) ) def on_retry( self, *, run_id: UUID, agent_id: str, attempt: int, max_attempts: int = 0, delay_ms: float = 0.0, error: str = "", parent_run_id: UUID | None = None, **kwargs: Any, ) -> None: self.calls.append( Call( "on_retry", { "run_id": run_id, "agent_id": agent_id, "attempt": attempt, "max_attempts": max_attempts, "delay_ms": delay_ms, "error": error, "parent_run_id": parent_run_id, **kwargs, }, ) ) def on_llm_new_token( self, token: str, *, run_id: UUID, agent_id: str, token_index: int = 0, is_first: bool = False, is_last: bool = False, parent_run_id: UUID | None = None, **kwargs: Any, ) -> None: self.calls.append( Call( "on_llm_new_token", { "token": token, "run_id": run_id, "agent_id": agent_id, "token_index": token_index, "is_first": is_first, "is_last": is_last, "parent_run_id": parent_run_id, **kwargs, }, ) ) def on_memory_read( self, *, run_id: UUID, agent_id: str, entries_count: int = 0, keys: list[str] | None = None, parent_run_id: UUID | None = None, **kwargs: Any, ) -> None: self.calls.append( Call( "on_memory_read", { "run_id": run_id, "agent_id": agent_id, "entries_count": entries_count, "keys": keys, "parent_run_id": parent_run_id, **kwargs, }, ) ) def on_budget_warning( self, *, run_id: UUID, budget_type: str, current: float, limit: float, ratio: float = 0.0, parent_run_id: UUID | None = None, **kwargs: Any, ) -> None: self.calls.append( Call( "on_budget_warning", { "run_id": run_id, "budget_type": budget_type, "current": current, "limit": limit, "ratio": ratio, "parent_run_id": parent_run_id, **kwargs, }, ) ) class RaisingHandler(BaseCallbackHandler): def __init__(self, *, raise_error: bool) -> None: self.raise_error = raise_error def on_run_start(self, **_: Any) -> None: msg = "boom" raise RuntimeError(msg) class AsyncRecordingHandler(AsyncCallbackHandler): def __init__(self) -> None: self.calls: list[Call] = [] async def on_retry( self, *, run_id: UUID, agent_id: str, attempt: int, max_attempts: int = 0, delay_ms: float = 0.0, error: str = "", parent_run_id: UUID | None = None, **kwargs: Any, ) -> None: self.calls.append( Call( "on_retry", { "run_id": run_id, "agent_id": agent_id, "attempt": attempt, "max_attempts": max_attempts, "delay_ms": delay_ms, "error": error, "parent_run_id": parent_run_id, **kwargs, }, ) ) def test_trace_as_callback_sets_and_resets_context_manager() -> None: set_callback_manager(None) assert get_callback_manager() is None with trace_as_callback(tags=["outer"], metadata={"x": 1}) as outer: assert get_callback_manager() is outer assert outer.tags == ["outer"] assert outer.metadata == {"x": 1} with trace_as_callback(tags=["inner"]) as inner: assert get_callback_manager() is inner assert inner.tags == ["inner"] assert get_callback_manager() is outer assert get_callback_manager() is None def test_callback_manager_on_run_start_passes_parent_tags_metadata() -> None: h = RecordingHandler() parent = uuid4() m = CallbackManager.configure(handlers=[h], tags=["t1"], metadata={"k": "v"}) m.parent_run_id = parent run_id = m.on_run_start(query="hi", num_agents=2) assert len(h.calls) == 1 call = h.calls[0] assert call.name == "on_run_start" assert call.kwargs["run_id"] == run_id assert call.kwargs["query"] == "hi" assert call.kwargs["num_agents"] == 2 assert call.kwargs["execution_order"] == [] assert call.kwargs["parent_run_id"] == parent assert call.kwargs["tags"] == ["t1"] assert call.kwargs["metadata"] == {"k": "v"} def test_callback_manager_ignore_flags_filter_events() -> None: run_id = uuid4() h_ok = RecordingHandler() h_ignore_retry = RecordingHandler() h_ignore_retry.ignore_retry = True h_ignore_llm = RecordingHandler() h_ignore_llm.ignore_llm = True h_ignore_memory = RecordingHandler() h_ignore_memory.ignore_memory = True h_ignore_budget = RecordingHandler() h_ignore_budget.ignore_budget = True m = CallbackManager(handlers=[h_ok, h_ignore_retry, h_ignore_llm, h_ignore_memory, h_ignore_budget]) m.on_retry(run_id, agent_id="a1", attempt=1, max_attempts=3, delay_ms=10.0, error="e") m.on_llm_new_token(run_id, "tok", agent_id="a1", token_index=0, is_first=True, is_last=False) m.on_memory_read(run_id, agent_id="a1", entries_count=2, keys=["k1", "k2"]) m.on_budget_warning(run_id, budget_type="tokens", current=90.0, limit=100.0, ratio=0.9) assert [c.name for c in h_ok.calls] == [ "on_retry", "on_llm_new_token", "on_memory_read", "on_budget_warning", ] assert [c.name for c in h_ignore_retry.calls] == [ "on_llm_new_token", "on_memory_read", "on_budget_warning", ] assert [c.name for c in h_ignore_llm.calls] == [ "on_retry", "on_memory_read", "on_budget_warning", ] assert [c.name for c in h_ignore_memory.calls] == [ "on_retry", "on_llm_new_token", "on_budget_warning", ] assert [c.name for c in h_ignore_budget.calls] == [ "on_retry", "on_llm_new_token", "on_memory_read", ] def test_callback_manager_error_in_one_handler_does_not_block_others_when_raise_error_false() -> None: bad = RaisingHandler(raise_error=False) good = RecordingHandler() m = CallbackManager(handlers=[bad, good]) run_id = m.on_run_start(query="q") assert len(good.calls) == 1 assert good.calls[0].name == "on_run_start" assert good.calls[0].kwargs["run_id"] == run_id def test_callback_manager_error_propagates_when_raise_error_true() -> None: bad = RaisingHandler(raise_error=True) good = RecordingHandler() m = CallbackManager(handlers=[bad, good]) with pytest.raises(RuntimeError, match="boom"): m.on_run_start(query="q") assert good.calls == [] def test_callback_manager_get_child_inherits_only_inheritable_settings() -> None: inheritable = RecordingHandler() non_inheritable = RecordingHandler() m = CallbackManager() m.add_handler(inheritable, inherit=True) m.add_handler(non_inheritable, inherit=False) m.add_tags(["tag_inherit"], inherit=True) m.add_tags(["tag_local"], inherit=False) m.add_metadata({"a": 1}, inherit=True) m.add_metadata({"b": 2}, inherit=False) parent_run_id = uuid4() child = m.get_child(parent_run_id) assert child.parent_run_id == parent_run_id assert child.handlers == [inheritable] assert child.tags == ["tag_inherit"] assert child.metadata == {"a": 1} run_id = child.on_run_start(query="x") assert inheritable.calls[-1].kwargs["parent_run_id"] == parent_run_id assert inheritable.calls[-1].kwargs["run_id"] == run_id @pytest.mark.asyncio async def test_async_callback_manager_runs_sync_and_async_handlers() -> None: sync_h = RecordingHandler() async_h = AsyncRecordingHandler() m = AsyncCallbackManager(handlers=[sync_h, async_h]) run_id = uuid4() await m.on_retry(run_id, agent_id="a1", attempt=2, max_attempts=5, delay_ms=1.0, error="x") assert [c.name for c in sync_h.calls] == ["on_retry"] assert [c.name for c in async_h.calls] == ["on_retry"] # ═══════════════════════════════════════════════════════════════ # AsyncCallbackHandler — direct method calls (coverage) # ═══════════════════════════════════════════════════════════════ @pytest.mark.asyncio async def test_async_callback_handler_all_pass_methods() -> None: """Call all async pass methods on AsyncCallbackHandler directly.""" h = AsyncCallbackHandler() run_id = uuid4() agent_id = "agent_a" assert h.is_async is True # Run lifecycle await h.on_run_start(run_id=run_id, query="q", num_agents=1) await h.on_run_end(run_id=run_id, output="out", success=True) # Agent lifecycle await h.on_agent_start(run_id=run_id, agent_id=agent_id) await h.on_agent_end(run_id=run_id, agent_id=agent_id, output="resp") await h.on_agent_error(ValueError("err"), run_id=run_id, agent_id=agent_id) # Retry await h.on_retry(run_id=run_id, agent_id=agent_id, attempt=1) # Streaming await h.on_llm_new_token("tok", run_id=run_id, agent_id=agent_id) # Planning await h.on_plan_created(run_id=run_id, num_steps=2, execution_order=["a", "b"]) await h.on_topology_changed( run_id=run_id, reason="test", old_remaining=["a"], new_remaining=["b"], ) # Pruning/Fallback await h.on_prune(run_id=run_id, agent_id=agent_id, reason="pruned") await h.on_fallback( run_id=run_id, failed_agent_id="a", fallback_agent_id="b", ) # Parallel await h.on_parallel_start(run_id=run_id, agent_ids=["a", "b"]) await h.on_parallel_end(run_id=run_id, agent_ids=["a", "b"]) # Memory await h.on_memory_read(run_id=run_id, agent_id=agent_id) await h.on_memory_write(run_id=run_id, agent_id=agent_id, key="key") # Budget await h.on_budget_warning(run_id=run_id, budget_type="tokens", current=90.0, limit=100.0) await h.on_budget_exceeded(run_id=run_id, budget_type="tokens", current=105.0, limit=100.0) # Tools await h.on_tool_start(run_id=run_id, tool_name="test_tool") await h.on_tool_end(run_id=run_id, tool_name="test_tool") await h.on_tool_error(run_id=run_id, tool_name="test_tool") def test_base_callback_handler_is_async_false() -> None: """BaseCallbackHandler.is_async should be False.""" from callbacks.base import CallbackHandlerMixin handler = BaseCallbackHandler() assert handler.is_async is False # Also test mixin directly mixin = CallbackHandlerMixin() assert mixin.is_async is False # ═══════════════════════════════════════════════════════════════ # callbacks/context.py — collect_metrics, configure_callbacks # ═══════════════════════════════════════════════════════════════ def test_collect_metrics_context_manager() -> None: """collect_metrics should provide a MetricsCallbackHandler.""" from callbacks.handlers.metrics import MetricsCallbackHandler with collect_metrics() as metrics: assert isinstance(metrics, MetricsCallbackHandler) # The context should have the callback manager set cm = get_callback_manager() assert cm is not None def test_collect_metrics_resets_after_exit() -> None: """collect_metrics should reset callback manager after exit.""" set_callback_manager(None) with collect_metrics(): assert get_callback_manager() is not None assert get_callback_manager() is None def test_configure_callbacks_returns_manager() -> None: """configure_callbacks should return a CallbackManager.""" manager = configure_callbacks(tags=["test"]) assert isinstance(manager, CallbackManager) assert "test" in manager.tags def test_configure_callbacks_with_handlers() -> None: """configure_callbacks with handlers.""" h = BaseCallbackHandler() manager = configure_callbacks(handlers=[h]) assert h in manager.handlers def test_configure_async_callbacks_returns_manager() -> None: """configure_async_callbacks should return an AsyncCallbackManager.""" manager = configure_async_callbacks(tags=["async-test"]) assert isinstance(manager, AsyncCallbackManager) assert "async-test" in manager.tags