| """Tests for the agent loop driver. |
| |
| We never hit a real LLM. The loop talks to a `Backend` (see |
| `agent/backends/`); each test injects a `FakeBackend` whose `next_turn` |
| returns a queued sequence of scripted `AgentTurn` objects. Tools are |
| stubbed so we can drive specific control-flow paths. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| import pytest |
|
|
| from agent import loop as loop_module |
| from agent.backends.base import AgentTurn, Backend, ToolCall |
| from agent.schemas import SSEEvent, ToolResult |
|
|
|
|
| |
| |
| |
|
|
|
|
| class FakeBackend(Backend): |
| """A scripted Backend for testing the loop in isolation. |
| |
| Each test queues a list of `AgentTurn`s; calling `next_turn` pops the |
| next one. We also record every tool result the loop hands back so tests |
| can assert that error / id / content were threaded through correctly. |
| """ |
|
|
| name = "fake" |
|
|
| def __init__( |
| self, |
| scripted_turns: list[AgentTurn] | None = None, |
| next_turn_raises: BaseException | None = None, |
| ) -> None: |
| self._scripted = list(scripted_turns or []) |
| self._raise_on_next = next_turn_raises |
| self.user_messages: list[str] = [] |
| self.tool_results: list[dict[str, Any]] = [] |
| self.turn_count = 0 |
|
|
| def add_user_message(self, content: str) -> None: |
| self.user_messages.append(content) |
|
|
| def add_tool_result( |
| self, |
| tool_call_id: str, |
| name: str, |
| content: str, |
| is_error: bool, |
| ) -> None: |
| self.tool_results.append( |
| { |
| "id": tool_call_id, |
| "name": name, |
| "content": content, |
| "is_error": is_error, |
| } |
| ) |
|
|
| async def next_turn(self, tool_schemas: list[dict[str, Any]]) -> AgentTurn: |
| self.turn_count += 1 |
| if self._raise_on_next is not None: |
| exc = self._raise_on_next |
| self._raise_on_next = None |
| raise exc |
| if not self._scripted: |
| raise AssertionError( |
| "FakeBackend exhausted — loop made more turns than expected" |
| ) |
| return self._scripted.pop(0) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _install_backend(monkeypatch, backend: Backend) -> Backend: |
| """Replace `make_backend` so the loop sees our fake.""" |
| monkeypatch.setattr(loop_module, "make_backend", lambda **_kwargs: backend) |
| return backend |
|
|
|
|
| def _install_make_backend_raises(monkeypatch, exc: BaseException) -> None: |
| def boom(**_kwargs: Any) -> Backend: |
| raise exc |
|
|
| monkeypatch.setattr(loop_module, "make_backend", boom) |
|
|
|
|
| def _install_fake_tools( |
| monkeypatch, tool_responses: dict[str, ToolResult] |
| ) -> list[str]: |
| """Replace `tools_module.call` and `tool_schemas`. Returns a list |
| that records the order tools were invoked. |
| """ |
| invoked: list[str] = [] |
|
|
| def fake_call(name: str, **_kwargs: Any) -> ToolResult: |
| invoked.append(name) |
| return tool_responses.get( |
| name, ToolResult(ok=False, error=f"no fake registered for {name}") |
| ) |
|
|
| monkeypatch.setattr(loop_module.tools_module, "call", fake_call) |
| monkeypatch.setattr(loop_module.tools_module, "tool_schemas", lambda: []) |
| return invoked |
|
|
|
|
| async def _collect(stream) -> list[SSEEvent]: |
| out: list[SSEEvent] = [] |
| async for event in stream: |
| out.append(event) |
| return out |
|
|
|
|
| |
| |
| |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_emits_thought_then_tool_call_then_tool_result(monkeypatch) -> None: |
| backend = FakeBackend( |
| scripted_turns=[ |
| AgentTurn( |
| text_blocks=["I'll start by parsing the config."], |
| tool_calls=[ |
| ToolCall(id="tu_1", name="parse_config", input={"file_path": "/x.py"}) |
| ], |
| stop_reason="tool_use", |
| ), |
| AgentTurn(text_blocks=["Done."], tool_calls=[], stop_reason="end_turn"), |
| ] |
| ) |
| _install_backend(monkeypatch, backend) |
| invoked = _install_fake_tools( |
| monkeypatch, |
| {"parse_config": ToolResult(ok=True, result={"model_name": "x"})}, |
| ) |
|
|
| events = await _collect(loop_module.run_audit("/x.py")) |
| types = [e.type for e in events] |
|
|
| assert types[0] == "thought" |
| assert types[1] == "tool_call" |
| assert types[2] == "tool_result" |
| |
| assert types[-1] == "error" |
| assert "without producing a final report" in events[-1].data["message"] |
| assert invoked == ["parse_config"] |
|
|
| |
| assert events[1].data == { |
| "id": "tu_1", |
| "name": "parse_config", |
| "input": {"file_path": "/x.py"}, |
| } |
| assert events[2].data["ok"] is True |
| assert events[2].data["result"] == {"model_name": "x"} |
| assert events[2].data["error"] is None |
|
|
| |
| assert backend.user_messages == ["Audit this fine-tuning workload: /x.py"] |
| assert backend.tool_results == [ |
| { |
| "id": "tu_1", |
| "name": "parse_config", |
| "content": '{"model_name": "x"}', |
| "is_error": False, |
| } |
| ] |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_final_report_extracted_from_compare_runs(monkeypatch) -> None: |
| fake_report = {"workload_name": "test", "speedup_actual": 2.0} |
| backend = FakeBackend( |
| scripted_turns=[ |
| AgentTurn( |
| text_blocks=["Wrapping up."], |
| tool_calls=[ |
| ToolCall( |
| id="tu_compare", |
| name="compare_runs", |
| input={ |
| "workload_name": "t", |
| "before": {}, |
| "after": {}, |
| "patch": {}, |
| }, |
| ) |
| ], |
| stop_reason="end_turn", |
| ), |
| ] |
| ) |
| _install_backend(monkeypatch, backend) |
| _install_fake_tools( |
| monkeypatch, {"compare_runs": ToolResult(ok=True, result=fake_report)} |
| ) |
|
|
| events = await _collect(loop_module.run_audit("/x.py")) |
|
|
| assert events[-1].type == "final_report" |
| assert events[-1].data["report"] == fake_report |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_tool_error_passes_through_does_not_crash(monkeypatch) -> None: |
| backend = FakeBackend( |
| scripted_turns=[ |
| AgentTurn( |
| text_blocks=["Trying parse."], |
| tool_calls=[ |
| ToolCall(id="tu_1", name="parse_config", input={"file_path": "/bogus"}) |
| ], |
| stop_reason="tool_use", |
| ), |
| AgentTurn(text_blocks=["Giving up."], tool_calls=[], stop_reason="end_turn"), |
| ] |
| ) |
| _install_backend(monkeypatch, backend) |
| _install_fake_tools( |
| monkeypatch, |
| {"parse_config": ToolResult(ok=False, error="file not found")}, |
| ) |
|
|
| events = await _collect(loop_module.run_audit("/bogus")) |
| tool_result_events = [e for e in events if e.type == "tool_result"] |
| assert len(tool_result_events) == 1 |
| assert tool_result_events[0].data["ok"] is False |
| assert tool_result_events[0].data["error"] == "file not found" |
| |
| assert events[-1].type == "error" |
|
|
| |
| assert backend.tool_results[-1]["is_error"] is True |
| assert backend.tool_results[-1]["content"] == "file not found" |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_backend_construction_failure_yields_error_event(monkeypatch) -> None: |
| _install_make_backend_raises( |
| monkeypatch, RuntimeError("HF_TOKEN is not set; Qwen backend cannot run.") |
| ) |
| events = await _collect(loop_module.run_audit("/x.py")) |
| assert len(events) == 1 |
| assert events[0].type == "error" |
| assert "HF_TOKEN" in events[0].data["message"] |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_mid_loop_exception_is_caught(monkeypatch) -> None: |
| backend = FakeBackend(next_turn_raises=RuntimeError("boom")) |
| _install_backend(monkeypatch, backend) |
| monkeypatch.setattr(loop_module.tools_module, "tool_schemas", lambda: []) |
|
|
| events = await _collect(loop_module.run_audit("/x.py")) |
| assert events[-1].type == "error" |
| assert "boom" in events[-1].data["message"] |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_loop_caps_at_max_steps(monkeypatch) -> None: |
| """Even if the model never says end_turn, we bail after MAX_STEPS.""" |
| backend = FakeBackend( |
| scripted_turns=[ |
| AgentTurn( |
| text_blocks=[f"step {i}"], |
| tool_calls=[ |
| ToolCall(id=f"tu_{i}", name="parse_config", input={"file_path": "/x.py"}) |
| ], |
| stop_reason="tool_use", |
| ) |
| for i in range(loop_module.MAX_STEPS + 2) |
| ] |
| ) |
| _install_backend(monkeypatch, backend) |
| _install_fake_tools(monkeypatch, {"parse_config": ToolResult(ok=True, result={})}) |
|
|
| events = await _collect(loop_module.run_audit("/x.py")) |
| |
| assert backend.turn_count == loop_module.MAX_STEPS |
| |
| assert events[-1].type == "error" |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_tool_call_id_is_threaded_to_backend(monkeypatch) -> None: |
| """The loop must hand the tool_call id back to the backend so the next |
| turn's request can correlate the tool_result with the originating call. |
| """ |
| backend = FakeBackend( |
| scripted_turns=[ |
| AgentTurn( |
| text_blocks=["parse"], |
| tool_calls=[ |
| ToolCall(id="tu_abc", name="parse_config", input={"file_path": "/x"}) |
| ], |
| stop_reason="tool_use", |
| ), |
| AgentTurn(text_blocks=["done"], tool_calls=[], stop_reason="end_turn"), |
| ] |
| ) |
| _install_backend(monkeypatch, backend) |
| _install_fake_tools( |
| monkeypatch, {"parse_config": ToolResult(ok=True, result={"a": 1})} |
| ) |
|
|
| await _collect(loop_module.run_audit("/x")) |
|
|
| |
| assert len(backend.tool_results) == 1 |
| assert backend.tool_results[0]["id"] == "tu_abc" |
| assert backend.tool_results[0]["name"] == "parse_config" |
| assert backend.tool_results[0]["is_error"] is False |
|
|