File size: 5,764 Bytes
715cbbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cfed54
715cbbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f912cfc
 
 
 
 
 
 
715cbbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""Agent + provider tests.

Offline by default: a FakeProvider returns scripted tool calls so the
briefing build, tool-schema filtering, tool-call parsing, and the full
ModelAgent -> eval_core -> live-Rust loop are all asserted without a
network. The OpenRouter live test is opt-in (skips without an API key).
"""

from __future__ import annotations

import os
from pathlib import Path

import pytest

ot = pytest.importorskip("openra_train", reason="Rust env wheel not installed")

pytest.importorskip("openra_rl_training", reason="Rust env wheel not installed")
from openra_bench.agent import ModelAgent, build_briefing, _to_commands, _tool_schemas
from openra_bench.providers import ChatProvider, ChatReply, ProviderConfig

TRAIN = Path("/Users/berta/Projects/OpenRA-RL-Training")
PACKS = Path(__file__).parent.parent / "openra_bench" / "scenarios" / "packs"


class FakeProvider(ChatProvider):
    """Returns a fixed sequence of tool-call sets, then observe()."""

    def __init__(self, script: list[list[dict]]):
        self.script = script
        self.i = 0
        self.seen_messages: list[list[dict]] = []

    def complete(self, messages, tools):
        self.seen_messages.append(messages)
        calls = self.script[self.i] if self.i < len(self.script) else [
            {"name": "observe", "arguments": {}}
        ]
        self.i += 1
        return ChatReply(text="", tool_calls=calls)


def test_tool_schema_filtering():
    only_move = _tool_schemas(["move_units"])
    names = {t["function"]["name"] for t in only_move}
    assert "move_units" in names
    assert "attack_unit" not in names
    assert "observe" in names, "a no-op must always be offered"
    # Unset → DEFAULT_CORE (movement/combat), not every tool.
    core = {t["function"]["name"] for t in _tool_schemas(None)}
    assert {"move_units", "attack_unit", "observe"} <= core
    assert "build" not in core and "surrender" not in core
    # Wildcard exposes the full set (economy/structure/etc).
    allp = {t["function"]["name"] for t in _tool_schemas(["*"])}
    assert {"build", "harvest", "place_building", "stop", "deploy"} <= allp


def test_build_briefing_format():
    rs = {
        "game_tick": 120,
        "explored_percent": 12.5,
        "units_summary": [
            {"id": "1001", "type": "jeep", "cell_x": 5, "cell_y": 6, "activity": "idle"}
        ],
        "enemy_summary": [],
    }
    b = build_briefing(rs, objective="find the base")
    assert "OBJECTIVE: find the base" in b
    assert "1001 jeep @(5,6)" in b
    assert "tick=120" in b and "explored=12.5%" in b
    assert "none (scout the fog)" in b


def test_tool_call_parsing_and_aliases():
    cmds = _to_commands(
        [
            {"name": "move_units", "arguments": {"unit_ids": [1, 2], "target_x": 9, "target_y": 4}},
            {"name": "attack_target", "arguments": {"unit_ids": [3], "target_id": 77}},  # alias
            {"name": "observe", "arguments": {}},
            {"name": "garbage", "arguments": {}},  # dropped
            {"name": "move_units", "arguments": {"unit_ids": [1]}},  # malformed -> dropped
        ],
        ot.Command,
    )
    # 3 valid (move, attack-alias, observe); 2 dropped
    assert len(cmds) == 3


def test_model_agent_drives_live_rust_with_fake_provider():
    from openra_bench.eval_core import run_level
    from openra_bench.scenarios import load_pack
    from openra_bench.scenarios.loader import compile_level

    pack = load_pack(PACKS / "perception-frontier-reading.yaml")
    compiled = compile_level(pack, "easy")

    # Scripted "scout east" behaviour, then observe forever.
    fake = FakeProvider(
        [[{"name": "move_units", "arguments": {"unit_ids": [1001], "target_x": 100, "target_y": 20}}]]
        * 8
    )
    agent = ModelAgent(
        ProviderConfig(vision=False),
        allowed_tools=compiled.scenario.tools,
        objective=compiled.scenario.description,
        provider=fake,
    )
    res = run_level(compiled, agent.agent_fn, seed=1)

    assert res.outcome in {"win", "draw", "loss"}
    assert res.turns >= 1 and len(res.trace) == res.turns
    assert agent.stats["turns"] == res.turns
    # Provider actually saw a system prompt + a user briefing.
    first = fake.seen_messages[0]
    assert first[0]["role"] == "system"
    assert any(m["role"] == "user" for m in first)


def test_history_strips_stale_images():
    hist = [
        {"role": "user", "content": [{"type": "text", "text": "t1"}, {"type": "image_url", "image_url": {}}]},
        {"role": "user", "content": [{"type": "text", "text": "t2"}, {"type": "image_url", "image_url": {}}]},
    ]
    ModelAgent._strip_old_images(hist)
    assert isinstance(hist[0]["content"], str) and hist[0]["content"] == "t1"
    assert isinstance(hist[1]["content"], list)  # newest image kept


@pytest.mark.skipif(
    not os.environ.get("OPENROUTER_API_KEY"),
    reason="set OPENROUTER_API_KEY to run the live provider test",
)
def test_openrouter_live_smoke():
    from openra_bench.eval_core import run_level
    from openra_bench.scenarios import load_pack
    from openra_bench.scenarios.loader import compile_level

    pack = load_pack(PACKS / "perception-frontier-reading.yaml")
    compiled = compile_level(pack, "easy")
    agent = ModelAgent(
        ProviderConfig(
            provider="openrouter",
            model=os.environ.get("OPENROUTER_MODEL", "anthropic/claude-3.5-sonnet"),
            vision=False,
            max_tokens=512,
        ),
        allowed_tools=compiled.scenario.tools,
        objective=compiled.scenario.description,
    )
    res = run_level(compiled, agent.agent_fn, seed=1)
    assert res.outcome in {"win", "draw", "loss"}
    assert agent.stats["tool_calls"] >= 1, "model issued no usable tool calls"