gpu-goblin / tests /test_loop.py
bharathtelu's picture
Deploy auto-tune UI + scripts (work-from-91d0cf0)
a9aa4ae verified
Raw
History Blame Contribute Delete
11.1 kB
"""Tests for the agent loop driver.
We never hit a real LLM. The loop talks to a `Backend` (see
`agent/backends/`); each test injects a `FakeBackend` whose `next_turn`
returns a queued sequence of scripted `AgentTurn` objects. Tools are
stubbed so we can drive specific control-flow paths.
"""
from __future__ import annotations
from typing import Any
import pytest
from agent import loop as loop_module
from agent.backends.base import AgentTurn, Backend, ToolCall
from agent.schemas import SSEEvent, ToolResult
# ---------------------------------------------------------------------------
# Fake backend
# ---------------------------------------------------------------------------
class FakeBackend(Backend):
"""A scripted Backend for testing the loop in isolation.
Each test queues a list of `AgentTurn`s; calling `next_turn` pops the
next one. We also record every tool result the loop hands back so tests
can assert that error / id / content were threaded through correctly.
"""
name = "fake"
def __init__(
self,
scripted_turns: list[AgentTurn] | None = None,
next_turn_raises: BaseException | None = None,
) -> None:
self._scripted = list(scripted_turns or [])
self._raise_on_next = next_turn_raises
self.user_messages: list[str] = []
self.tool_results: list[dict[str, Any]] = []
self.turn_count = 0
def add_user_message(self, content: str) -> None:
self.user_messages.append(content)
def add_tool_result(
self,
tool_call_id: str,
name: str,
content: str,
is_error: bool,
) -> None:
self.tool_results.append(
{
"id": tool_call_id,
"name": name,
"content": content,
"is_error": is_error,
}
)
async def next_turn(self, tool_schemas: list[dict[str, Any]]) -> AgentTurn:
self.turn_count += 1
if self._raise_on_next is not None:
exc = self._raise_on_next
self._raise_on_next = None
raise exc
if not self._scripted:
raise AssertionError(
"FakeBackend exhausted — loop made more turns than expected"
)
return self._scripted.pop(0)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _install_backend(monkeypatch, backend: Backend) -> Backend:
"""Replace `make_backend` so the loop sees our fake."""
monkeypatch.setattr(loop_module, "make_backend", lambda **_kwargs: backend)
return backend
def _install_make_backend_raises(monkeypatch, exc: BaseException) -> None:
def boom(**_kwargs: Any) -> Backend:
raise exc
monkeypatch.setattr(loop_module, "make_backend", boom)
def _install_fake_tools(
monkeypatch, tool_responses: dict[str, ToolResult]
) -> list[str]:
"""Replace `tools_module.call` and `tool_schemas`. Returns a list
that records the order tools were invoked.
"""
invoked: list[str] = []
def fake_call(name: str, **_kwargs: Any) -> ToolResult:
invoked.append(name)
return tool_responses.get(
name, ToolResult(ok=False, error=f"no fake registered for {name}")
)
monkeypatch.setattr(loop_module.tools_module, "call", fake_call)
monkeypatch.setattr(loop_module.tools_module, "tool_schemas", lambda: [])
return invoked
async def _collect(stream) -> list[SSEEvent]:
out: list[SSEEvent] = []
async for event in stream:
out.append(event)
return out
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_emits_thought_then_tool_call_then_tool_result(monkeypatch) -> None:
backend = FakeBackend(
scripted_turns=[
AgentTurn(
text_blocks=["I'll start by parsing the config."],
tool_calls=[
ToolCall(id="tu_1", name="parse_config", input={"file_path": "/x.py"})
],
stop_reason="tool_use",
),
AgentTurn(text_blocks=["Done."], tool_calls=[], stop_reason="end_turn"),
]
)
_install_backend(monkeypatch, backend)
invoked = _install_fake_tools(
monkeypatch,
{"parse_config": ToolResult(ok=True, result={"model_name": "x"})},
)
events = await _collect(loop_module.run_audit("/x.py"))
types = [e.type for e in events]
assert types[0] == "thought"
assert types[1] == "tool_call"
assert types[2] == "tool_result"
# No compare_runs ⇒ final event is the "no final report" error.
assert types[-1] == "error"
assert "without producing a final report" in events[-1].data["message"]
assert invoked == ["parse_config"]
# tool_call carries id/name/input; tool_result mirrors that plus ok/result/error.
assert events[1].data == {
"id": "tu_1",
"name": "parse_config",
"input": {"file_path": "/x.py"},
}
assert events[2].data["ok"] is True
assert events[2].data["result"] == {"model_name": "x"}
assert events[2].data["error"] is None
# The user message and tool result were threaded into the backend.
assert backend.user_messages == ["Audit this fine-tuning workload: /x.py"]
assert backend.tool_results == [
{
"id": "tu_1",
"name": "parse_config",
"content": '{"model_name": "x"}',
"is_error": False,
}
]
@pytest.mark.asyncio
async def test_final_report_extracted_from_compare_runs(monkeypatch) -> None:
fake_report = {"workload_name": "test", "speedup_actual": 2.0}
backend = FakeBackend(
scripted_turns=[
AgentTurn(
text_blocks=["Wrapping up."],
tool_calls=[
ToolCall(
id="tu_compare",
name="compare_runs",
input={
"workload_name": "t",
"before": {},
"after": {},
"patch": {},
},
)
],
stop_reason="end_turn",
),
]
)
_install_backend(monkeypatch, backend)
_install_fake_tools(
monkeypatch, {"compare_runs": ToolResult(ok=True, result=fake_report)}
)
events = await _collect(loop_module.run_audit("/x.py"))
assert events[-1].type == "final_report"
assert events[-1].data["report"] == fake_report
@pytest.mark.asyncio
async def test_tool_error_passes_through_does_not_crash(monkeypatch) -> None:
backend = FakeBackend(
scripted_turns=[
AgentTurn(
text_blocks=["Trying parse."],
tool_calls=[
ToolCall(id="tu_1", name="parse_config", input={"file_path": "/bogus"})
],
stop_reason="tool_use",
),
AgentTurn(text_blocks=["Giving up."], tool_calls=[], stop_reason="end_turn"),
]
)
_install_backend(monkeypatch, backend)
_install_fake_tools(
monkeypatch,
{"parse_config": ToolResult(ok=False, error="file not found")},
)
events = await _collect(loop_module.run_audit("/bogus"))
tool_result_events = [e for e in events if e.type == "tool_result"]
assert len(tool_result_events) == 1
assert tool_result_events[0].data["ok"] is False
assert tool_result_events[0].data["error"] == "file not found"
# The loop kept iterating rather than bailing.
assert events[-1].type == "error" # no compare_runs ⇒ "no final report"
# Backend received an is_error=True tool result with the error message.
assert backend.tool_results[-1]["is_error"] is True
assert backend.tool_results[-1]["content"] == "file not found"
@pytest.mark.asyncio
async def test_backend_construction_failure_yields_error_event(monkeypatch) -> None:
_install_make_backend_raises(
monkeypatch, RuntimeError("HF_TOKEN is not set; Qwen backend cannot run.")
)
events = await _collect(loop_module.run_audit("/x.py"))
assert len(events) == 1
assert events[0].type == "error"
assert "HF_TOKEN" in events[0].data["message"]
@pytest.mark.asyncio
async def test_mid_loop_exception_is_caught(monkeypatch) -> None:
backend = FakeBackend(next_turn_raises=RuntimeError("boom"))
_install_backend(monkeypatch, backend)
monkeypatch.setattr(loop_module.tools_module, "tool_schemas", lambda: [])
events = await _collect(loop_module.run_audit("/x.py"))
assert events[-1].type == "error"
assert "boom" in events[-1].data["message"]
@pytest.mark.asyncio
async def test_loop_caps_at_max_steps(monkeypatch) -> None:
"""Even if the model never says end_turn, we bail after MAX_STEPS."""
backend = FakeBackend(
scripted_turns=[
AgentTurn(
text_blocks=[f"step {i}"],
tool_calls=[
ToolCall(id=f"tu_{i}", name="parse_config", input={"file_path": "/x.py"})
],
stop_reason="tool_use",
)
for i in range(loop_module.MAX_STEPS + 2) # extra so we'd overrun
]
)
_install_backend(monkeypatch, backend)
_install_fake_tools(monkeypatch, {"parse_config": ToolResult(ok=True, result={})})
events = await _collect(loop_module.run_audit("/x.py"))
# Backend's next_turn was called exactly MAX_STEPS times.
assert backend.turn_count == loop_module.MAX_STEPS
# Last event is the "no final report" error (not a crash).
assert events[-1].type == "error"
@pytest.mark.asyncio
async def test_tool_call_id_is_threaded_to_backend(monkeypatch) -> None:
"""The loop must hand the tool_call id back to the backend so the next
turn's request can correlate the tool_result with the originating call.
"""
backend = FakeBackend(
scripted_turns=[
AgentTurn(
text_blocks=["parse"],
tool_calls=[
ToolCall(id="tu_abc", name="parse_config", input={"file_path": "/x"})
],
stop_reason="tool_use",
),
AgentTurn(text_blocks=["done"], tool_calls=[], stop_reason="end_turn"),
]
)
_install_backend(monkeypatch, backend)
_install_fake_tools(
monkeypatch, {"parse_config": ToolResult(ok=True, result={"a": 1})}
)
await _collect(loop_module.run_audit("/x"))
# Backend got exactly one tool_result with id=tu_abc.
assert len(backend.tool_results) == 1
assert backend.tool_results[0]["id"] == "tu_abc"
assert backend.tool_results[0]["name"] == "parse_config"
assert backend.tool_results[0]["is_error"] is False