Spaces:
Running
Running
| """Agent + provider tests. | |
| Offline by default: a FakeProvider returns scripted tool calls so the | |
| briefing build, tool-schema filtering, tool-call parsing, and the full | |
| ModelAgent -> eval_core -> live-Rust loop are all asserted without a | |
| network. The OpenRouter live test is opt-in (skips without an API key). | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| import pytest | |
| ot = pytest.importorskip("openra_train", reason="Rust env wheel not installed") | |
| pytest.importorskip("openra_rl_training", reason="Rust env wheel not installed") | |
| from openra_bench.agent import ModelAgent, build_briefing, _to_commands, _tool_schemas | |
| from openra_bench.providers import ChatProvider, ChatReply, ProviderConfig | |
| TRAIN = Path("/Users/berta/Projects/OpenRA-RL-Training") | |
| PACKS = Path(__file__).parent.parent / "openra_bench" / "scenarios" / "packs" | |
| class FakeProvider(ChatProvider): | |
| """Returns a fixed sequence of tool-call sets, then observe().""" | |
| def __init__(self, script: list[list[dict]]): | |
| self.script = script | |
| self.i = 0 | |
| self.seen_messages: list[list[dict]] = [] | |
| def complete(self, messages, tools): | |
| self.seen_messages.append(messages) | |
| calls = self.script[self.i] if self.i < len(self.script) else [ | |
| {"name": "observe", "arguments": {}} | |
| ] | |
| self.i += 1 | |
| return ChatReply(text="", tool_calls=calls) | |
| def test_tool_schema_filtering(): | |
| only_move = _tool_schemas(["move_units"]) | |
| names = {t["function"]["name"] for t in only_move} | |
| assert "move_units" in names | |
| assert "attack_unit" not in names | |
| assert "observe" in names, "a no-op must always be offered" | |
| # Unset → DEFAULT_CORE (movement/combat), not every tool. | |
| core = {t["function"]["name"] for t in _tool_schemas(None)} | |
| assert {"move_units", "attack_unit", "observe"} <= core | |
| assert "build" not in core and "surrender" not in core | |
| # Wildcard exposes the full set (economy/structure/etc). | |
| allp = {t["function"]["name"] for t in _tool_schemas(["*"])} | |
| assert {"build", "harvest", "place_building", "stop", "deploy"} <= allp | |
| def test_build_briefing_format(): | |
| rs = { | |
| "game_tick": 120, | |
| "explored_percent": 12.5, | |
| "units_summary": [ | |
| {"id": "1001", "type": "jeep", "cell_x": 5, "cell_y": 6, "activity": "idle"} | |
| ], | |
| "enemy_summary": [], | |
| } | |
| b = build_briefing(rs, objective="find the base") | |
| assert "OBJECTIVE: find the base" in b | |
| assert "1001 jeep @(5,6)" in b | |
| assert "tick=120" in b and "explored=12.5%" in b | |
| assert "none (scout the fog)" in b | |
| def test_tool_call_parsing_and_aliases(): | |
| cmds = _to_commands( | |
| [ | |
| {"name": "move_units", "arguments": {"unit_ids": [1, 2], "target_x": 9, "target_y": 4}}, | |
| {"name": "attack_target", "arguments": {"unit_ids": [3], "target_id": 77}}, # alias | |
| {"name": "observe", "arguments": {}}, | |
| {"name": "garbage", "arguments": {}}, # dropped | |
| {"name": "move_units", "arguments": {"unit_ids": [1]}}, # malformed -> dropped | |
| ], | |
| ot.Command, | |
| ) | |
| # 3 valid (move, attack-alias, observe); 2 dropped | |
| assert len(cmds) == 3 | |
| def test_model_agent_drives_live_rust_with_fake_provider(): | |
| from openra_bench.eval_core import run_level | |
| from openra_bench.scenarios import load_pack | |
| from openra_bench.scenarios.loader import compile_level | |
| pack = load_pack(PACKS / "perception-frontier-reading.yaml") | |
| compiled = compile_level(pack, "easy") | |
| # Scripted "scout east" behaviour, then observe forever. | |
| fake = FakeProvider( | |
| [[{"name": "move_units", "arguments": {"unit_ids": [1001], "target_x": 100, "target_y": 20}}]] | |
| * 8 | |
| ) | |
| agent = ModelAgent( | |
| ProviderConfig(vision=False), | |
| allowed_tools=compiled.scenario.tools, | |
| objective=compiled.scenario.description, | |
| provider=fake, | |
| ) | |
| res = run_level(compiled, agent.agent_fn, seed=1) | |
| assert res.outcome in {"win", "draw", "loss"} | |
| assert res.turns >= 1 and len(res.trace) == res.turns | |
| assert agent.stats["turns"] == res.turns | |
| # Provider actually saw a system prompt + a user briefing. | |
| first = fake.seen_messages[0] | |
| assert first[0]["role"] == "system" | |
| assert any(m["role"] == "user" for m in first) | |
| def test_history_strips_stale_images(): | |
| hist = [ | |
| {"role": "user", "content": [{"type": "text", "text": "t1"}, {"type": "image_url", "image_url": {}}]}, | |
| {"role": "user", "content": [{"type": "text", "text": "t2"}, {"type": "image_url", "image_url": {}}]}, | |
| ] | |
| ModelAgent._strip_old_images(hist) | |
| assert isinstance(hist[0]["content"], str) and hist[0]["content"] == "t1" | |
| assert isinstance(hist[1]["content"], list) # newest image kept | |
| def test_openrouter_live_smoke(): | |
| from openra_bench.eval_core import run_level | |
| from openra_bench.scenarios import load_pack | |
| from openra_bench.scenarios.loader import compile_level | |
| pack = load_pack(PACKS / "perception-frontier-reading.yaml") | |
| compiled = compile_level(pack, "easy") | |
| agent = ModelAgent( | |
| ProviderConfig( | |
| provider="openrouter", | |
| model=os.environ.get("OPENROUTER_MODEL", "anthropic/claude-3.5-sonnet"), | |
| vision=False, | |
| max_tokens=512, | |
| ), | |
| allowed_tools=compiled.scenario.tools, | |
| objective=compiled.scenario.description, | |
| ) | |
| res = run_level(compiled, agent.agent_fn, seed=1) | |
| assert res.outcome in {"win", "draw", "loss"} | |
| assert agent.stats["tool_calls"] >= 1, "model issued no usable tool calls" | |