import asyncio import json import pytest from agent.config import Config from agent.core import agent_loop from agent.core.agent_loop import Handlers, LLMResult from agent.core.session import Session from agent.tools.plan_tool import PlanTool class FakeToolRouter: def __init__(self): self.calls = [] def get_tool_specs_for_llm(self): return [ { "type": "function", "function": { "name": "plan_tool", "description": "Update plan", "parameters": {"type": "object"}, }, } ] async def call_tool(self, name, arguments, session=None, tool_call_id=None): self.calls.append((name, arguments, tool_call_id)) if name == "plan_tool" and session is not None: session.current_plan = [dict(todo) for todo in arguments["todos"]] return "plan updated", True @pytest.mark.asyncio async def test_plan_tool_stores_session_scoped_plan(): events = [] class FakeSession: current_plan = [] async def send_event(self, event): events.append(event) session = FakeSession() todos = [{"id": "1", "content": "Smoke test", "status": "in_progress"}] result = await PlanTool(session=session).execute({"todos": todos}) assert result["isError"] is False assert session.current_plan == todos assert events[0].event_type == "plan_update" assert events[0].data == {"plan": todos} @pytest.mark.asyncio async def test_no_tool_response_retries_when_plan_is_incomplete(monkeypatch): config = Config.model_validate( {"model_name": "openai/test", "save_sessions": False} ) event_queue = asyncio.Queue() router = FakeToolRouter() session = Session( event_queue, config, tool_router=router, stream=False, ) session.current_plan = [ { "id": "1", "content": "Write and smoke-test training script", "status": "in_progress", }, {"id": "2", "content": "Launch full training job", "status": "pending"}, ] calls = [] async def fake_call_llm_non_streaming(session, messages, tools, llm_params): calls.append(messages) if len(calls) == 1: return LLMResult( content="I should keep going, but I forgot to call a tool.", tool_calls_acc={}, token_count=10, finish_reason="stop", ) if len(calls) == 2: assert "CONTINUATION GUARD" in messages[-1].content return LLMResult( content=None, tool_calls_acc={ 0: { "id": "call_1", "function": { "name": "plan_tool", "arguments": json.dumps( { "todos": [ { "id": "1", "content": "Write and smoke-test training script", "status": "completed", }, { "id": "2", "content": "Launch full training job", "status": "completed", }, ] } ), }, } }, token_count=20, finish_reason="tool_calls", ) return LLMResult( content="Done.", tool_calls_acc={}, token_count=30, finish_reason="stop", ) monkeypatch.setattr( agent_loop, "_resolve_llm_params", lambda *_, **__: {"model": "openai/test"} ) monkeypatch.setattr( agent_loop, "_call_llm_non_streaming", fake_call_llm_non_streaming ) final = await Handlers.run_agent(session, "continue") assert final == "Done." assert len(calls) == 3 assert router.calls[0][0] == "plan_tool" assert all(todo["status"] == "completed" for todo in session.current_plan) events = [] while not event_queue.empty(): events.append(await event_queue.get()) assert any( event.event_type == "tool_log" and "text-only response" in (event.data or {}).get("log", "") for event in events )