fireaction-a2a / tests /test_agent.py
zequn-fireworks's picture
Add structured execution trace to A2A task responses
d5f3acc
"""Tests for ProviderAgentExecutor: trace flag resolution and artifact emission."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from fireaction_a2a.agent import ProviderAgentExecutor
from fireaction_a2a.planner import PlanStep
# ---- _resolve_trace_flag tests ----
def _make_executor(*, trace_default: bool = False) -> ProviderAgentExecutor:
mock_planner = MagicMock()
return ProviderAgentExecutor(mock_planner, trace_enabled_default=trace_default)
def _make_context(*, task_metadata=None, message_metadata=None, has_task=True):
ctx = MagicMock()
if has_task:
ctx.current_task = MagicMock()
ctx.current_task.metadata = task_metadata
else:
ctx.current_task = None
ctx.message = MagicMock()
ctx.message.metadata = message_metadata
return ctx
def test_resolve_trace_flag_default_false():
executor = _make_executor(trace_default=False)
ctx = _make_context(task_metadata={})
assert executor._resolve_trace_flag(ctx) is False
def test_resolve_trace_flag_default_true():
executor = _make_executor(trace_default=True)
ctx = _make_context(task_metadata={})
assert executor._resolve_trace_flag(ctx) is True
def test_resolve_trace_flag_task_metadata_overrides_default():
executor = _make_executor(trace_default=False)
ctx = _make_context(task_metadata={"include_trace": True})
assert executor._resolve_trace_flag(ctx) is True
def test_resolve_trace_flag_task_metadata_disables():
executor = _make_executor(trace_default=True)
ctx = _make_context(task_metadata={"include_trace": False})
assert executor._resolve_trace_flag(ctx) is False
def test_resolve_trace_flag_message_metadata_fallback():
executor = _make_executor(trace_default=False)
ctx = _make_context(has_task=False, message_metadata={"include_trace": True})
assert executor._resolve_trace_flag(ctx) is True
def test_resolve_trace_flag_no_metadata():
executor = _make_executor(trace_default=False)
ctx = _make_context(task_metadata=None)
assert executor._resolve_trace_flag(ctx) is False
# ---- execute() artifact emission tests ----
def test_execute_emits_trace_artifact_when_trace_data_present():
"""When planner yields a completed step with trace_data, agent emits DataPart."""
mock_planner = MagicMock()
trace_payload = {
"execution_trace": {"contracts_called": [], "planner_steps": [],
"total_api_calls": 0, "total_planner_steps": 0},
"cost_metrics": {"total_tokens": 100, "prompt_tokens": 60,
"completion_tokens": 40, "num_llm_calls": 1,
"total_duration_ms": 500},
}
async def mock_run(user_msg, history, *, trace_enabled=False):
yield PlanStep(type="completed", message="Done!", trace_data=trace_payload)
mock_planner.run = mock_run
mock_planner.get_messages.return_value = []
executor = ProviderAgentExecutor(mock_planner, trace_enabled_default=True)
context = MagicMock()
context.get_user_input.return_value = "send email"
context.current_task = MagicMock()
context.current_task.metadata = {}
context.current_task.id = "task_1"
context.current_task.context_id = "ctx_1"
context.message = MagicMock()
event_queue = MagicMock()
event_queue.enqueue_event = AsyncMock()
async def _run():
await executor.execute(context, event_queue)
asyncio.run(_run())
calls = event_queue.enqueue_event.call_args_list
artifacts = [
c for c in calls
if hasattr(c[0][0], "parts") or "artifact" in str(type(c[0][0])).lower()
]
assert len(calls) >= 3
def test_execute_skips_trace_artifact_when_trace_disabled():
"""When trace_enabled is False, no DataPart artifact is emitted."""
mock_planner = MagicMock()
async def mock_run(user_msg, history, *, trace_enabled=False):
yield PlanStep(type="completed", message="Done!", trace_data=None)
mock_planner.run = mock_run
mock_planner.get_messages.return_value = []
executor = ProviderAgentExecutor(mock_planner, trace_enabled_default=False)
context = MagicMock()
context.get_user_input.return_value = "send email"
context.current_task = MagicMock()
context.current_task.metadata = {}
context.current_task.id = "task_2"
context.current_task.context_id = "ctx_2"
context.message = MagicMock()
event_queue = MagicMock()
event_queue.enqueue_event = AsyncMock()
async def _run():
await executor.execute(context, event_queue)
asyncio.run(_run())
all_events = [c[0][0] for c in event_queue.enqueue_event.call_args_list]
data_part_events = [
e for e in all_events
if hasattr(e, "parts") and any(
getattr(getattr(p, "root", None), "kind", None) == "data"
for p in getattr(e, "parts", [])
)
]
assert len(data_part_events) == 0