fireaction-a2a / tests /test_planner_retry.py
zequn-fireworks's picture
Upgrade to hierarchical schema and build_request() API
4d52c8a
"""Tests for planner error recovery and retry limits."""
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock, patch
from fireaction import load_provider
from fireaction_a2a.planner import MAX_RETRIES_PER_CONTRACT, Planner, PlanStep
from fireaction_a2a.runner import ContractError
from fireaction_a2a.search import ContractSearch
def _make_planner() -> Planner:
provider = load_provider("resend")
mock_client = AsyncMock()
mock_client.get_context = MagicMock(return_value={})
mock_client.get_base_url = MagicMock(return_value="https://api.resend.com")
search = ContractSearch(provider, embedding_model="text-embedding-3-small")
return Planner(
provider=provider,
provider_client=mock_client,
contract_search=search,
llm_model="gpt-4o",
system_prompt="Test",
provider_name="resend",
)
def test_dispatch_search():
planner = _make_planner()
result = asyncio.run(
planner._dispatch_tool("search_contracts", {"query": "email"}, {})
)
assert isinstance(result, list)
assert len(result) > 0
assert "name" in result[0]
def test_dispatch_get_details():
planner = _make_planner()
result = asyncio.run(
planner._dispatch_tool("get_contract_details", {"contract_name": "send_email"}, {})
)
assert isinstance(result, dict)
assert "schema_nodes" in result
assert "rules" in result
assert "properties" in result
assert len(result["schema_nodes"]) > 0
def test_dispatch_get_details_unknown():
planner = _make_planner()
result = asyncio.run(
planner._dispatch_tool("get_contract_details", {"contract_name": "nonexistent"}, {})
)
assert isinstance(result, dict)
assert result["error"] is True
def test_dispatch_execute_contract_error_returns_dict():
"""Contract validation error should return error dict (not raise) for LLM retry."""
planner = _make_planner()
retry_counts: dict[str, int] = {}
bad_nodes = [
{"instance_key": "root", "element_key": "nonexistent",
"variant_key": "send_email_skeleton", "data_type": "object",
"compile_key": "root", "description": "bad", "index": 0,
"parent_instance_key": None},
]
result = asyncio.run(
planner._dispatch_tool(
"execute_contract",
{"contract_name": "send_email", "instance_nodes": bad_nodes},
retry_counts,
)
)
assert isinstance(result, dict)
assert result["error"] is True
assert result["stage"] == "instantiate"
assert retry_counts["send_email"] == 1
def test_dispatch_execute_exceeds_retry_limit():
"""After MAX_RETRIES_PER_CONTRACT failures, return a PlanStep error."""
planner = _make_planner()
retry_counts = {"send_email": MAX_RETRIES_PER_CONTRACT}
bad_nodes = [
{"instance_key": "root", "element_key": "nonexistent",
"variant_key": "send_email_skeleton", "data_type": "object",
"compile_key": "root", "description": "bad", "index": 0,
"parent_instance_key": None},
]
result = asyncio.run(
planner._dispatch_tool(
"execute_contract",
{"contract_name": "send_email", "instance_nodes": bad_nodes},
retry_counts,
)
)
assert isinstance(result, PlanStep)
assert result.type == "error"
assert "failed validation" in result.message.lower() or "failed" in result.message.lower()
def test_dispatch_execute_contract_with_trace():
"""Successful execute_contract with trace_enabled populates trace_contracts."""
planner = _make_planner()
planner.provider_client.call.return_value = {"id": "email_traced"}
retry_counts: dict[str, int] = {}
trace_contracts: list[dict] = []
from tests.test_runner import _minimal_email_nodes
result = asyncio.run(
planner._dispatch_tool(
"execute_contract",
{"contract_name": "send_email", "instance_nodes": _minimal_email_nodes()},
retry_counts,
trace_enabled=True,
trace_contracts=trace_contracts,
)
)
assert isinstance(result, dict)
assert result == {"id": "email_traced"}
assert len(trace_contracts) == 1
entry = trace_contracts[0]
assert entry["provider"] == "resend"
assert entry["action"] == "send_email"
assert entry["validation"]["verify_passed"] is True
assert entry["api_call"] is not None
assert entry["api_call"]["response_body"] == {"id": "email_traced"}
def test_dispatch_execute_contract_error_with_trace():
"""Failed execute_contract with trace_enabled records partial trace."""
planner = _make_planner()
retry_counts: dict[str, int] = {}
trace_contracts: list[dict] = []
bad_nodes = [
{"instance_key": "root", "element_key": "nonexistent",
"variant_key": "send_email_skeleton", "data_type": "object",
"compile_key": "root", "description": "bad", "index": 0,
"parent_instance_key": None},
]
result = asyncio.run(
planner._dispatch_tool(
"execute_contract",
{"contract_name": "send_email", "instance_nodes": bad_nodes},
retry_counts,
trace_enabled=True,
trace_contracts=trace_contracts,
)
)
assert isinstance(result, dict)
assert result["error"] is True
assert len(trace_contracts) == 1
entry = trace_contracts[0]
assert entry["provider"] == "resend"
assert entry["action"] == "send_email"
assert entry["api_call"] is None
def test_classify_response():
assert Planner._classify_response("Done! Email sent successfully.") == "completed"
assert Planner._classify_response("Please specify the recipient.") == "input_required"
assert Planner._classify_response("Could you provide the subject?") == "input_required"
assert Planner._classify_response("Error: API returned 500") == "error"
# ---- Planner run() integration tests with trace ----
def _mock_llm_response(*, tool_calls=None, content=None, finish_reason="stop",
prompt_tokens=100, completion_tokens=50):
"""Build a mock litellm response."""
resp = MagicMock()
msg = MagicMock()
msg.content = content
msg.tool_calls = tool_calls
msg.model_dump.return_value = {
"role": "assistant",
"content": content,
**({"tool_calls": [
{"id": tc.id, "type": "function",
"function": {"name": tc.function.name, "arguments": tc.function.arguments}}
for tc in tool_calls
]} if tool_calls else {}),
}
choice = MagicMock()
choice.message = msg
choice.finish_reason = finish_reason
resp.choices = [choice]
resp.usage = MagicMock()
resp.usage.prompt_tokens = prompt_tokens
resp.usage.completion_tokens = completion_tokens
resp.usage.total_tokens = prompt_tokens + completion_tokens
return resp
def _mock_tool_call(name, arguments, call_id="call_1"):
tc = MagicMock()
tc.id = call_id
tc.function.name = name
tc.function.arguments = json.dumps(arguments)
return tc
def test_planner_run_trace_accumulates_steps_and_cost():
"""Full run() with trace collects planner_steps and cost_metrics."""
planner = _make_planner()
search_tc = _mock_tool_call("search_contracts", {"query": "email"}, "call_1")
resp1 = _mock_llm_response(
tool_calls=[search_tc], finish_reason="tool_calls",
prompt_tokens=100, completion_tokens=50,
)
resp2 = _mock_llm_response(
content="Done! Email sent successfully.",
prompt_tokens=200, completion_tokens=30,
)
steps: list[PlanStep] = []
async def _run():
with patch("fireaction_a2a.planner.litellm") as mock_litellm:
mock_litellm.acompletion = AsyncMock(side_effect=[resp1, resp2])
async for step in planner.run("send an email", [], trace_enabled=True):
steps.append(step)
asyncio.run(_run())
final = steps[-1]
assert final.type == "completed"
assert final.trace_data is not None
trace = final.trace_data
assert "execution_trace" in trace
assert "cost_metrics" in trace
assert len(trace["execution_trace"]["planner_steps"]) == 1
assert trace["execution_trace"]["planner_steps"][0]["tool"] == "search_contracts"
assert trace["execution_trace"]["total_planner_steps"] == 1
cost = trace["cost_metrics"]
assert cost["num_llm_calls"] == 2
assert cost["prompt_tokens"] == 300
assert cost["completion_tokens"] == 80
assert cost["total_tokens"] == 380
assert cost["total_duration_ms"] >= 0
def test_planner_run_without_trace_returns_none():
"""run() with trace_enabled=False yields PlanStep with trace_data=None."""
planner = _make_planner()
resp = _mock_llm_response(content="Done! Email sent.")
steps: list[PlanStep] = []
async def _run():
with patch("fireaction_a2a.planner.litellm") as mock_litellm:
mock_litellm.acompletion = AsyncMock(return_value=resp)
async for step in planner.run("send email", [], trace_enabled=False):
steps.append(step)
asyncio.run(_run())
assert steps[-1].type == "completed"
assert steps[-1].trace_data is None