| """Tests for ProviderAgentExecutor: trace flag resolution and artifact emission.""" |
|
|
| import asyncio |
| from unittest.mock import AsyncMock, MagicMock, patch |
|
|
| from fireaction_a2a.agent import ProviderAgentExecutor |
| from fireaction_a2a.planner import PlanStep |
|
|
|
|
| |
|
|
|
|
| def _make_executor(*, trace_default: bool = False) -> ProviderAgentExecutor: |
| mock_planner = MagicMock() |
| return ProviderAgentExecutor(mock_planner, trace_enabled_default=trace_default) |
|
|
|
|
| def _make_context(*, task_metadata=None, message_metadata=None, has_task=True): |
| ctx = MagicMock() |
| if has_task: |
| ctx.current_task = MagicMock() |
| ctx.current_task.metadata = task_metadata |
| else: |
| ctx.current_task = None |
| ctx.message = MagicMock() |
| ctx.message.metadata = message_metadata |
| return ctx |
|
|
|
|
| def test_resolve_trace_flag_default_false(): |
| executor = _make_executor(trace_default=False) |
| ctx = _make_context(task_metadata={}) |
| assert executor._resolve_trace_flag(ctx) is False |
|
|
|
|
| def test_resolve_trace_flag_default_true(): |
| executor = _make_executor(trace_default=True) |
| ctx = _make_context(task_metadata={}) |
| assert executor._resolve_trace_flag(ctx) is True |
|
|
|
|
| def test_resolve_trace_flag_task_metadata_overrides_default(): |
| executor = _make_executor(trace_default=False) |
| ctx = _make_context(task_metadata={"include_trace": True}) |
| assert executor._resolve_trace_flag(ctx) is True |
|
|
|
|
| def test_resolve_trace_flag_task_metadata_disables(): |
| executor = _make_executor(trace_default=True) |
| ctx = _make_context(task_metadata={"include_trace": False}) |
| assert executor._resolve_trace_flag(ctx) is False |
|
|
|
|
| def test_resolve_trace_flag_message_metadata_fallback(): |
| executor = _make_executor(trace_default=False) |
| ctx = _make_context(has_task=False, message_metadata={"include_trace": True}) |
| assert executor._resolve_trace_flag(ctx) is True |
|
|
|
|
| def test_resolve_trace_flag_no_metadata(): |
| executor = _make_executor(trace_default=False) |
| ctx = _make_context(task_metadata=None) |
| assert executor._resolve_trace_flag(ctx) is False |
|
|
|
|
| |
|
|
|
|
| def test_execute_emits_trace_artifact_when_trace_data_present(): |
| """When planner yields a completed step with trace_data, agent emits DataPart.""" |
| mock_planner = MagicMock() |
| trace_payload = { |
| "execution_trace": {"contracts_called": [], "planner_steps": [], |
| "total_api_calls": 0, "total_planner_steps": 0}, |
| "cost_metrics": {"total_tokens": 100, "prompt_tokens": 60, |
| "completion_tokens": 40, "num_llm_calls": 1, |
| "total_duration_ms": 500}, |
| } |
|
|
| async def mock_run(user_msg, history, *, trace_enabled=False): |
| yield PlanStep(type="completed", message="Done!", trace_data=trace_payload) |
|
|
| mock_planner.run = mock_run |
| mock_planner.get_messages.return_value = [] |
|
|
| executor = ProviderAgentExecutor(mock_planner, trace_enabled_default=True) |
|
|
| context = MagicMock() |
| context.get_user_input.return_value = "send email" |
| context.current_task = MagicMock() |
| context.current_task.metadata = {} |
| context.current_task.id = "task_1" |
| context.current_task.context_id = "ctx_1" |
| context.message = MagicMock() |
|
|
| event_queue = MagicMock() |
| event_queue.enqueue_event = AsyncMock() |
|
|
| async def _run(): |
| await executor.execute(context, event_queue) |
|
|
| asyncio.run(_run()) |
|
|
| calls = event_queue.enqueue_event.call_args_list |
| artifacts = [ |
| c for c in calls |
| if hasattr(c[0][0], "parts") or "artifact" in str(type(c[0][0])).lower() |
| ] |
| assert len(calls) >= 3 |
|
|
|
|
| def test_execute_skips_trace_artifact_when_trace_disabled(): |
| """When trace_enabled is False, no DataPart artifact is emitted.""" |
| mock_planner = MagicMock() |
|
|
| async def mock_run(user_msg, history, *, trace_enabled=False): |
| yield PlanStep(type="completed", message="Done!", trace_data=None) |
|
|
| mock_planner.run = mock_run |
| mock_planner.get_messages.return_value = [] |
|
|
| executor = ProviderAgentExecutor(mock_planner, trace_enabled_default=False) |
|
|
| context = MagicMock() |
| context.get_user_input.return_value = "send email" |
| context.current_task = MagicMock() |
| context.current_task.metadata = {} |
| context.current_task.id = "task_2" |
| context.current_task.context_id = "ctx_2" |
| context.message = MagicMock() |
|
|
| event_queue = MagicMock() |
| event_queue.enqueue_event = AsyncMock() |
|
|
| async def _run(): |
| await executor.execute(context, event_queue) |
|
|
| asyncio.run(_run()) |
|
|
| all_events = [c[0][0] for c in event_queue.enqueue_event.call_args_list] |
| data_part_events = [ |
| e for e in all_events |
| if hasattr(e, "parts") and any( |
| getattr(getattr(p, "root", None), "kind", None) == "data" |
| for p in getattr(e, "parts", []) |
| ) |
| ] |
| assert len(data_part_events) == 0 |
|
|