open-range / tests /test_openenv_contract.py
Lars Talian
Extend OpenEnv contract guardrails
cd25692
"""Guardrail tests for OpenEnv contract compatibility."""
from __future__ import annotations
import asyncio
import importlib
import inspect
from typing import Any
from fastapi import FastAPI
import openenv.core.env_server as openenv_env_server
import pytest
from pydantic import ValidationError
from open_range.client.client import OpenRangeEnv
from open_range.models import RangeAction, RangeObservation, RangeState
from open_range.server.app import create_app
from open_range.server.environment import RangeEnvironment
from openenv.core.env_server.types import Action, Observation, State
def _call_route_endpoint(app: FastAPI, path: str, method: str = "GET") -> Any:
"""Invoke a route endpoint directly for lightweight schema checks."""
for route in app.routes:
methods = getattr(route, "methods", set())
if getattr(route, "path", None) != path or method not in methods:
continue
result = route.endpoint()
if inspect.isawaitable(result):
return asyncio.run(result)
return result
raise AssertionError(f"Route {method} {path} not found")
class TestModelContract:
def test_models_do_not_redeclare_inherited_openenv_fields(self):
# These fields must stay inherited from OpenEnv base models.
assert issubclass(RangeAction, Action)
assert issubclass(RangeObservation, Observation)
assert issubclass(RangeState, State)
assert "metadata" not in RangeAction.__annotations__
assert "done" not in RangeObservation.__annotations__
assert "reward" not in RangeObservation.__annotations__
assert "metadata" not in RangeObservation.__annotations__
assert "episode_id" not in RangeState.__annotations__
assert "step_count" not in RangeState.__annotations__
def test_models_expose_inherited_openenv_fields(self):
assert "metadata" in RangeAction.model_fields
assert "done" in RangeObservation.model_fields
assert "reward" in RangeObservation.model_fields
assert "metadata" in RangeObservation.model_fields
assert "episode_id" in RangeState.model_fields
assert "step_count" in RangeState.model_fields
def test_action_and_observation_reject_unknown_fields(self):
with pytest.raises(ValidationError):
RangeAction(command="whoami", mode="red", unknown_field="x")
with pytest.raises(ValidationError):
RangeObservation(stdout="ok", extra_field="x")
def test_state_allows_unknown_fields(self):
state = RangeState(step_count=1, extra_field="ok")
assert state.extra_field == "ok"
dumped = state.model_dump()
assert dumped["extra_field"] == "ok"
class TestAppFactoryContract:
def test_create_app_wires_openenv_factory_with_expected_types(self, monkeypatch):
captured: dict[str, Any] = {}
def fake_create_app(env_factory, action_type, observation_type, *, env_name):
captured["env_factory"] = env_factory
captured["action_type"] = action_type
captured["observation_type"] = observation_type
captured["env_name"] = env_name
return FastAPI()
monkeypatch.delenv("OPENRANGE_ENABLE_MANAGED_RUNTIME", raising=False)
monkeypatch.delenv("OPENRANGE_RUNTIME_MANIFEST", raising=False)
monkeypatch.setattr(openenv_env_server, "create_app", fake_create_app)
app_module = importlib.import_module("open_range.server.app")
app = app_module.create_app()
assert isinstance(app, FastAPI)
assert captured["action_type"] is RangeAction
assert captured["observation_type"] is RangeObservation
assert captured["env_name"] == "open_range"
assert callable(captured["env_factory"])
assert isinstance(captured["env_factory"](), RangeEnvironment)
assert isinstance(app.state.env, RangeEnvironment)
def test_create_app_exposes_required_openenv_routes(self, monkeypatch):
monkeypatch.delenv("OPENRANGE_ENABLE_MANAGED_RUNTIME", raising=False)
monkeypatch.delenv("OPENRANGE_RUNTIME_MANIFEST", raising=False)
app_module = importlib.import_module("open_range.server.app")
app = app_module.create_app()
paths = {route.path for route in app.router.routes}
required_paths = {"/health", "/metadata", "/schema", "/reset", "/step", "/state", "/ws"}
assert required_paths.issubset(paths)
def test_create_app_exposes_openenv_server_state(self, monkeypatch):
monkeypatch.delenv("OPENRANGE_ENABLE_MANAGED_RUNTIME", raising=False)
monkeypatch.delenv("OPENRANGE_RUNTIME_MANIFEST", raising=False)
app = create_app()
assert hasattr(app.state, "openenv_server")
assert isinstance(app.state.env, RangeEnvironment)
assert hasattr(app.state.openenv_server, "_sessions")
assert app.state.openenv_server.active_sessions == 0
def test_schema_endpoints_expose_expected_contract_shapes(self, monkeypatch):
monkeypatch.delenv("OPENRANGE_ENABLE_MANAGED_RUNTIME", raising=False)
monkeypatch.delenv("OPENRANGE_RUNTIME_MANIFEST", raising=False)
app = create_app()
health_payload = _call_route_endpoint(app, "/health")
assert health_payload.model_dump() == {"status": "healthy"}
metadata_payload = _call_route_endpoint(app, "/metadata").model_dump()
assert metadata_payload["name"] == "open_range"
assert isinstance(metadata_payload["version"], str) and metadata_payload["version"]
assert isinstance(metadata_payload["description"], str) and metadata_payload["description"]
payload = _call_route_endpoint(app, "/schema").model_dump()
assert payload["action"]["properties"]["command"]["type"] == "string"
assert payload["action"]["properties"]["mode"]["enum"] == ["red", "blue"]
assert "stdout" in payload["observation"]["properties"]
assert "done" in payload["observation"]["properties"]
assert "episode_id" in payload["state"]["properties"]
assert "step_count" in payload["state"]["properties"]
class TestClientContract:
def test_step_payload_matches_server_contract(self):
client = OpenRangeEnv(base_url="http://localhost:8000")
payload = client._step_payload(
RangeAction(command="nmap -sV web", mode="red", metadata={"source": "test"})
)
assert payload == {"command": "nmap -sV web", "mode": "red"}
def test_parse_result_uses_observation_and_top_level_done_reward(self):
client = OpenRangeEnv(base_url="http://localhost:8000")
result = client._parse_result(
{
"observation": {
"stdout": "ok",
"stderr": "",
"done": False,
"reward": 0.1,
"flags_captured": ["FLAG{a}"],
},
"done": 1,
"reward": 0.75,
}
)
assert isinstance(result.observation, RangeObservation)
assert result.observation.stdout == "ok"
assert result.observation.flags_captured == ["FLAG{a}"]
assert result.done is True
assert result.reward == 0.75
def test_parse_state_round_trips_openenv_and_extended_state_fields(self):
client = OpenRangeEnv(base_url="http://localhost:8000")
state = client._parse_state(
{
"episode_id": "ep-123",
"step_count": 4,
"mode": "red",
"tier": 2,
"custom_key": "value",
}
)
assert isinstance(state, RangeState)
assert state.episode_id == "ep-123"
assert state.step_count == 4
assert state.mode == "red"
assert state.tier == 2
assert state.custom_key == "value"