File size: 4,753 Bytes
b05b6f5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import asyncio
import json
import pytest
from agent.config import Config
from agent.core import agent_loop
from agent.core.agent_loop import Handlers, LLMResult
from agent.core.session import Session
from agent.tools.plan_tool import PlanTool
class FakeToolRouter:
def __init__(self):
self.calls = []
def get_tool_specs_for_llm(self):
return [
{
"type": "function",
"function": {
"name": "plan_tool",
"description": "Update plan",
"parameters": {"type": "object"},
},
}
]
async def call_tool(self, name, arguments, session=None, tool_call_id=None):
self.calls.append((name, arguments, tool_call_id))
if name == "plan_tool" and session is not None:
session.current_plan = [dict(todo) for todo in arguments["todos"]]
return "plan updated", True
@pytest.mark.asyncio
async def test_plan_tool_stores_session_scoped_plan():
events = []
class FakeSession:
current_plan = []
async def send_event(self, event):
events.append(event)
session = FakeSession()
todos = [{"id": "1", "content": "Smoke test", "status": "in_progress"}]
result = await PlanTool(session=session).execute({"todos": todos})
assert result["isError"] is False
assert session.current_plan == todos
assert events[0].event_type == "plan_update"
assert events[0].data == {"plan": todos}
@pytest.mark.asyncio
async def test_no_tool_response_retries_when_plan_is_incomplete(monkeypatch):
config = Config.model_validate(
{"model_name": "openai/test", "save_sessions": False}
)
event_queue = asyncio.Queue()
router = FakeToolRouter()
session = Session(
event_queue,
config,
tool_router=router,
stream=False,
)
session.current_plan = [
{
"id": "1",
"content": "Write and smoke-test training script",
"status": "in_progress",
},
{"id": "2", "content": "Launch full training job", "status": "pending"},
]
calls = []
async def fake_call_llm_non_streaming(session, messages, tools, llm_params):
calls.append(messages)
if len(calls) == 1:
return LLMResult(
content="I should keep going, but I forgot to call a tool.",
tool_calls_acc={},
token_count=10,
finish_reason="stop",
)
if len(calls) == 2:
assert "CONTINUATION GUARD" in messages[-1].content
return LLMResult(
content=None,
tool_calls_acc={
0: {
"id": "call_1",
"function": {
"name": "plan_tool",
"arguments": json.dumps(
{
"todos": [
{
"id": "1",
"content": "Write and smoke-test training script",
"status": "completed",
},
{
"id": "2",
"content": "Launch full training job",
"status": "completed",
},
]
}
),
},
}
},
token_count=20,
finish_reason="tool_calls",
)
return LLMResult(
content="Done.",
tool_calls_acc={},
token_count=30,
finish_reason="stop",
)
monkeypatch.setattr(
agent_loop, "_resolve_llm_params", lambda *_, **__: {"model": "openai/test"}
)
monkeypatch.setattr(
agent_loop, "_call_llm_non_streaming", fake_call_llm_non_streaming
)
final = await Handlers.run_agent(session, "continue")
assert final == "Done."
assert len(calls) == 3
assert router.calls[0][0] == "plan_tool"
assert all(todo["status"] == "completed" for todo in session.current_plan)
events = []
while not event_queue.empty():
events.append(await event_queue.get())
assert any(
event.event_type == "tool_log"
and "text-only response" in (event.data or {}).get("log", "")
for event in events
)
|