| """Tests for planner error recovery and retry limits.""" |
|
|
| import asyncio |
| import json |
| from unittest.mock import AsyncMock, MagicMock, patch |
|
|
| from fireaction import load_provider |
|
|
| from fireaction_a2a.planner import MAX_RETRIES_PER_CONTRACT, Planner, PlanStep |
| from fireaction_a2a.runner import ContractError |
| from fireaction_a2a.search import ContractSearch |
|
|
|
|
| def _make_planner() -> Planner: |
| provider = load_provider("resend") |
| mock_client = AsyncMock() |
| mock_client.get_context = MagicMock(return_value={}) |
| mock_client.get_base_url = MagicMock(return_value="https://api.resend.com") |
| search = ContractSearch(provider, embedding_model="text-embedding-3-small") |
| return Planner( |
| provider=provider, |
| provider_client=mock_client, |
| contract_search=search, |
| llm_model="gpt-4o", |
| system_prompt="Test", |
| provider_name="resend", |
| ) |
|
|
|
|
| def test_dispatch_search(): |
| planner = _make_planner() |
| result = asyncio.run( |
| planner._dispatch_tool("search_contracts", {"query": "email"}, {}) |
| ) |
| assert isinstance(result, list) |
| assert len(result) > 0 |
| assert "name" in result[0] |
|
|
|
|
| def test_dispatch_get_details(): |
| planner = _make_planner() |
| result = asyncio.run( |
| planner._dispatch_tool("get_contract_details", {"contract_name": "send_email"}, {}) |
| ) |
| assert isinstance(result, dict) |
| assert "schema_nodes" in result |
| assert "rules" in result |
| assert "properties" in result |
| assert len(result["schema_nodes"]) > 0 |
|
|
|
|
| def test_dispatch_get_details_unknown(): |
| planner = _make_planner() |
| result = asyncio.run( |
| planner._dispatch_tool("get_contract_details", {"contract_name": "nonexistent"}, {}) |
| ) |
| assert isinstance(result, dict) |
| assert result["error"] is True |
|
|
|
|
| def test_dispatch_execute_contract_error_returns_dict(): |
| """Contract validation error should return error dict (not raise) for LLM retry.""" |
| planner = _make_planner() |
| retry_counts: dict[str, int] = {} |
|
|
| bad_nodes = [ |
| {"instance_key": "root", "element_key": "nonexistent", |
| "variant_key": "send_email_skeleton", "data_type": "object", |
| "compile_key": "root", "description": "bad", "index": 0, |
| "parent_instance_key": None}, |
| ] |
| result = asyncio.run( |
| planner._dispatch_tool( |
| "execute_contract", |
| {"contract_name": "send_email", "instance_nodes": bad_nodes}, |
| retry_counts, |
| ) |
| ) |
| assert isinstance(result, dict) |
| assert result["error"] is True |
| assert result["stage"] == "instantiate" |
| assert retry_counts["send_email"] == 1 |
|
|
|
|
| def test_dispatch_execute_exceeds_retry_limit(): |
| """After MAX_RETRIES_PER_CONTRACT failures, return a PlanStep error.""" |
| planner = _make_planner() |
| retry_counts = {"send_email": MAX_RETRIES_PER_CONTRACT} |
|
|
| bad_nodes = [ |
| {"instance_key": "root", "element_key": "nonexistent", |
| "variant_key": "send_email_skeleton", "data_type": "object", |
| "compile_key": "root", "description": "bad", "index": 0, |
| "parent_instance_key": None}, |
| ] |
| result = asyncio.run( |
| planner._dispatch_tool( |
| "execute_contract", |
| {"contract_name": "send_email", "instance_nodes": bad_nodes}, |
| retry_counts, |
| ) |
| ) |
| assert isinstance(result, PlanStep) |
| assert result.type == "error" |
| assert "failed validation" in result.message.lower() or "failed" in result.message.lower() |
|
|
|
|
| def test_dispatch_execute_contract_with_trace(): |
| """Successful execute_contract with trace_enabled populates trace_contracts.""" |
| planner = _make_planner() |
| planner.provider_client.call.return_value = {"id": "email_traced"} |
| retry_counts: dict[str, int] = {} |
| trace_contracts: list[dict] = [] |
|
|
| from tests.test_runner import _minimal_email_nodes |
|
|
| result = asyncio.run( |
| planner._dispatch_tool( |
| "execute_contract", |
| {"contract_name": "send_email", "instance_nodes": _minimal_email_nodes()}, |
| retry_counts, |
| trace_enabled=True, |
| trace_contracts=trace_contracts, |
| ) |
| ) |
| assert isinstance(result, dict) |
| assert result == {"id": "email_traced"} |
| assert len(trace_contracts) == 1 |
|
|
| entry = trace_contracts[0] |
| assert entry["provider"] == "resend" |
| assert entry["action"] == "send_email" |
| assert entry["validation"]["verify_passed"] is True |
| assert entry["api_call"] is not None |
| assert entry["api_call"]["response_body"] == {"id": "email_traced"} |
|
|
|
|
| def test_dispatch_execute_contract_error_with_trace(): |
| """Failed execute_contract with trace_enabled records partial trace.""" |
| planner = _make_planner() |
| retry_counts: dict[str, int] = {} |
| trace_contracts: list[dict] = [] |
|
|
| bad_nodes = [ |
| {"instance_key": "root", "element_key": "nonexistent", |
| "variant_key": "send_email_skeleton", "data_type": "object", |
| "compile_key": "root", "description": "bad", "index": 0, |
| "parent_instance_key": None}, |
| ] |
| result = asyncio.run( |
| planner._dispatch_tool( |
| "execute_contract", |
| {"contract_name": "send_email", "instance_nodes": bad_nodes}, |
| retry_counts, |
| trace_enabled=True, |
| trace_contracts=trace_contracts, |
| ) |
| ) |
| assert isinstance(result, dict) |
| assert result["error"] is True |
| assert len(trace_contracts) == 1 |
|
|
| entry = trace_contracts[0] |
| assert entry["provider"] == "resend" |
| assert entry["action"] == "send_email" |
| assert entry["api_call"] is None |
|
|
|
|
| def test_classify_response(): |
| assert Planner._classify_response("Done! Email sent successfully.") == "completed" |
| assert Planner._classify_response("Please specify the recipient.") == "input_required" |
| assert Planner._classify_response("Could you provide the subject?") == "input_required" |
| assert Planner._classify_response("Error: API returned 500") == "error" |
|
|
|
|
| |
|
|
|
|
| def _mock_llm_response(*, tool_calls=None, content=None, finish_reason="stop", |
| prompt_tokens=100, completion_tokens=50): |
| """Build a mock litellm response.""" |
| resp = MagicMock() |
| msg = MagicMock() |
| msg.content = content |
| msg.tool_calls = tool_calls |
| msg.model_dump.return_value = { |
| "role": "assistant", |
| "content": content, |
| **({"tool_calls": [ |
| {"id": tc.id, "type": "function", |
| "function": {"name": tc.function.name, "arguments": tc.function.arguments}} |
| for tc in tool_calls |
| ]} if tool_calls else {}), |
| } |
| choice = MagicMock() |
| choice.message = msg |
| choice.finish_reason = finish_reason |
| resp.choices = [choice] |
| resp.usage = MagicMock() |
| resp.usage.prompt_tokens = prompt_tokens |
| resp.usage.completion_tokens = completion_tokens |
| resp.usage.total_tokens = prompt_tokens + completion_tokens |
| return resp |
|
|
|
|
| def _mock_tool_call(name, arguments, call_id="call_1"): |
| tc = MagicMock() |
| tc.id = call_id |
| tc.function.name = name |
| tc.function.arguments = json.dumps(arguments) |
| return tc |
|
|
|
|
| def test_planner_run_trace_accumulates_steps_and_cost(): |
| """Full run() with trace collects planner_steps and cost_metrics.""" |
| planner = _make_planner() |
|
|
| search_tc = _mock_tool_call("search_contracts", {"query": "email"}, "call_1") |
| resp1 = _mock_llm_response( |
| tool_calls=[search_tc], finish_reason="tool_calls", |
| prompt_tokens=100, completion_tokens=50, |
| ) |
| resp2 = _mock_llm_response( |
| content="Done! Email sent successfully.", |
| prompt_tokens=200, completion_tokens=30, |
| ) |
|
|
| steps: list[PlanStep] = [] |
|
|
| async def _run(): |
| with patch("fireaction_a2a.planner.litellm") as mock_litellm: |
| mock_litellm.acompletion = AsyncMock(side_effect=[resp1, resp2]) |
| async for step in planner.run("send an email", [], trace_enabled=True): |
| steps.append(step) |
|
|
| asyncio.run(_run()) |
|
|
| final = steps[-1] |
| assert final.type == "completed" |
| assert final.trace_data is not None |
|
|
| trace = final.trace_data |
| assert "execution_trace" in trace |
| assert "cost_metrics" in trace |
|
|
| assert len(trace["execution_trace"]["planner_steps"]) == 1 |
| assert trace["execution_trace"]["planner_steps"][0]["tool"] == "search_contracts" |
| assert trace["execution_trace"]["total_planner_steps"] == 1 |
|
|
| cost = trace["cost_metrics"] |
| assert cost["num_llm_calls"] == 2 |
| assert cost["prompt_tokens"] == 300 |
| assert cost["completion_tokens"] == 80 |
| assert cost["total_tokens"] == 380 |
| assert cost["total_duration_ms"] >= 0 |
|
|
|
|
| def test_planner_run_without_trace_returns_none(): |
| """run() with trace_enabled=False yields PlanStep with trace_data=None.""" |
| planner = _make_planner() |
|
|
| resp = _mock_llm_response(content="Done! Email sent.") |
|
|
| steps: list[PlanStep] = [] |
|
|
| async def _run(): |
| with patch("fireaction_a2a.planner.litellm") as mock_litellm: |
| mock_litellm.acompletion = AsyncMock(return_value=resp) |
| async for step in planner.run("send email", [], trace_enabled=False): |
| steps.append(step) |
|
|
| asyncio.run(_run()) |
|
|
| assert steps[-1].type == "completed" |
| assert steps[-1].trace_data is None |
|
|