"""Tests for src/callbacks/manager.py""" import pytest from uuid import uuid4, UUID from typing import Any from unittest.mock import MagicMock, patch from callbacks.base import BaseCallbackHandler from callbacks.manager import AsyncCallbackManager, CallbackManager # ─────────────────────────── Concrete handler for testing ────────────────────── class RecordingHandler(BaseCallbackHandler): """Records all calls for assertion.""" def __init__(self, raise_error: bool = False): self.raise_error = raise_error self.calls: list[tuple[str, dict]] = [] def _record(self, method: str, **kwargs): self.calls.append((method, kwargs)) def on_run_start(self, *, run_id, query, **kwargs): self._record("on_run_start", run_id=run_id, query=query) def on_run_end(self, *, run_id, output, **kwargs): self._record("on_run_end", run_id=run_id, output=output) def on_agent_start(self, *, run_id, agent_id, **kwargs): self._record("on_agent_start", run_id=run_id, agent_id=agent_id) def on_agent_end(self, *, run_id, agent_id, output, **kwargs): self._record("on_agent_end", run_id=run_id, agent_id=agent_id) def on_agent_error(self, error, *, run_id, agent_id, **kwargs): self._record("on_agent_error", run_id=run_id, agent_id=agent_id) def on_retry(self, *, run_id, agent_id, attempt, **kwargs): self._record("on_retry", run_id=run_id, agent_id=agent_id) def on_llm_new_token(self, token, *, run_id, agent_id, **kwargs): self._record("on_llm_new_token", run_id=run_id, agent_id=agent_id, token=token) def on_plan_created(self, *, run_id, num_steps, execution_order, **kwargs): self._record("on_plan_created", run_id=run_id, num_steps=num_steps) def on_topology_changed(self, *, run_id, reason, **kwargs): self._record("on_topology_changed", run_id=run_id, reason=reason) def on_prune(self, *, run_id, agent_id, reason, **kwargs): self._record("on_prune", run_id=run_id, agent_id=agent_id) def on_fallback(self, *, run_id, failed_agent_id, fallback_agent_id, **kwargs): self._record("on_fallback", run_id=run_id, failed_agent_id=failed_agent_id) def on_parallel_start(self, *, run_id, agent_ids, **kwargs): self._record("on_parallel_start", run_id=run_id, agent_ids=agent_ids) def on_parallel_end(self, *, run_id, agent_ids, **kwargs): self._record("on_parallel_end", run_id=run_id, agent_ids=agent_ids) def on_tool_start(self, *, run_id, tool_name, **kwargs): self._record("on_tool_start", run_id=run_id, tool_name=tool_name) def on_tool_end(self, *, run_id, tool_name, **kwargs): self._record("on_tool_end", run_id=run_id, tool_name=tool_name) def on_memory_read(self, *, run_id, agent_id, **kwargs): self._record("on_memory_read", run_id=run_id) def on_memory_write(self, *, run_id, agent_id, key, **kwargs): self._record("on_memory_write", run_id=run_id, key=key) def on_budget_warning(self, *, run_id, budget_type, **kwargs): self._record("on_budget_warning", run_id=run_id, budget_type=budget_type) def on_budget_exceeded(self, *, run_id, budget_type, **kwargs): self._record("on_budget_exceeded", run_id=run_id, budget_type=budget_type) def on_tool_error(self, *, run_id, tool_name, **kwargs): self._record("on_tool_error", run_id=run_id, tool_name=tool_name) class ErrorHandler(RecordingHandler): """Handler that raises on all calls.""" raise_error = True def on_run_start(self, *, run_id, query, **kwargs): raise RuntimeError("intentional error") # ═══════════════════════════════════════════════════════════════ # CallbackManager initialization & configuration # ═══════════════════════════════════════════════════════════════ class TestCallbackManagerInit: def test_empty_init(self): cm = CallbackManager() assert cm.handlers == [] assert cm.tags == [] assert cm.metadata == {} assert cm.is_async is False def test_with_handlers(self): h = RecordingHandler() cm = CallbackManager(handlers=[h]) assert h in cm.handlers def test_with_tags_and_metadata(self): cm = CallbackManager(tags=["t1", "t2"], metadata={"key": "value"}) assert "t1" in cm.tags assert cm.metadata["key"] == "value" def test_configure_classmethod(self): h = RecordingHandler() cm = CallbackManager.configure( handlers=[h], tags=["tag"], metadata={"k": "v"}, ) assert h in cm.handlers assert "tag" in cm.tags def test_copy(self): h = RecordingHandler() cm = CallbackManager(handlers=[h], tags=["t1"]) copy = cm.copy() assert h in copy.handlers assert copy is not cm def test_merge(self): h1 = RecordingHandler() h2 = RecordingHandler() cm1 = CallbackManager(handlers=[h1], tags=["t1"]) cm2 = CallbackManager(handlers=[h2], tags=["t2"]) merged = cm1.merge(cm2) assert h1 in merged.handlers assert h2 in merged.handlers assert "t1" in merged.tags assert "t2" in merged.tags def test_add_handler(self): cm = CallbackManager() h = RecordingHandler() cm.add_handler(h) assert h in cm.handlers def test_add_handler_inheritable(self): cm = CallbackManager() h = RecordingHandler() cm.add_handler(h, inherit=True) assert h in cm.handlers assert h in cm.inheritable_handlers def test_remove_handler(self): h = RecordingHandler() cm = CallbackManager(handlers=[h]) cm.remove_handler(h) assert h not in cm.handlers def test_set_handlers(self): h1 = RecordingHandler() h2 = RecordingHandler() cm = CallbackManager(handlers=[h1]) cm.set_handlers([h2]) assert cm.handlers == [h2] def test_add_tags(self): cm = CallbackManager() cm.add_tags(["t1", "t2"]) assert "t1" in cm.tags def test_add_tags_inheritable(self): cm = CallbackManager() cm.add_tags(["t1"], inherit=True) assert "t1" in cm.inheritable_tags def test_remove_tags(self): cm = CallbackManager(tags=["t1", "t2"]) cm.remove_tags(["t1"]) assert "t1" not in cm.tags assert "t2" in cm.tags def test_add_metadata(self): cm = CallbackManager() cm.add_metadata({"key": "val"}) assert cm.metadata["key"] == "val" def test_add_metadata_inheritable(self): cm = CallbackManager() cm.add_metadata({"key": "val"}, inherit=True) assert cm.inheritable_metadata["key"] == "val" def test_get_child(self): h = RecordingHandler() cm = CallbackManager(inheritable_handlers=[h], inheritable_tags=["t"]) parent_run_id = uuid4() child = cm.get_child(parent_run_id) assert h in child.handlers assert "t" in child.tags assert child.parent_run_id == parent_run_id # ═══════════════════════════════════════════════════════════════ # CallbackManager event dispatching # ═══════════════════════════════════════════════════════════════ class TestCallbackManagerEvents: def setup_method(self): self.handler = RecordingHandler() self.cm = CallbackManager(handlers=[self.handler]) self.run_id = uuid4() def test_on_run_start_returns_run_id(self): rid = self.cm.on_run_start(query="test", num_agents=3) assert isinstance(rid, UUID) def test_on_run_start_with_provided_id(self): rid = self.cm.on_run_start(run_id=self.run_id, query="test") assert rid == self.run_id def test_on_run_start_dispatched(self): self.cm.on_run_start(query="hello") assert any(m == "on_run_start" for m, _ in self.handler.calls) def test_on_run_end_dispatched(self): self.cm.on_run_end(self.run_id, output="done", success=True) assert any(m == "on_run_end" for m, _ in self.handler.calls) def test_on_agent_start_dispatched(self): self.cm.on_agent_start(self.run_id, agent_id="solver") assert any(m == "on_agent_start" for m, _ in self.handler.calls) def test_on_agent_end_dispatched(self): self.cm.on_agent_end(self.run_id, agent_id="solver", output="result") assert any(m == "on_agent_end" for m, _ in self.handler.calls) def test_on_agent_error_dispatched(self): self.cm.on_agent_error(self.run_id, ValueError("err"), agent_id="solver") assert any(m == "on_agent_error" for m, _ in self.handler.calls) def test_on_retry_dispatched(self): self.cm.on_retry(self.run_id, agent_id="solver", attempt=1) assert any(m == "on_retry" for m, _ in self.handler.calls) def test_on_llm_new_token_dispatched(self): self.cm.on_llm_new_token(self.run_id, "tok", agent_id="solver") assert any(m == "on_llm_new_token" for m, _ in self.handler.calls) def test_on_plan_created_dispatched(self): self.cm.on_plan_created( self.run_id, num_steps=3, execution_order=["a", "b", "c"], ) assert any(m == "on_plan_created" for m, _ in self.handler.calls) def test_on_topology_changed_dispatched(self): self.cm.on_topology_changed( self.run_id, reason="pruned", old_remaining=["a", "b"], new_remaining=["b"], ) assert any(m == "on_topology_changed" for m, _ in self.handler.calls) def test_on_prune_dispatched(self): self.cm.on_prune(self.run_id, agent_id="solver", reason="low quality") assert any(m == "on_prune" for m, _ in self.handler.calls) def test_on_fallback_dispatched(self): self.cm.on_fallback( self.run_id, failed_agent_id="solver", fallback_agent_id="backup", ) assert any(m == "on_fallback" for m, _ in self.handler.calls) def test_on_parallel_start_dispatched(self): self.cm.on_parallel_start(self.run_id, agent_ids=["a", "b"]) assert any(m == "on_parallel_start" for m, _ in self.handler.calls) def test_on_parallel_end_dispatched(self): self.cm.on_parallel_end(self.run_id, agent_ids=["a", "b"]) assert any(m == "on_parallel_end" for m, _ in self.handler.calls) def test_on_tool_start_dispatched(self): self.cm.on_tool_start(self.run_id, agent_id="solver", tool_name="search", action="search") assert any(m == "on_tool_start" for m, _ in self.handler.calls) def test_on_tool_end_dispatched(self): self.cm.on_tool_end(self.run_id, agent_id="solver", tool_name="search", success=True) assert any(m == "on_tool_end" for m, _ in self.handler.calls) def test_on_memory_read_dispatched(self): self.cm.on_memory_read(self.run_id, agent_id="solver", keys=["context"]) assert any(m == "on_memory_read" for m, _ in self.handler.calls) def test_on_memory_write_dispatched(self): self.cm.on_memory_write(self.run_id, agent_id="solver", key="context", value_size=256) assert any(m == "on_memory_write" for m, _ in self.handler.calls) def test_on_budget_warning_dispatched(self): self.cm.on_budget_warning( self.run_id, budget_type="tokens", current=800, limit=1000, ratio=0.8, ) assert any(m == "on_budget_warning" for m, _ in self.handler.calls) def test_on_budget_exceeded_dispatched(self): self.cm.on_budget_exceeded(self.run_id, budget_type="requests", current=10, limit=10) assert any(m == "on_budget_exceeded" for m, _ in self.handler.calls) def test_on_tool_error_dispatched(self): self.cm.on_tool_error( self.run_id, agent_id="solver", tool_name="search", error_type="timeout", error_message="timed out", ) assert any(m == "on_tool_error" for m, _ in self.handler.calls) def test_ignore_memory_skips_handler(self): class MemIgnoringHandler(RecordingHandler): ignore_memory = True handler = MemIgnoringHandler() cm = CallbackManager(handlers=[handler]) run_id = uuid4() cm.on_memory_read(run_id, agent_id="solver") cm.on_memory_write(run_id, agent_id="solver", key="k") assert not any(m in ("on_memory_read", "on_memory_write") for m, _ in handler.calls) def test_ignore_budget_skips_handler(self): class BudgetIgnoringHandler(RecordingHandler): ignore_budget = True handler = BudgetIgnoringHandler() cm = CallbackManager(handlers=[handler]) run_id = uuid4() cm.on_budget_warning(run_id, budget_type="tokens", current=800, limit=1000) assert not any(m == "on_budget_warning" for m, _ in handler.calls) def test_ignore_tool_skips_handler(self): class ToolIgnoringHandler(RecordingHandler): ignore_tool = True handler = ToolIgnoringHandler() cm = CallbackManager(handlers=[handler]) run_id = uuid4() cm.on_tool_start(run_id, tool_name="search") assert not any(m == "on_tool_start" for m, _ in handler.calls) def test_ignore_retry_skips_handler(self): class RetryIgnoringHandler(RecordingHandler): ignore_retry = True handler = RetryIgnoringHandler() cm = CallbackManager(handlers=[handler]) run_id = uuid4() cm.on_retry(run_id, agent_id="solver", attempt=1) assert not any(m == "on_retry" for m, _ in handler.calls) def test_ignore_llm_skips_handler(self): class LLMIgnoringHandler(RecordingHandler): ignore_llm = True handler = LLMIgnoringHandler() cm = CallbackManager(handlers=[handler]) run_id = uuid4() cm.on_llm_new_token(run_id, "tok", agent_id="solver") assert not any(m == "on_llm_new_token" for m, _ in handler.calls) # ═══════════════════════════════════════════════════════════════ # Error handling # ═══════════════════════════════════════════════════════════════ class TestCallbackManagerErrorHandling: def test_handler_error_propagated_when_raise_error(self): handler = ErrorHandler(raise_error=True) cm = CallbackManager(handlers=[handler]) with pytest.raises(RuntimeError, match="intentional error"): cm.on_run_start(query="test") def test_handler_error_suppressed_when_not_raise_error(self): class SilentErrorHandler(RecordingHandler): raise_error = False def on_run_start(self, *, run_id, query, **kwargs): raise ValueError("silent error") handler = SilentErrorHandler() cm = CallbackManager(handlers=[handler]) # Should not raise cm.on_run_start(query="test") def test_ignore_agent_skips_handlers(self): class IgnoringHandler(RecordingHandler): ignore_agent = True handler = IgnoringHandler() cm = CallbackManager(handlers=[handler]) run_id = uuid4() cm.on_run_start(query="test") cm.on_agent_start(run_id, agent_id="solver") # Both on_run_start and on_agent_start should be skipped assert len(handler.calls) == 0 def test_multiple_handlers(self): h1 = RecordingHandler() h2 = RecordingHandler() cm = CallbackManager(handlers=[h1, h2]) cm.on_run_start(query="test") assert any(m == "on_run_start" for m, _ in h1.calls) assert any(m == "on_run_start" for m, _ in h2.calls) # ═══════════════════════════════════════════════════════════════ # AsyncCallbackManager # ═══════════════════════════════════════════════════════════════ class TestAsyncCallbackManager: def test_is_async(self): acm = AsyncCallbackManager() assert acm.is_async is True def test_init(self): h = RecordingHandler() acm = AsyncCallbackManager(handlers=[h]) assert h in acm.handlers def test_configure(self): h = RecordingHandler() acm = AsyncCallbackManager.configure(handlers=[h]) assert h in acm.handlers @pytest.mark.asyncio async def test_on_run_start_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) rid = await acm.on_run_start(query="test async") assert isinstance(rid, UUID) @pytest.mark.asyncio async def test_on_run_end_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_run_end(run_id, output="done") assert any(m == "on_run_end" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_agent_start_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_agent_start(run_id, agent_id="solver") assert any(m == "on_agent_start" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_agent_end_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_agent_end(run_id, agent_id="solver", output="result") assert any(m == "on_agent_end" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_plan_created_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_plan_created(run_id, num_steps=2, execution_order=["a", "b"]) assert any(m == "on_plan_created" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_agent_error_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_agent_error(run_id, ValueError("err"), agent_id="solver") assert any(m == "on_agent_error" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_retry_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_retry(run_id, agent_id="solver", attempt=1, max_attempts=3) assert any(m == "on_retry" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_llm_new_token_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_llm_new_token(run_id, "tok", agent_id="solver") assert any(m == "on_llm_new_token" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_topology_changed_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_topology_changed(run_id, reason="pruned", old_remaining=["a"], new_remaining=[]) assert any(m == "on_topology_changed" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_prune_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_prune(run_id, agent_id="solver", reason="low quality") assert any(m == "on_prune" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_fallback_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_fallback(run_id, failed_agent_id="solver", fallback_agent_id="backup") assert any(m == "on_fallback" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_parallel_start_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_parallel_start(run_id, agent_ids=["a", "b"]) assert any(m == "on_parallel_start" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_parallel_end_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_parallel_end(run_id, agent_ids=["a", "b"], successful=["a"], failed=["b"]) assert any(m == "on_parallel_end" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_memory_read_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_memory_read(run_id, agent_id="solver", keys=["ctx"]) assert any(m == "on_memory_read" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_memory_write_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_memory_write(run_id, agent_id="solver", key="result", value_size=128) assert any(m == "on_memory_write" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_budget_warning_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_budget_warning(run_id, budget_type="tokens", current=800, limit=1000) assert any(m == "on_budget_warning" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_budget_exceeded_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_budget_exceeded(run_id, budget_type="requests", current=10, limit=10) assert any(m == "on_budget_exceeded" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_tool_start_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_tool_start(run_id, agent_id="solver", tool_name="search") assert any(m == "on_tool_start" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_tool_end_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_tool_end(run_id, agent_id="solver", tool_name="search", success=True) assert any(m == "on_tool_end" for m, _ in handler.calls) @pytest.mark.asyncio async def test_on_tool_error_async(self): handler = RecordingHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_tool_error(run_id, agent_id="solver", tool_name="search", error_type="timeout") assert any(m == "on_tool_error" for m, _ in handler.calls) def test_copy(self): h = RecordingHandler() acm = AsyncCallbackManager(handlers=[h], tags=["t1"]) copy = acm.copy() assert h in copy.handlers assert copy is not acm def test_merge(self): h1 = RecordingHandler() h2 = RecordingHandler() acm1 = AsyncCallbackManager(handlers=[h1], tags=["t1"]) acm2 = AsyncCallbackManager(handlers=[h2], tags=["t2"]) merged = acm1.merge(acm2) assert h1 in merged.handlers assert h2 in merged.handlers def test_add_handler(self): acm = AsyncCallbackManager() h = RecordingHandler() acm.add_handler(h) assert h in acm.handlers def test_add_handler_inheritable(self): acm = AsyncCallbackManager() h = RecordingHandler() acm.add_handler(h, inherit=True) assert h in acm.inheritable_handlers def test_remove_handler(self): h = RecordingHandler() acm = AsyncCallbackManager(handlers=[h]) acm.remove_handler(h) assert h not in acm.handlers def test_set_handlers(self): h1 = RecordingHandler() h2 = RecordingHandler() acm = AsyncCallbackManager(handlers=[h1]) acm.set_handlers([h2]) assert acm.handlers == [h2] def test_add_tags(self): acm = AsyncCallbackManager() acm.add_tags(["t1", "t2"]) assert "t1" in acm.tags def test_add_tags_inheritable(self): acm = AsyncCallbackManager() acm.add_tags(["t1"], inherit=True) assert "t1" in acm.inheritable_tags def test_remove_tags(self): acm = AsyncCallbackManager(tags=["t1", "t2"]) acm.remove_tags(["t1"]) assert "t1" not in acm.tags def test_add_metadata(self): acm = AsyncCallbackManager() acm.add_metadata({"key": "val"}) assert acm.metadata["key"] == "val" def test_add_metadata_inheritable(self): acm = AsyncCallbackManager() acm.add_metadata({"key": "val"}, inherit=True) assert acm.inheritable_metadata["key"] == "val" def test_get_child(self): h = RecordingHandler() acm = AsyncCallbackManager(inheritable_handlers=[h], inheritable_tags=["t"]) parent_run_id = uuid4() child = acm.get_child(parent_run_id) assert h in child.handlers assert child.parent_run_id == parent_run_id @pytest.mark.asyncio async def test_ignore_memory_skips_handler(self): class MemoryIgnoringHandler(RecordingHandler): ignore_memory = True handler = MemoryIgnoringHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_memory_read(run_id, agent_id="solver") await acm.on_memory_write(run_id, agent_id="solver", key="k") # Memory events should be skipped assert not any(m in ("on_memory_read", "on_memory_write") for m, _ in handler.calls) @pytest.mark.asyncio async def test_ignore_budget_skips_handler(self): class BudgetIgnoringHandler(RecordingHandler): ignore_budget = True handler = BudgetIgnoringHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_budget_warning(run_id, budget_type="tokens", current=800, limit=1000) assert not any(m == "on_budget_warning" for m, _ in handler.calls) @pytest.mark.asyncio async def test_ignore_tool_skips_handler(self): class ToolIgnoringHandler(RecordingHandler): ignore_tool = True handler = ToolIgnoringHandler() acm = AsyncCallbackManager(handlers=[handler]) run_id = uuid4() await acm.on_tool_start(run_id, tool_name="search") assert not any(m == "on_tool_start" for m, _ in handler.calls) @pytest.mark.asyncio async def test_error_in_async_handler_is_handled(self): class ErrorAsyncHandler(RecordingHandler): raise_error = False def on_run_start(self, *, run_id, query, **kwargs): raise RuntimeError("async error") handler = ErrorAsyncHandler() acm = AsyncCallbackManager(handlers=[handler]) # Should not raise due to raise_error=False await acm.on_run_start(query="test") # ═══════════════════════════════════════════════════════════════ # Additional ignore_* branch coverage tests # ═══════════════════════════════════════════════════════════════ class TestCallbackManagerIgnoreBranchCoverage: """Tests that exercise the ignore_* continue branches in CallbackManager.""" def test_remove_handler_from_inheritable_handlers(self): """Line 113: remove from inheritable_handlers.""" h = RecordingHandler() cm = CallbackManager() cm.add_handler(h, inherit=True) assert h in cm.inheritable_handlers cm.remove_handler(h) assert h not in cm.inheritable_handlers def test_ignore_agent_skips_on_run_end(self): """Line 206: ignore_agent continue in on_run_end.""" class AgentIgnoringHandler(RecordingHandler): ignore_agent = True h = AgentIgnoringHandler() cm = CallbackManager(handlers=[h]) cm.on_run_end(uuid4(), output="done", success=True) assert not any(m == "on_run_end" for m, _ in h.calls) def test_ignore_agent_skips_on_agent_end(self): """Line 267: ignore_agent continue in on_agent_end.""" class AgentIgnoringHandler(RecordingHandler): ignore_agent = True h = AgentIgnoringHandler() cm = CallbackManager(handlers=[h]) cm.on_agent_end(uuid4(), agent_id="solver", output="result") assert not any(m == "on_agent_end" for m, _ in h.calls) def test_ignore_agent_skips_on_agent_error(self): """Line 298: ignore_agent continue in on_agent_error.""" class AgentIgnoringHandler(RecordingHandler): ignore_agent = True h = AgentIgnoringHandler() cm = CallbackManager(handlers=[h]) cm.on_agent_error(uuid4(), ValueError("err"), agent_id="solver") assert not any(m == "on_agent_error" for m, _ in h.calls) def test_ignore_budget_skips_on_budget_exceeded(self): """Line 600: ignore_budget continue in on_budget_exceeded.""" class BudgetIgnoringHandler(RecordingHandler): ignore_budget = True h = BudgetIgnoringHandler() cm = CallbackManager(handlers=[h]) cm.on_budget_exceeded(uuid4(), budget_type="requests", current=10, limit=10) assert not any(m == "on_budget_exceeded" for m, _ in h.calls) def test_ignore_tool_skips_on_tool_end(self): """Line 657: ignore_tool continue in on_tool_end.""" class ToolIgnoringHandler(RecordingHandler): ignore_tool = True h = ToolIgnoringHandler() cm = CallbackManager(handlers=[h]) cm.on_tool_end(uuid4(), agent_id="solver", tool_name="search", success=True) assert not any(m == "on_tool_end" for m, _ in h.calls) def test_ignore_tool_skips_on_tool_error(self): """Line 687: ignore_tool continue in on_tool_error.""" class ToolIgnoringHandler(RecordingHandler): ignore_tool = True h = ToolIgnoringHandler() cm = CallbackManager(handlers=[h]) cm.on_tool_error(uuid4(), agent_id="solver", tool_name="search", error_type="timeout", error_message="err") assert not any(m == "on_tool_error" for m, _ in h.calls) class TestAsyncCallbackManagerIgnoreBranchCoverage: """Tests that exercise the ignore_* continue branches in AsyncCallbackManager.""" def test_remove_handler_from_inheritable_handlers(self): """Line 782: remove from inheritable_handlers in AsyncCallbackManager.""" h = RecordingHandler() acm = AsyncCallbackManager() acm.add_handler(h, inherit=True) assert h in acm.inheritable_handlers acm.remove_handler(h) assert h not in acm.inheritable_handlers def test_handle_error_raises_when_raise_error(self): """Line 820: AsyncCallbackManager._handle_error raises if raise_error=True.""" h = RecordingHandler() h.raise_error = True acm = AsyncCallbackManager() err = RuntimeError("intentional raise") with pytest.raises(RuntimeError, match="intentional raise"): acm._handle_error(h, "on_run_start", err) @pytest.mark.asyncio async def test_ignore_agent_skips_on_run_start_async(self): """Line 856: ignore_agent continue in AsyncCallbackManager.on_run_start.""" class AgentIgnoringHandler(RecordingHandler): ignore_agent = True h = AgentIgnoringHandler() acm = AsyncCallbackManager(handlers=[h]) await acm.on_run_start(query="test") assert not any(m == "on_run_start" for m, _ in h.calls) @pytest.mark.asyncio async def test_ignore_agent_skips_on_run_end_async(self): """Line 890: ignore_agent continue in AsyncCallbackManager.on_run_end.""" class AgentIgnoringHandler(RecordingHandler): ignore_agent = True h = AgentIgnoringHandler() acm = AsyncCallbackManager(handlers=[h]) await acm.on_run_end(uuid4(), output="done") assert not any(m == "on_run_end" for m, _ in h.calls) @pytest.mark.asyncio async def test_ignore_agent_skips_on_agent_start_async(self): """Line 925: ignore_agent continue in AsyncCallbackManager.on_agent_start.""" class AgentIgnoringHandler(RecordingHandler): ignore_agent = True h = AgentIgnoringHandler() acm = AsyncCallbackManager(handlers=[h]) await acm.on_agent_start(uuid4(), agent_id="solver") assert not any(m == "on_agent_start" for m, _ in h.calls) @pytest.mark.asyncio async def test_ignore_agent_skips_on_agent_end_async(self): """Line 959: ignore_agent continue in AsyncCallbackManager.on_agent_end.""" class AgentIgnoringHandler(RecordingHandler): ignore_agent = True h = AgentIgnoringHandler() acm = AsyncCallbackManager(handlers=[h]) await acm.on_agent_end(uuid4(), agent_id="solver", output="result") assert not any(m == "on_agent_end" for m, _ in h.calls) @pytest.mark.asyncio async def test_ignore_agent_skips_on_agent_error_async(self): """Line 994: ignore_agent continue in AsyncCallbackManager.on_agent_error.""" class AgentIgnoringHandler(RecordingHandler): ignore_agent = True h = AgentIgnoringHandler() acm = AsyncCallbackManager(handlers=[h]) await acm.on_agent_error(uuid4(), ValueError("err"), agent_id="solver") assert not any(m == "on_agent_error" for m, _ in h.calls) @pytest.mark.asyncio async def test_ignore_retry_skips_on_retry_async(self): """Line 1029: ignore_retry continue in AsyncCallbackManager.on_retry.""" class RetryIgnoringHandler(RecordingHandler): ignore_retry = True h = RetryIgnoringHandler() acm = AsyncCallbackManager(handlers=[h]) await acm.on_retry(uuid4(), agent_id="solver", attempt=1, max_attempts=3) assert not any(m == "on_retry" for m, _ in h.calls) @pytest.mark.asyncio async def test_ignore_llm_skips_on_llm_new_token_async(self): """Line 1063: ignore_llm continue in AsyncCallbackManager.on_llm_new_token.""" class LLMIgnoringHandler(RecordingHandler): ignore_llm = True h = LLMIgnoringHandler() acm = AsyncCallbackManager(handlers=[h]) await acm.on_llm_new_token(uuid4(), "tok", agent_id="solver") assert not any(m == "on_llm_new_token" for m, _ in h.calls) @pytest.mark.asyncio async def test_ignore_budget_skips_on_budget_exceeded_async(self): """Line 1338: ignore_budget continue in AsyncCallbackManager.on_budget_exceeded.""" class BudgetIgnoringHandler(RecordingHandler): ignore_budget = True h = BudgetIgnoringHandler() acm = AsyncCallbackManager(handlers=[h]) await acm.on_budget_exceeded(uuid4(), budget_type="requests", current=10, limit=10) assert not any(m == "on_budget_exceeded" for m, _ in h.calls) @pytest.mark.asyncio async def test_ignore_tool_skips_on_tool_end_async(self): """Line 1403: ignore_tool continue in AsyncCallbackManager.on_tool_end.""" class ToolIgnoringHandler(RecordingHandler): ignore_tool = True h = ToolIgnoringHandler() acm = AsyncCallbackManager(handlers=[h]) await acm.on_tool_end(uuid4(), agent_id="solver", tool_name="search", success=True) assert not any(m == "on_tool_end" for m, _ in h.calls) @pytest.mark.asyncio async def test_ignore_tool_skips_on_tool_error_async(self): """Line 1437: ignore_tool continue in AsyncCallbackManager.on_tool_error.""" class ToolIgnoringHandler(RecordingHandler): ignore_tool = True h = ToolIgnoringHandler() acm = AsyncCallbackManager(handlers=[h]) await acm.on_tool_error(uuid4(), agent_id="solver", tool_name="search", error_type="timeout", error_message="err") assert not any(m == "on_tool_error" for m, _ in h.calls)