Lars Talian commited on
Commit
cd25692
·
1 Parent(s): a49b769

Extend OpenEnv contract guardrails

Browse files
Files changed (1) hide show
  1. tests/test_openenv_contract.py +53 -0
tests/test_openenv_contract.py CHANGED
@@ -2,7 +2,9 @@
2
 
3
  from __future__ import annotations
4
 
 
5
  import importlib
 
6
  from typing import Any
7
 
8
  from fastapi import FastAPI
@@ -12,12 +14,30 @@ from pydantic import ValidationError
12
 
13
  from open_range.client.client import OpenRangeEnv
14
  from open_range.models import RangeAction, RangeObservation, RangeState
 
15
  from open_range.server.environment import RangeEnvironment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  class TestModelContract:
19
  def test_models_do_not_redeclare_inherited_openenv_fields(self):
20
  # These fields must stay inherited from OpenEnv base models.
 
 
 
21
  assert "metadata" not in RangeAction.__annotations__
22
  assert "done" not in RangeObservation.__annotations__
23
  assert "reward" not in RangeObservation.__annotations__
@@ -83,6 +103,39 @@ class TestAppFactoryContract:
83
  required_paths = {"/health", "/metadata", "/schema", "/reset", "/step", "/state", "/ws"}
84
  assert required_paths.issubset(paths)
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  class TestClientContract:
88
  def test_step_payload_matches_server_contract(self):
 
2
 
3
  from __future__ import annotations
4
 
5
+ import asyncio
6
  import importlib
7
+ import inspect
8
  from typing import Any
9
 
10
  from fastapi import FastAPI
 
14
 
15
  from open_range.client.client import OpenRangeEnv
16
  from open_range.models import RangeAction, RangeObservation, RangeState
17
+ from open_range.server.app import create_app
18
  from open_range.server.environment import RangeEnvironment
19
+ from openenv.core.env_server.types import Action, Observation, State
20
+
21
+
22
+ def _call_route_endpoint(app: FastAPI, path: str, method: str = "GET") -> Any:
23
+ """Invoke a route endpoint directly for lightweight schema checks."""
24
+ for route in app.routes:
25
+ methods = getattr(route, "methods", set())
26
+ if getattr(route, "path", None) != path or method not in methods:
27
+ continue
28
+ result = route.endpoint()
29
+ if inspect.isawaitable(result):
30
+ return asyncio.run(result)
31
+ return result
32
+ raise AssertionError(f"Route {method} {path} not found")
33
 
34
 
35
  class TestModelContract:
36
  def test_models_do_not_redeclare_inherited_openenv_fields(self):
37
  # These fields must stay inherited from OpenEnv base models.
38
+ assert issubclass(RangeAction, Action)
39
+ assert issubclass(RangeObservation, Observation)
40
+ assert issubclass(RangeState, State)
41
  assert "metadata" not in RangeAction.__annotations__
42
  assert "done" not in RangeObservation.__annotations__
43
  assert "reward" not in RangeObservation.__annotations__
 
103
  required_paths = {"/health", "/metadata", "/schema", "/reset", "/step", "/state", "/ws"}
104
  assert required_paths.issubset(paths)
105
 
106
+ def test_create_app_exposes_openenv_server_state(self, monkeypatch):
107
+ monkeypatch.delenv("OPENRANGE_ENABLE_MANAGED_RUNTIME", raising=False)
108
+ monkeypatch.delenv("OPENRANGE_RUNTIME_MANIFEST", raising=False)
109
+
110
+ app = create_app()
111
+
112
+ assert hasattr(app.state, "openenv_server")
113
+ assert isinstance(app.state.env, RangeEnvironment)
114
+ assert hasattr(app.state.openenv_server, "_sessions")
115
+ assert app.state.openenv_server.active_sessions == 0
116
+
117
+ def test_schema_endpoints_expose_expected_contract_shapes(self, monkeypatch):
118
+ monkeypatch.delenv("OPENRANGE_ENABLE_MANAGED_RUNTIME", raising=False)
119
+ monkeypatch.delenv("OPENRANGE_RUNTIME_MANIFEST", raising=False)
120
+
121
+ app = create_app()
122
+
123
+ health_payload = _call_route_endpoint(app, "/health")
124
+ assert health_payload.model_dump() == {"status": "healthy"}
125
+
126
+ metadata_payload = _call_route_endpoint(app, "/metadata").model_dump()
127
+ assert metadata_payload["name"] == "open_range"
128
+ assert isinstance(metadata_payload["version"], str) and metadata_payload["version"]
129
+ assert isinstance(metadata_payload["description"], str) and metadata_payload["description"]
130
+
131
+ payload = _call_route_endpoint(app, "/schema").model_dump()
132
+ assert payload["action"]["properties"]["command"]["type"] == "string"
133
+ assert payload["action"]["properties"]["mode"]["enum"] == ["red", "blue"]
134
+ assert "stdout" in payload["observation"]["properties"]
135
+ assert "done" in payload["observation"]["properties"]
136
+ assert "episode_id" in payload["state"]["properties"]
137
+ assert "step_count" in payload["state"]["properties"]
138
+
139
 
140
  class TestClientContract:
141
  def test_step_payload_matches_server_contract(self):