Spaces:
Sleeping
Sleeping
| """ | |
| Test API — proves step()/reset()/state() return correctly typed models. | |
| """ | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| from fastapi.testclient import TestClient | |
| from models import GridAction, GridObservation, ActionType | |
| from server.nexusgrid_environment import NexusgridEnvironment | |
| from server.app import app | |
| class TestEnvironmentAPI: | |
| """Test the OpenEnv API interface.""" | |
| def _make_env(self) -> NexusgridEnvironment: | |
| return NexusgridEnvironment() | |
| def test_reset_returns_observation(self): | |
| """reset() should return a GridObservation.""" | |
| env = self._make_env() | |
| obs = env.reset(seed=42, task_id=0) | |
| assert isinstance(obs, GridObservation) | |
| assert obs.tick == 0 | |
| assert obs.task_id == 0 | |
| assert obs.done is False | |
| assert 58.0 <= obs.grid_frequency_hz <= 62.0 | |
| def test_reset_topology_present(self): | |
| """reset() observation should contain topology with nodes and edges.""" | |
| env = self._make_env() | |
| obs = env.reset(seed=42, task_id=0) | |
| topo = obs.topology_graph | |
| assert "nodes" in topo | |
| assert "edges" in topo | |
| assert len(topo["nodes"]) == 20 # 20-node grid | |
| assert len(topo["edges"]) == 40 # 40 edges | |
| def test_step_returns_observation(self): | |
| """step() should return a GridObservation with reward and done.""" | |
| env = self._make_env() | |
| env.reset(seed=42, task_id=0) | |
| action = GridAction( | |
| action_type=ActionType.DISPATCH_GENERATION, | |
| node_id="NODE_01", | |
| mw=100.0, | |
| ) | |
| obs = env.step(action) | |
| assert isinstance(obs, GridObservation) | |
| assert obs.tick == 1 | |
| assert obs.grid_frequency_hz > 0 | |
| assert "reward_breakdown" in obs.metadata | |
| assert "rubric_breakdown" in obs.metadata | |
| def test_step_advance_tick(self): | |
| """advance_tick should increment the tick counter.""" | |
| env = self._make_env() | |
| env.reset(seed=42, task_id=1) | |
| action = GridAction(action_type=ActionType.ADVANCE_TICK) | |
| obs = env.step(action) | |
| assert obs.tick == 1 | |
| def test_state_returns_state(self): | |
| """state property should return an OpenEnv State.""" | |
| env = self._make_env() | |
| env.reset(seed=42, task_id=0) | |
| state = env.state | |
| assert state.episode_id is not None | |
| assert state.step_count == 0 | |
| def test_health_endpoint_returns_200(self): | |
| """The standalone health endpoint should respond immediately.""" | |
| client = TestClient(app) | |
| response = client.get("/health") | |
| assert response.status_code == 200 | |
| assert response.json()["status"] == "healthy" | |
| def test_web_dashboard_route_available(self): | |
| """The mounted dashboard route should remain available for HF Spaces.""" | |
| client = TestClient(app) | |
| response = client.get("/web", follow_redirects=True) | |
| assert response.status_code == 200 | |
| def test_websocket_reset_step_state_flow(self): | |
| """The persistent OpenEnv WebSocket API should support reset, step, and state.""" | |
| client = TestClient(app) | |
| with client.websocket_connect("/ws") as ws: | |
| ws.send_json({"type": "reset", "data": {"seed": 42, "task_id": 0}}) | |
| reset_message = ws.receive_json() | |
| assert reset_message["type"] == "observation" | |
| assert reset_message["data"]["observation"]["tick"] == 0 | |
| ws.send_json( | |
| { | |
| "type": "step", | |
| "data": { | |
| "action_type": "dispatch_generation", | |
| "node_id": "NODE_01", | |
| "mw": 100, | |
| }, | |
| } | |
| ) | |
| step_message = ws.receive_json() | |
| assert step_message["type"] == "observation" | |
| assert step_message["data"]["observation"]["tick"] == 1 | |
| ws.send_json({"type": "state"}) | |
| state_message = ws.receive_json() | |
| assert state_message["type"] == "state" | |
| assert state_message["data"]["step_count"] == 1 | |
| def test_done_when_frequency_below_59(self): | |
| """Episode should terminate when frequency drops below 59.0Hz.""" | |
| env = self._make_env() | |
| env.reset(seed=42, task_id=5) # Black start — frequency starts at 59.0 | |
| # In black start, frequency can drop. Let's force it by doing nothing. | |
| # The initial frequency is 59.0 for black start | |
| action = GridAction(action_type=ActionType.ADVANCE_TICK) | |
| done = False | |
| for _ in range(5): | |
| obs = env.step(action) | |
| if obs.done: | |
| done = True | |
| break | |
| # Black start starts at 59.0 which is already at termination boundary | |
| # The env should detect this | |
| assert done or obs.grid_frequency_hz >= 59.0 | |
| def test_smoke_test_scores_1(self): | |
| """Task 0 with valid dispatch should score 1.0.""" | |
| env = self._make_env() | |
| env.reset(seed=42, task_id=0) | |
| action = GridAction( | |
| action_type=ActionType.DISPATCH_GENERATION, | |
| node_id="NODE_01", | |
| mw=100.0, | |
| ) | |
| env.step(action) | |
| score = env.get_score() | |
| assert score == 1.0, f"Smoke test should score 1.0, got {score}" | |
| def test_dispatch_updates_frequency_without_advance_tick(self): | |
| """Control actions should update the observable grid frequency immediately.""" | |
| env = self._make_env() | |
| obs = env.reset(seed=42, task_id=1) | |
| initial_frequency = obs.grid_frequency_hz | |
| obs = env.step( | |
| GridAction( | |
| action_type=ActionType.DISPATCH_GENERATION, | |
| node_id="NODE_04", | |
| mw=200.0, | |
| ) | |
| ) | |
| assert obs.grid_frequency_hz != initial_frequency | |
| def test_cascade_overload_reset_exposes_overloaded_lines(self): | |
| """Task 2 should surface the reroute overload in the initial observation.""" | |
| env = self._make_env() | |
| obs = env.reset(seed=42, task_id=2) | |
| overloaded = [ | |
| edge["id"] | |
| for edge in obs.topology_graph["edges"] | |
| if edge["status"] == "LIVE" and edge["current_load_mw"] >= 0.95 * edge["capacity_mw"] | |
| ] | |
| assert "LINE_29" in overloaded | |
| def test_cascade_overload_persists_without_immediate_action(self): | |
| """Task 2 should remain overloaded after a no-op tick until the agent isolates it.""" | |
| env = self._make_env() | |
| env.reset(seed=42, task_id=2) | |
| obs = env.step(GridAction(action_type=ActionType.ADVANCE_TICK)) | |
| overloaded = [ | |
| edge["id"] | |
| for edge in obs.topology_graph["edges"] | |
| if edge["status"] == "LIVE" and edge["current_load_mw"] >= 0.95 * edge["capacity_mw"] | |
| ] | |
| assert "LINE_29" in overloaded | |
| def test_black_start_can_reenergize_critical_loads(self): | |
| """Closing breakers from the hydro island should energize downstream critical loads.""" | |
| env = self._make_env() | |
| env.reset(seed=42, task_id=5) | |
| actions = [ | |
| GridAction(action_type=ActionType.DISPATCH_GENERATION, node_id="NODE_01", mw=600.0), | |
| GridAction(action_type=ActionType.TOGGLE_CIRCUIT_BREAKER, edge_id="LINE_02", status="CLOSED"), | |
| GridAction(action_type=ActionType.TOGGLE_CIRCUIT_BREAKER, edge_id="LINE_28", status="CLOSED"), | |
| GridAction(action_type=ActionType.DISPATCH_GENERATION, node_id="NODE_17", mw=400.0), | |
| GridAction(action_type=ActionType.TOGGLE_CIRCUIT_BREAKER, edge_id="LINE_22", status="CLOSED"), | |
| ] | |
| obs = None | |
| for action in actions: | |
| obs = env.step(action) | |
| assert obs is not None | |
| assert env._engine.nodes["NODE_03"]["energized"] is True | |
| assert env._engine.nodes["NODE_18"]["energized"] is True | |
| assert env._engine.nodes["NODE_18"]["consumption_mw"] > 0.0 | |
| assert env.get_score() > 0.0 | |
| def test_full_restoration_tick_records_first_recovery(self): | |
| """Task 2 should keep the first tick that crosses the restoration threshold.""" | |
| env = self._make_env() | |
| env.reset(seed=42, task_id=2) | |
| env.step(GridAction(action_type=ActionType.TOGGLE_CIRCUIT_BREAKER, edge_id="LINE_29", status="OPEN")) | |
| assert env._get_full_restoration_tick() is None | |
| env.step(GridAction(action_type=ActionType.DISPATCH_GENERATION, node_id="NODE_09", mw=100.0)) | |
| assert env._get_full_restoration_tick() is None | |
| env.step(GridAction(action_type=ActionType.DISPATCH_GENERATION, node_id="NODE_13", mw=100.0)) | |
| assert env._get_full_restoration_tick() == 3 | |
| def test_reset_idempotent(self): | |
| """Calling reset(42) multiple times should produce identical observations.""" | |
| env = self._make_env() | |
| obs1 = env.reset(seed=42, task_id=0) | |
| obs2 = env.reset(seed=42, task_id=0) | |
| assert obs1.grid_frequency_hz == obs2.grid_frequency_hz | |
| assert obs1.tick == obs2.tick | |
| assert obs1.task_id == obs2.task_id | |
| def test_done_observation_includes_episode_summary(self): | |
| """Final observations should expose a compact episode summary for training logs.""" | |
| env = self._make_env() | |
| env.reset(seed=42, task_id=0) | |
| obs = env.step( | |
| GridAction( | |
| action_type=ActionType.DISPATCH_GENERATION, | |
| node_id="NODE_01", | |
| mw=100.0, | |
| ) | |
| ) | |
| if not obs.done: | |
| obs = env.step(GridAction(action_type=ActionType.ADVANCE_TICK)) | |
| if not obs.done: | |
| obs = env.step(GridAction(action_type=ActionType.ADVANCE_TICK)) | |
| assert obs.done is True | |
| assert "episode_summary" in obs.metadata | |
| def test_all_action_types_accepted(self): | |
| """All action types should be accepted without crashing.""" | |
| env = self._make_env() | |
| env.reset(seed=42, task_id=1) | |
| actions = [ | |
| GridAction(action_type=ActionType.ADVANCE_TICK), | |
| GridAction(action_type=ActionType.DISPATCH_GENERATION, node_id="NODE_01", mw=50), | |
| GridAction(action_type=ActionType.TOGGLE_CIRCUIT_BREAKER, edge_id="LINE_01", status="OPEN"), | |
| GridAction(action_type=ActionType.RUN_STATE_ESTIMATION, subgraph=["NODE_01", "NODE_02"]), | |
| GridAction(action_type=ActionType.QUARANTINE_SCADA_NODE, node_id="NODE_01"), | |
| ] | |
| for action in actions: | |
| obs = env.step(action) | |
| assert isinstance(obs, GridObservation) | |
| def test_weather_present(self): | |
| """Observation should include weather data.""" | |
| env = self._make_env() | |
| obs = env.reset(seed=42, task_id=0) | |
| assert len(obs.weather_forecast_matrix) > 0 | |
| assert obs.weather_summary != "" | |