Spaces:
Running
Running
File size: 5,764 Bytes
715cbbc 5cfed54 715cbbc f912cfc 715cbbc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """Agent + provider tests.
Offline by default: a FakeProvider returns scripted tool calls so the
briefing build, tool-schema filtering, tool-call parsing, and the full
ModelAgent -> eval_core -> live-Rust loop are all asserted without a
network. The OpenRouter live test is opt-in (skips without an API key).
"""
from __future__ import annotations
import os
from pathlib import Path
import pytest
ot = pytest.importorskip("openra_train", reason="Rust env wheel not installed")
pytest.importorskip("openra_rl_training", reason="Rust env wheel not installed")
from openra_bench.agent import ModelAgent, build_briefing, _to_commands, _tool_schemas
from openra_bench.providers import ChatProvider, ChatReply, ProviderConfig
TRAIN = Path("/Users/berta/Projects/OpenRA-RL-Training")
PACKS = Path(__file__).parent.parent / "openra_bench" / "scenarios" / "packs"
class FakeProvider(ChatProvider):
"""Returns a fixed sequence of tool-call sets, then observe()."""
def __init__(self, script: list[list[dict]]):
self.script = script
self.i = 0
self.seen_messages: list[list[dict]] = []
def complete(self, messages, tools):
self.seen_messages.append(messages)
calls = self.script[self.i] if self.i < len(self.script) else [
{"name": "observe", "arguments": {}}
]
self.i += 1
return ChatReply(text="", tool_calls=calls)
def test_tool_schema_filtering():
only_move = _tool_schemas(["move_units"])
names = {t["function"]["name"] for t in only_move}
assert "move_units" in names
assert "attack_unit" not in names
assert "observe" in names, "a no-op must always be offered"
# Unset → DEFAULT_CORE (movement/combat), not every tool.
core = {t["function"]["name"] for t in _tool_schemas(None)}
assert {"move_units", "attack_unit", "observe"} <= core
assert "build" not in core and "surrender" not in core
# Wildcard exposes the full set (economy/structure/etc).
allp = {t["function"]["name"] for t in _tool_schemas(["*"])}
assert {"build", "harvest", "place_building", "stop", "deploy"} <= allp
def test_build_briefing_format():
rs = {
"game_tick": 120,
"explored_percent": 12.5,
"units_summary": [
{"id": "1001", "type": "jeep", "cell_x": 5, "cell_y": 6, "activity": "idle"}
],
"enemy_summary": [],
}
b = build_briefing(rs, objective="find the base")
assert "OBJECTIVE: find the base" in b
assert "1001 jeep @(5,6)" in b
assert "tick=120" in b and "explored=12.5%" in b
assert "none (scout the fog)" in b
def test_tool_call_parsing_and_aliases():
cmds = _to_commands(
[
{"name": "move_units", "arguments": {"unit_ids": [1, 2], "target_x": 9, "target_y": 4}},
{"name": "attack_target", "arguments": {"unit_ids": [3], "target_id": 77}}, # alias
{"name": "observe", "arguments": {}},
{"name": "garbage", "arguments": {}}, # dropped
{"name": "move_units", "arguments": {"unit_ids": [1]}}, # malformed -> dropped
],
ot.Command,
)
# 3 valid (move, attack-alias, observe); 2 dropped
assert len(cmds) == 3
def test_model_agent_drives_live_rust_with_fake_provider():
from openra_bench.eval_core import run_level
from openra_bench.scenarios import load_pack
from openra_bench.scenarios.loader import compile_level
pack = load_pack(PACKS / "perception-frontier-reading.yaml")
compiled = compile_level(pack, "easy")
# Scripted "scout east" behaviour, then observe forever.
fake = FakeProvider(
[[{"name": "move_units", "arguments": {"unit_ids": [1001], "target_x": 100, "target_y": 20}}]]
* 8
)
agent = ModelAgent(
ProviderConfig(vision=False),
allowed_tools=compiled.scenario.tools,
objective=compiled.scenario.description,
provider=fake,
)
res = run_level(compiled, agent.agent_fn, seed=1)
assert res.outcome in {"win", "draw", "loss"}
assert res.turns >= 1 and len(res.trace) == res.turns
assert agent.stats["turns"] == res.turns
# Provider actually saw a system prompt + a user briefing.
first = fake.seen_messages[0]
assert first[0]["role"] == "system"
assert any(m["role"] == "user" for m in first)
def test_history_strips_stale_images():
hist = [
{"role": "user", "content": [{"type": "text", "text": "t1"}, {"type": "image_url", "image_url": {}}]},
{"role": "user", "content": [{"type": "text", "text": "t2"}, {"type": "image_url", "image_url": {}}]},
]
ModelAgent._strip_old_images(hist)
assert isinstance(hist[0]["content"], str) and hist[0]["content"] == "t1"
assert isinstance(hist[1]["content"], list) # newest image kept
@pytest.mark.skipif(
not os.environ.get("OPENROUTER_API_KEY"),
reason="set OPENROUTER_API_KEY to run the live provider test",
)
def test_openrouter_live_smoke():
from openra_bench.eval_core import run_level
from openra_bench.scenarios import load_pack
from openra_bench.scenarios.loader import compile_level
pack = load_pack(PACKS / "perception-frontier-reading.yaml")
compiled = compile_level(pack, "easy")
agent = ModelAgent(
ProviderConfig(
provider="openrouter",
model=os.environ.get("OPENROUTER_MODEL", "anthropic/claude-3.5-sonnet"),
vision=False,
max_tokens=512,
),
allowed_tools=compiled.scenario.tools,
objective=compiled.scenario.description,
)
res = run_level(compiled, agent.agent_fn, seed=1)
assert res.outcome in {"win", "draw", "loss"}
assert agent.stats["tool_calls"] >= 1, "model issued no usable tool calls"
|