File size: 4,753 Bytes
b05b6f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import asyncio
import json

import pytest

from agent.config import Config
from agent.core import agent_loop
from agent.core.agent_loop import Handlers, LLMResult
from agent.core.session import Session
from agent.tools.plan_tool import PlanTool


class FakeToolRouter:
    def __init__(self):
        self.calls = []

    def get_tool_specs_for_llm(self):
        return [
            {
                "type": "function",
                "function": {
                    "name": "plan_tool",
                    "description": "Update plan",
                    "parameters": {"type": "object"},
                },
            }
        ]

    async def call_tool(self, name, arguments, session=None, tool_call_id=None):
        self.calls.append((name, arguments, tool_call_id))
        if name == "plan_tool" and session is not None:
            session.current_plan = [dict(todo) for todo in arguments["todos"]]
        return "plan updated", True


@pytest.mark.asyncio
async def test_plan_tool_stores_session_scoped_plan():
    events = []

    class FakeSession:
        current_plan = []

        async def send_event(self, event):
            events.append(event)

    session = FakeSession()
    todos = [{"id": "1", "content": "Smoke test", "status": "in_progress"}]

    result = await PlanTool(session=session).execute({"todos": todos})

    assert result["isError"] is False
    assert session.current_plan == todos
    assert events[0].event_type == "plan_update"
    assert events[0].data == {"plan": todos}


@pytest.mark.asyncio
async def test_no_tool_response_retries_when_plan_is_incomplete(monkeypatch):
    config = Config.model_validate(
        {"model_name": "openai/test", "save_sessions": False}
    )
    event_queue = asyncio.Queue()
    router = FakeToolRouter()
    session = Session(
        event_queue,
        config,
        tool_router=router,
        stream=False,
    )
    session.current_plan = [
        {
            "id": "1",
            "content": "Write and smoke-test training script",
            "status": "in_progress",
        },
        {"id": "2", "content": "Launch full training job", "status": "pending"},
    ]
    calls = []

    async def fake_call_llm_non_streaming(session, messages, tools, llm_params):
        calls.append(messages)
        if len(calls) == 1:
            return LLMResult(
                content="I should keep going, but I forgot to call a tool.",
                tool_calls_acc={},
                token_count=10,
                finish_reason="stop",
            )
        if len(calls) == 2:
            assert "CONTINUATION GUARD" in messages[-1].content
            return LLMResult(
                content=None,
                tool_calls_acc={
                    0: {
                        "id": "call_1",
                        "function": {
                            "name": "plan_tool",
                            "arguments": json.dumps(
                                {
                                    "todos": [
                                        {
                                            "id": "1",
                                            "content": "Write and smoke-test training script",
                                            "status": "completed",
                                        },
                                        {
                                            "id": "2",
                                            "content": "Launch full training job",
                                            "status": "completed",
                                        },
                                    ]
                                }
                            ),
                        },
                    }
                },
                token_count=20,
                finish_reason="tool_calls",
            )
        return LLMResult(
            content="Done.",
            tool_calls_acc={},
            token_count=30,
            finish_reason="stop",
        )

    monkeypatch.setattr(
        agent_loop, "_resolve_llm_params", lambda *_, **__: {"model": "openai/test"}
    )
    monkeypatch.setattr(
        agent_loop, "_call_llm_non_streaming", fake_call_llm_non_streaming
    )

    final = await Handlers.run_agent(session, "continue")

    assert final == "Done."
    assert len(calls) == 3
    assert router.calls[0][0] == "plan_tool"
    assert all(todo["status"] == "completed" for todo in session.current_plan)
    events = []
    while not event_queue.empty():
        events.append(await event_queue.get())
    assert any(
        event.event_type == "tool_log"
        and "text-only response" in (event.data or {}).get("log", "")
        for event in events
    )