File size: 7,533 Bytes
6c20e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""Phase 2 verification tests for SentinelOpsArena environment."""

from sentinelops_arena.environment import SentinelOpsArena
from sentinelops_arena.models import SentinelAction, AgentRole


# -------------------------------------------------------------------
# Basic environment tests
# -------------------------------------------------------------------

def test_reset():
    env = SentinelOpsArena()
    obs = env.reset(seed=42)
    assert obs.done is False
    assert obs.current_agent == AgentRole.ATTACKER
    assert obs.tick == 0
    assert env.state.step_count == 0
    print("PASS: test_reset")


def test_turn_order():
    env = SentinelOpsArena()
    obs = env.reset(seed=42)
    assert obs.current_agent == AgentRole.ATTACKER

    obs = env.step(SentinelAction(agent=AgentRole.ATTACKER, action_type="pass"))
    assert obs.current_agent == AgentRole.WORKER

    obs = env.step(SentinelAction(
        agent=AgentRole.WORKER, action_type="respond", response_text="Hello"
    ))
    assert obs.current_agent == AgentRole.OVERSIGHT

    obs = env.step(SentinelAction(
        agent=AgentRole.OVERSIGHT, action_type="approve", flag=False
    ))
    assert obs.current_agent == AgentRole.ATTACKER
    assert env.tick == 1  # tick advanced after full rotation
    print("PASS: test_turn_order")


def test_full_episode():
    env = SentinelOpsArena()
    obs = env.reset(seed=42)
    steps = 0
    while not obs.done:
        agent = obs.current_agent
        if agent == AgentRole.ATTACKER:
            action = SentinelAction(agent=AgentRole.ATTACKER, action_type="pass")
        elif agent == AgentRole.WORKER:
            action = SentinelAction(
                agent=AgentRole.WORKER,
                action_type="respond",
                response_text="Done",
            )
        else:
            action = SentinelAction(
                agent=AgentRole.OVERSIGHT, action_type="approve", flag=False
            )
        obs = env.step(action)
        steps += 1

    assert env.tick == 30, f"Expected tick=30, got {env.tick}"
    assert steps == 90, f"Expected 90 steps, got {steps}"
    assert obs.done is True
    print("PASS: test_full_episode")


def test_wrong_turn_rejected():
    env = SentinelOpsArena()
    env.reset(seed=42)
    # Try worker action when it's attacker's turn
    obs = env.step(SentinelAction(
        agent=AgentRole.WORKER, action_type="respond", response_text="Wrong turn"
    ))
    assert obs.reward == -1.0
    print("PASS: test_wrong_turn_rejected")


# -------------------------------------------------------------------
# MCP routing tests
# -------------------------------------------------------------------

def test_mcp_list_tools():
    from openenv.core.env_server.mcp_types import ListToolsAction

    env = SentinelOpsArena()
    env.reset(seed=42)

    obs = env.step(ListToolsAction())
    tool_names = [t.name for t in obs.tools]
    assert "lookup_customer" in tool_names
    assert "launch_attack" in tool_names
    assert "issue_refund" in tool_names
    assert "flag_action" in tool_names
    # Reserved names must NOT appear
    assert "reset" not in tool_names
    assert "step" not in tool_names
    assert "state" not in tool_names
    assert "close" not in tool_names
    print(f"PASS: test_mcp_list_tools ({len(tool_names)} tools)")


def test_mcp_call_tool():
    from openenv.core.env_server.mcp_types import CallToolAction

    env = SentinelOpsArena()
    env.reset(seed=42)

    obs = env.step(CallToolAction(
        tool_name="lookup_customer", arguments={"customer_id": "C000"}
    ))
    assert obs.tool_name == "lookup_customer"
    assert obs.result is not None
    print("PASS: test_mcp_call_tool")


# -------------------------------------------------------------------
# Attack tests
# -------------------------------------------------------------------

def test_attacker_launch_attack():
    env = SentinelOpsArena()
    env.reset(seed=42)

    obs = env.step(SentinelAction(
        agent=AgentRole.ATTACKER,
        action_type="launch_attack",
        parameters={
            "attack_type": "schema_drift",
            "target_system": "crm",
            "old_field": "name",
            "new_field": "full_name",
        },
    ))
    # Attacker turn done, should be worker's turn now
    assert obs.current_agent == AgentRole.WORKER

    # Verify schema drift took effect
    schema = env.crm.get_schema()
    assert "full_name" in schema["fields"]
    assert "name" not in schema["fields"]
    print("PASS: test_attacker_launch_attack")


def test_worker_lookup_after_drift():
    env = SentinelOpsArena()
    env.reset(seed=42)

    # Attacker applies schema drift
    env.step(SentinelAction(
        agent=AgentRole.ATTACKER,
        action_type="launch_attack",
        parameters={
            "attack_type": "schema_drift",
            "target_system": "crm",
            "old_field": "name",
            "new_field": "full_name",
        },
    ))

    # Worker looks up customer
    obs = env.step(SentinelAction(
        agent=AgentRole.WORKER,
        action_type="lookup_customer",
        parameters={"customer_id": "C000"},
    ))
    # Should still succeed (field renamed but lookup_customer uses _apply_field_map)
    assert obs.last_action_result is not None
    print("PASS: test_worker_lookup_after_drift")


# -------------------------------------------------------------------
# State tests
# -------------------------------------------------------------------

def test_state_tracking():
    env = SentinelOpsArena()
    env.reset(seed=42)

    assert env.state.tick == 0
    assert env.state.step_count == 0
    assert env.state.tasks_total == 30

    # Do one full rotation
    env.step(SentinelAction(agent=AgentRole.ATTACKER, action_type="pass"))
    env.step(SentinelAction(
        agent=AgentRole.WORKER, action_type="respond", response_text="ok"
    ))
    env.step(SentinelAction(
        agent=AgentRole.OVERSIGHT, action_type="approve", flag=False
    ))

    assert env.state.tick == 1
    assert env.state.step_count == 3
    print("PASS: test_state_tracking")


# -------------------------------------------------------------------
# HTTP server test
# -------------------------------------------------------------------

def test_create_app():
    from openenv.core.env_server.http_server import create_app
    from sentinelops_arena.models import SentinelAction, SentinelObservation

    app = create_app(
        SentinelOpsArena,
        SentinelAction,
        SentinelObservation,
        env_name="sentinelops_arena",
    )
    assert app is not None
    print("PASS: test_create_app")


# -------------------------------------------------------------------
# Run all
# -------------------------------------------------------------------

if __name__ == "__main__":
    tests = [
        test_reset,
        test_turn_order,
        test_full_episode,
        test_wrong_turn_rejected,
        test_mcp_list_tools,
        test_mcp_call_tool,
        test_attacker_launch_attack,
        test_worker_lookup_after_drift,
        test_state_tracking,
        test_create_app,
    ]

    passed = 0
    failed = 0
    for test in tests:
        try:
            test()
            passed += 1
        except Exception as e:
            print(f"FAIL: {test.__name__}: {e}")
            failed += 1

    print(f"\n{'='*50}")
    print(f"Results: {passed}/{passed + failed} passed")
    if failed == 0:
        print("ALL PHASE 2 TESTS PASSED")
    else:
        print(f"{failed} test(s) FAILED")