File size: 7,256 Bytes
4732653 f28440f 4732653 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | import json
from pathlib import Path
from fastapi.testclient import TestClient
from sysadmin_env.server import create_app
class FakeSandbox:
def __init__(self, lowerdir, *, timeout=30.0, isolate_network=True, overlay_base_dir=None):
self.lowerdir = Path(lowerdir)
self.timeout = timeout
self.isolate_network = isolate_network
self.overlay_base_dir = overlay_base_dir
self.merged_root = self.lowerdir
self.created = False
self.destroyed = False
self.commands = []
def create(self):
self.created = True
async def execute_async(self, command: str):
self.commands.append(command)
if command == "sleep 999":
return type("Result", (), {
"stdout": "",
"stderr": "",
"exit_code": -1,
"execution_time": 1.5,
"timed_out": True,
})()
return type("Result", (), {
"stdout": f"ran {command}",
"stderr": "",
"exit_code": 0,
"execution_time": 0.02,
"timed_out": False,
})()
def destroy(self):
self.destroyed = True
def _make_client(monkeypatch):
import sysadmin_env.server as server_module
monkeypatch.setattr(server_module, "Sandbox", FakeSandbox)
app = create_app()
return TestClient(app)
def test_health_endpoint(monkeypatch):
client = _make_client(monkeypatch)
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
def test_tasks_endpoint_returns_registered_scenarios(monkeypatch):
client = _make_client(monkeypatch)
response = client.get("/tasks")
assert response.status_code == 200
payload = response.json()
assert "tasks" in payload
assert len(payload["tasks"]) == 3
assert {task["task_id"] for task in payload["tasks"]} == {"nginx_crash", "disk_full", "network_broken"}
def test_http_reset_step_and_state_contract(monkeypatch):
client = _make_client(monkeypatch)
reset_response = client.post("/reset", json={"task_id": "nginx_crash"})
assert reset_response.status_code == 200
reset_payload = reset_response.json()
assert reset_payload["state"]["task_id"] == "nginx_crash"
assert reset_payload["state"]["step_count"] == 0
assert reset_payload["observation"]["done"] is False
step_response = client.post("/step", json={"action": {"command": "echo hello"}})
assert step_response.status_code == 200
step_payload = step_response.json()
assert step_payload["observation"]["stdout"] == "ran echo hello"
assert step_payload["state"]["step_count"] == 1
state_response = client.get("/state")
assert state_response.status_code == 200
assert state_response.json()["step_count"] == 1
def test_http_step_requires_reset(monkeypatch):
client = _make_client(monkeypatch)
response = client.post("/step", json={"action": {"command": "echo hello"}})
assert response.status_code == 409
def test_web_routes_expose_compatibility_shim_and_delegate_to_current_environment(monkeypatch):
client = _make_client(monkeypatch)
web_response = client.get("/web")
assert web_response.status_code == 200
assert "sysadmin-env web compatibility shim" in web_response.text
metadata_response = client.get("/web/metadata")
assert metadata_response.status_code == 200
metadata_payload = metadata_response.json()
assert metadata_payload["name"] == "sysadmin-env"
assert "OpenEnv-compatible web shim routes" in metadata_payload["description"]
pre_reset_state_response = client.get("/web/state")
assert pre_reset_state_response.status_code == 200
assert pre_reset_state_response.json() == {
"episode_id": None,
"task_id": None,
"step_count": 0,
"max_steps": 0,
"done": False,
"reward": 0.0,
"initialized": False,
}
reset_response = client.post("/web/reset", json={"task_id": "nginx_crash"})
assert reset_response.status_code == 200
reset_payload = reset_response.json()
assert reset_payload["state"]["task_id"] == "nginx_crash"
assert reset_payload["observation"]["step_number"] == 0
assert reset_payload["reward"] == 0.0
assert reset_payload["done"] is False
step_response = client.post("/web/step", json={"action": {"command": "echo hello"}})
assert step_response.status_code == 200
step_payload = step_response.json()
assert step_payload["observation"]["stdout"] == "ran echo hello"
assert step_payload["state"]["step_count"] == 1
assert step_payload["reward"] == step_payload["observation"]["reward"]
assert step_payload["done"] is False
post_step_state_response = client.get("/web/state")
assert post_step_state_response.status_code == 200
assert post_step_state_response.json() == {
"episode_id": reset_payload["state"]["episode_id"],
"task_id": "nginx_crash",
"step_count": 1,
"max_steps": 40,
"done": False,
"reward": step_payload["reward"],
"initialized": True,
}
def test_websocket_handles_valid_invalid_and_timeout_actions(monkeypatch):
client = _make_client(monkeypatch)
with client.websocket_connect("/ws?task_id=nginx_crash") as websocket:
started = websocket.receive_json()
assert started["type"] == "episode_started"
assert started["task"]["task_id"] == "nginx_crash"
websocket.send_text("not json")
error_payload = websocket.receive_json()
assert error_payload == {
"type": "error",
"code": "invalid_action",
"message": "malformed action json",
}
websocket.send_text(json.dumps({"command": "echo hello"}))
observation_payload = websocket.receive_json()
assert observation_payload["type"] == "observation"
assert observation_payload["task_id"] == "nginx_crash"
observation = observation_payload["observation"]
assert observation["stdout"] == "ran echo hello"
assert observation["exit_code"] == 0
assert observation["step_number"] == 1
assert observation["max_steps"] == 40
assert observation["done"] is False
websocket.send_text(json.dumps({"command": "sleep 999"}))
timed_out_payload = websocket.receive_json()
timeout_observation = timed_out_payload["observation"]
assert timeout_observation["exit_code"] == -1
assert "command execution timed out" in timeout_observation["stderr"]
assert timeout_observation["step_number"] == 2
def test_websocket_disconnect_cleans_up_sandbox(monkeypatch):
import sysadmin_env.server as server_module
created = []
class TrackingSandbox(FakeSandbox):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
created.append(self)
monkeypatch.setattr(server_module, "Sandbox", TrackingSandbox)
app = create_app()
client = TestClient(app)
with client.websocket_connect("/ws?task_id=disk_full") as websocket:
websocket.receive_json()
assert created
assert created[0].destroyed is True
|