gMAS / tests /test_callbacks.py
Артём Боярских
chore: initial commit
3193174
from dataclasses import dataclass, field
from typing import Any
from uuid import UUID, uuid4
import pytest
from callbacks import (
AsyncCallbackHandler,
AsyncCallbackManager,
BaseCallbackHandler,
CallbackManager,
get_callback_manager,
set_callback_manager,
trace_as_callback,
)
from callbacks.context import (
collect_metrics,
configure_async_callbacks,
configure_callbacks,
)
@dataclass
class Call:
name: str
kwargs: dict[str, Any] = field(default_factory=dict)
class RecordingHandler(BaseCallbackHandler):
def __init__(self) -> None:
self.calls: list[Call] = []
def on_run_start(
self,
*,
run_id: UUID,
query: str,
num_agents: int = 0,
execution_order: list[str] | None = None,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.calls.append(
Call(
"on_run_start",
{
"run_id": run_id,
"query": query,
"num_agents": num_agents,
"execution_order": execution_order,
"parent_run_id": parent_run_id,
"tags": tags,
"metadata": metadata,
**kwargs,
},
)
)
def on_retry(
self,
*,
run_id: UUID,
agent_id: str,
attempt: int,
max_attempts: int = 0,
delay_ms: float = 0.0,
error: str = "",
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
self.calls.append(
Call(
"on_retry",
{
"run_id": run_id,
"agent_id": agent_id,
"attempt": attempt,
"max_attempts": max_attempts,
"delay_ms": delay_ms,
"error": error,
"parent_run_id": parent_run_id,
**kwargs,
},
)
)
def on_llm_new_token(
self,
token: str,
*,
run_id: UUID,
agent_id: str,
token_index: int = 0,
is_first: bool = False,
is_last: bool = False,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
self.calls.append(
Call(
"on_llm_new_token",
{
"token": token,
"run_id": run_id,
"agent_id": agent_id,
"token_index": token_index,
"is_first": is_first,
"is_last": is_last,
"parent_run_id": parent_run_id,
**kwargs,
},
)
)
def on_memory_read(
self,
*,
run_id: UUID,
agent_id: str,
entries_count: int = 0,
keys: list[str] | None = None,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
self.calls.append(
Call(
"on_memory_read",
{
"run_id": run_id,
"agent_id": agent_id,
"entries_count": entries_count,
"keys": keys,
"parent_run_id": parent_run_id,
**kwargs,
},
)
)
def on_budget_warning(
self,
*,
run_id: UUID,
budget_type: str,
current: float,
limit: float,
ratio: float = 0.0,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
self.calls.append(
Call(
"on_budget_warning",
{
"run_id": run_id,
"budget_type": budget_type,
"current": current,
"limit": limit,
"ratio": ratio,
"parent_run_id": parent_run_id,
**kwargs,
},
)
)
class RaisingHandler(BaseCallbackHandler):
def __init__(self, *, raise_error: bool) -> None:
self.raise_error = raise_error
def on_run_start(self, **_: Any) -> None:
msg = "boom"
raise RuntimeError(msg)
class AsyncRecordingHandler(AsyncCallbackHandler):
def __init__(self) -> None:
self.calls: list[Call] = []
async def on_retry(
self,
*,
run_id: UUID,
agent_id: str,
attempt: int,
max_attempts: int = 0,
delay_ms: float = 0.0,
error: str = "",
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
self.calls.append(
Call(
"on_retry",
{
"run_id": run_id,
"agent_id": agent_id,
"attempt": attempt,
"max_attempts": max_attempts,
"delay_ms": delay_ms,
"error": error,
"parent_run_id": parent_run_id,
**kwargs,
},
)
)
def test_trace_as_callback_sets_and_resets_context_manager() -> None:
set_callback_manager(None)
assert get_callback_manager() is None
with trace_as_callback(tags=["outer"], metadata={"x": 1}) as outer:
assert get_callback_manager() is outer
assert outer.tags == ["outer"]
assert outer.metadata == {"x": 1}
with trace_as_callback(tags=["inner"]) as inner:
assert get_callback_manager() is inner
assert inner.tags == ["inner"]
assert get_callback_manager() is outer
assert get_callback_manager() is None
def test_callback_manager_on_run_start_passes_parent_tags_metadata() -> None:
h = RecordingHandler()
parent = uuid4()
m = CallbackManager.configure(handlers=[h], tags=["t1"], metadata={"k": "v"})
m.parent_run_id = parent
run_id = m.on_run_start(query="hi", num_agents=2)
assert len(h.calls) == 1
call = h.calls[0]
assert call.name == "on_run_start"
assert call.kwargs["run_id"] == run_id
assert call.kwargs["query"] == "hi"
assert call.kwargs["num_agents"] == 2
assert call.kwargs["execution_order"] == []
assert call.kwargs["parent_run_id"] == parent
assert call.kwargs["tags"] == ["t1"]
assert call.kwargs["metadata"] == {"k": "v"}
def test_callback_manager_ignore_flags_filter_events() -> None:
run_id = uuid4()
h_ok = RecordingHandler()
h_ignore_retry = RecordingHandler()
h_ignore_retry.ignore_retry = True
h_ignore_llm = RecordingHandler()
h_ignore_llm.ignore_llm = True
h_ignore_memory = RecordingHandler()
h_ignore_memory.ignore_memory = True
h_ignore_budget = RecordingHandler()
h_ignore_budget.ignore_budget = True
m = CallbackManager(handlers=[h_ok, h_ignore_retry, h_ignore_llm, h_ignore_memory, h_ignore_budget])
m.on_retry(run_id, agent_id="a1", attempt=1, max_attempts=3, delay_ms=10.0, error="e")
m.on_llm_new_token(run_id, "tok", agent_id="a1", token_index=0, is_first=True, is_last=False)
m.on_memory_read(run_id, agent_id="a1", entries_count=2, keys=["k1", "k2"])
m.on_budget_warning(run_id, budget_type="tokens", current=90.0, limit=100.0, ratio=0.9)
assert [c.name for c in h_ok.calls] == [
"on_retry",
"on_llm_new_token",
"on_memory_read",
"on_budget_warning",
]
assert [c.name for c in h_ignore_retry.calls] == [
"on_llm_new_token",
"on_memory_read",
"on_budget_warning",
]
assert [c.name for c in h_ignore_llm.calls] == [
"on_retry",
"on_memory_read",
"on_budget_warning",
]
assert [c.name for c in h_ignore_memory.calls] == [
"on_retry",
"on_llm_new_token",
"on_budget_warning",
]
assert [c.name for c in h_ignore_budget.calls] == [
"on_retry",
"on_llm_new_token",
"on_memory_read",
]
def test_callback_manager_error_in_one_handler_does_not_block_others_when_raise_error_false() -> None:
bad = RaisingHandler(raise_error=False)
good = RecordingHandler()
m = CallbackManager(handlers=[bad, good])
run_id = m.on_run_start(query="q")
assert len(good.calls) == 1
assert good.calls[0].name == "on_run_start"
assert good.calls[0].kwargs["run_id"] == run_id
def test_callback_manager_error_propagates_when_raise_error_true() -> None:
bad = RaisingHandler(raise_error=True)
good = RecordingHandler()
m = CallbackManager(handlers=[bad, good])
with pytest.raises(RuntimeError, match="boom"):
m.on_run_start(query="q")
assert good.calls == []
def test_callback_manager_get_child_inherits_only_inheritable_settings() -> None:
inheritable = RecordingHandler()
non_inheritable = RecordingHandler()
m = CallbackManager()
m.add_handler(inheritable, inherit=True)
m.add_handler(non_inheritable, inherit=False)
m.add_tags(["tag_inherit"], inherit=True)
m.add_tags(["tag_local"], inherit=False)
m.add_metadata({"a": 1}, inherit=True)
m.add_metadata({"b": 2}, inherit=False)
parent_run_id = uuid4()
child = m.get_child(parent_run_id)
assert child.parent_run_id == parent_run_id
assert child.handlers == [inheritable]
assert child.tags == ["tag_inherit"]
assert child.metadata == {"a": 1}
run_id = child.on_run_start(query="x")
assert inheritable.calls[-1].kwargs["parent_run_id"] == parent_run_id
assert inheritable.calls[-1].kwargs["run_id"] == run_id
@pytest.mark.asyncio
async def test_async_callback_manager_runs_sync_and_async_handlers() -> None:
sync_h = RecordingHandler()
async_h = AsyncRecordingHandler()
m = AsyncCallbackManager(handlers=[sync_h, async_h])
run_id = uuid4()
await m.on_retry(run_id, agent_id="a1", attempt=2, max_attempts=5, delay_ms=1.0, error="x")
assert [c.name for c in sync_h.calls] == ["on_retry"]
assert [c.name for c in async_h.calls] == ["on_retry"]
# ═══════════════════════════════════════════════════════════════
# AsyncCallbackHandler β€” direct method calls (coverage)
# ═══════════════════════════════════════════════════════════════
@pytest.mark.asyncio
async def test_async_callback_handler_all_pass_methods() -> None:
"""Call all async pass methods on AsyncCallbackHandler directly."""
h = AsyncCallbackHandler()
run_id = uuid4()
agent_id = "agent_a"
assert h.is_async is True
# Run lifecycle
await h.on_run_start(run_id=run_id, query="q", num_agents=1)
await h.on_run_end(run_id=run_id, output="out", success=True)
# Agent lifecycle
await h.on_agent_start(run_id=run_id, agent_id=agent_id)
await h.on_agent_end(run_id=run_id, agent_id=agent_id, output="resp")
await h.on_agent_error(ValueError("err"), run_id=run_id, agent_id=agent_id)
# Retry
await h.on_retry(run_id=run_id, agent_id=agent_id, attempt=1)
# Streaming
await h.on_llm_new_token("tok", run_id=run_id, agent_id=agent_id)
# Planning
await h.on_plan_created(run_id=run_id, num_steps=2, execution_order=["a", "b"])
await h.on_topology_changed(
run_id=run_id,
reason="test",
old_remaining=["a"],
new_remaining=["b"],
)
# Pruning/Fallback
await h.on_prune(run_id=run_id, agent_id=agent_id, reason="pruned")
await h.on_fallback(
run_id=run_id,
failed_agent_id="a",
fallback_agent_id="b",
)
# Parallel
await h.on_parallel_start(run_id=run_id, agent_ids=["a", "b"])
await h.on_parallel_end(run_id=run_id, agent_ids=["a", "b"])
# Memory
await h.on_memory_read(run_id=run_id, agent_id=agent_id)
await h.on_memory_write(run_id=run_id, agent_id=agent_id, key="key")
# Budget
await h.on_budget_warning(run_id=run_id, budget_type="tokens", current=90.0, limit=100.0)
await h.on_budget_exceeded(run_id=run_id, budget_type="tokens", current=105.0, limit=100.0)
# Tools
await h.on_tool_start(run_id=run_id, tool_name="test_tool")
await h.on_tool_end(run_id=run_id, tool_name="test_tool")
await h.on_tool_error(run_id=run_id, tool_name="test_tool")
def test_base_callback_handler_is_async_false() -> None:
"""BaseCallbackHandler.is_async should be False."""
from callbacks.base import CallbackHandlerMixin
handler = BaseCallbackHandler()
assert handler.is_async is False
# Also test mixin directly
mixin = CallbackHandlerMixin()
assert mixin.is_async is False
# ═══════════════════════════════════════════════════════════════
# callbacks/context.py β€” collect_metrics, configure_callbacks
# ═══════════════════════════════════════════════════════════════
def test_collect_metrics_context_manager() -> None:
"""collect_metrics should provide a MetricsCallbackHandler."""
from callbacks.handlers.metrics import MetricsCallbackHandler
with collect_metrics() as metrics:
assert isinstance(metrics, MetricsCallbackHandler)
# The context should have the callback manager set
cm = get_callback_manager()
assert cm is not None
def test_collect_metrics_resets_after_exit() -> None:
"""collect_metrics should reset callback manager after exit."""
set_callback_manager(None)
with collect_metrics():
assert get_callback_manager() is not None
assert get_callback_manager() is None
def test_configure_callbacks_returns_manager() -> None:
"""configure_callbacks should return a CallbackManager."""
manager = configure_callbacks(tags=["test"])
assert isinstance(manager, CallbackManager)
assert "test" in manager.tags
def test_configure_callbacks_with_handlers() -> None:
"""configure_callbacks with handlers."""
h = BaseCallbackHandler()
manager = configure_callbacks(handlers=[h])
assert h in manager.handlers
def test_configure_async_callbacks_returns_manager() -> None:
"""configure_async_callbacks should return an AsyncCallbackManager."""
manager = configure_async_callbacks(tags=["async-test"])
assert isinstance(manager, AsyncCallbackManager)
assert "async-test" in manager.tags