Nexus-Grid / tests /test_api.py
Abineshsdata's picture
Add manifest.json endpoint, update dashboard and app
74965f9
Raw
History Blame Contribute Delete
10.9 kB
"""
Test API — proves step()/reset()/state() return correctly typed models.
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from fastapi.testclient import TestClient
from models import GridAction, GridObservation, ActionType
from server.nexusgrid_environment import NexusgridEnvironment
from server.app import app
class TestEnvironmentAPI:
"""Test the OpenEnv API interface."""
def _make_env(self) -> NexusgridEnvironment:
return NexusgridEnvironment()
def test_reset_returns_observation(self):
"""reset() should return a GridObservation."""
env = self._make_env()
obs = env.reset(seed=42, task_id=0)
assert isinstance(obs, GridObservation)
assert obs.tick == 0
assert obs.task_id == 0
assert obs.done is False
assert 58.0 <= obs.grid_frequency_hz <= 62.0
def test_reset_topology_present(self):
"""reset() observation should contain topology with nodes and edges."""
env = self._make_env()
obs = env.reset(seed=42, task_id=0)
topo = obs.topology_graph
assert "nodes" in topo
assert "edges" in topo
assert len(topo["nodes"]) == 20 # 20-node grid
assert len(topo["edges"]) == 40 # 40 edges
def test_step_returns_observation(self):
"""step() should return a GridObservation with reward and done."""
env = self._make_env()
env.reset(seed=42, task_id=0)
action = GridAction(
action_type=ActionType.DISPATCH_GENERATION,
node_id="NODE_01",
mw=100.0,
)
obs = env.step(action)
assert isinstance(obs, GridObservation)
assert obs.tick == 1
assert obs.grid_frequency_hz > 0
assert "reward_breakdown" in obs.metadata
assert "rubric_breakdown" in obs.metadata
def test_step_advance_tick(self):
"""advance_tick should increment the tick counter."""
env = self._make_env()
env.reset(seed=42, task_id=1)
action = GridAction(action_type=ActionType.ADVANCE_TICK)
obs = env.step(action)
assert obs.tick == 1
def test_state_returns_state(self):
"""state property should return an OpenEnv State."""
env = self._make_env()
env.reset(seed=42, task_id=0)
state = env.state
assert state.episode_id is not None
assert state.step_count == 0
def test_health_endpoint_returns_200(self):
"""The standalone health endpoint should respond immediately."""
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
assert response.json()["status"] == "healthy"
def test_web_dashboard_route_available(self):
"""The mounted dashboard route should remain available for HF Spaces."""
client = TestClient(app)
response = client.get("/web", follow_redirects=True)
assert response.status_code == 200
def test_websocket_reset_step_state_flow(self):
"""The persistent OpenEnv WebSocket API should support reset, step, and state."""
client = TestClient(app)
with client.websocket_connect("/ws") as ws:
ws.send_json({"type": "reset", "data": {"seed": 42, "task_id": 0}})
reset_message = ws.receive_json()
assert reset_message["type"] == "observation"
assert reset_message["data"]["observation"]["tick"] == 0
ws.send_json(
{
"type": "step",
"data": {
"action_type": "dispatch_generation",
"node_id": "NODE_01",
"mw": 100,
},
}
)
step_message = ws.receive_json()
assert step_message["type"] == "observation"
assert step_message["data"]["observation"]["tick"] == 1
ws.send_json({"type": "state"})
state_message = ws.receive_json()
assert state_message["type"] == "state"
assert state_message["data"]["step_count"] == 1
def test_done_when_frequency_below_59(self):
"""Episode should terminate when frequency drops below 59.0Hz."""
env = self._make_env()
env.reset(seed=42, task_id=5) # Black start — frequency starts at 59.0
# In black start, frequency can drop. Let's force it by doing nothing.
# The initial frequency is 59.0 for black start
action = GridAction(action_type=ActionType.ADVANCE_TICK)
done = False
for _ in range(5):
obs = env.step(action)
if obs.done:
done = True
break
# Black start starts at 59.0 which is already at termination boundary
# The env should detect this
assert done or obs.grid_frequency_hz >= 59.0
def test_smoke_test_scores_1(self):
"""Task 0 with valid dispatch should score 1.0."""
env = self._make_env()
env.reset(seed=42, task_id=0)
action = GridAction(
action_type=ActionType.DISPATCH_GENERATION,
node_id="NODE_01",
mw=100.0,
)
env.step(action)
score = env.get_score()
assert score == 1.0, f"Smoke test should score 1.0, got {score}"
def test_dispatch_updates_frequency_without_advance_tick(self):
"""Control actions should update the observable grid frequency immediately."""
env = self._make_env()
obs = env.reset(seed=42, task_id=1)
initial_frequency = obs.grid_frequency_hz
obs = env.step(
GridAction(
action_type=ActionType.DISPATCH_GENERATION,
node_id="NODE_04",
mw=200.0,
)
)
assert obs.grid_frequency_hz != initial_frequency
def test_cascade_overload_reset_exposes_overloaded_lines(self):
"""Task 2 should surface the reroute overload in the initial observation."""
env = self._make_env()
obs = env.reset(seed=42, task_id=2)
overloaded = [
edge["id"]
for edge in obs.topology_graph["edges"]
if edge["status"] == "LIVE" and edge["current_load_mw"] >= 0.95 * edge["capacity_mw"]
]
assert "LINE_29" in overloaded
def test_cascade_overload_persists_without_immediate_action(self):
"""Task 2 should remain overloaded after a no-op tick until the agent isolates it."""
env = self._make_env()
env.reset(seed=42, task_id=2)
obs = env.step(GridAction(action_type=ActionType.ADVANCE_TICK))
overloaded = [
edge["id"]
for edge in obs.topology_graph["edges"]
if edge["status"] == "LIVE" and edge["current_load_mw"] >= 0.95 * edge["capacity_mw"]
]
assert "LINE_29" in overloaded
def test_black_start_can_reenergize_critical_loads(self):
"""Closing breakers from the hydro island should energize downstream critical loads."""
env = self._make_env()
env.reset(seed=42, task_id=5)
actions = [
GridAction(action_type=ActionType.DISPATCH_GENERATION, node_id="NODE_01", mw=600.0),
GridAction(action_type=ActionType.TOGGLE_CIRCUIT_BREAKER, edge_id="LINE_02", status="CLOSED"),
GridAction(action_type=ActionType.TOGGLE_CIRCUIT_BREAKER, edge_id="LINE_28", status="CLOSED"),
GridAction(action_type=ActionType.DISPATCH_GENERATION, node_id="NODE_17", mw=400.0),
GridAction(action_type=ActionType.TOGGLE_CIRCUIT_BREAKER, edge_id="LINE_22", status="CLOSED"),
]
obs = None
for action in actions:
obs = env.step(action)
assert obs is not None
assert env._engine.nodes["NODE_03"]["energized"] is True
assert env._engine.nodes["NODE_18"]["energized"] is True
assert env._engine.nodes["NODE_18"]["consumption_mw"] > 0.0
assert env.get_score() > 0.0
def test_full_restoration_tick_records_first_recovery(self):
"""Task 2 should keep the first tick that crosses the restoration threshold."""
env = self._make_env()
env.reset(seed=42, task_id=2)
env.step(GridAction(action_type=ActionType.TOGGLE_CIRCUIT_BREAKER, edge_id="LINE_29", status="OPEN"))
assert env._get_full_restoration_tick() is None
env.step(GridAction(action_type=ActionType.DISPATCH_GENERATION, node_id="NODE_09", mw=100.0))
assert env._get_full_restoration_tick() is None
env.step(GridAction(action_type=ActionType.DISPATCH_GENERATION, node_id="NODE_13", mw=100.0))
assert env._get_full_restoration_tick() == 3
def test_reset_idempotent(self):
"""Calling reset(42) multiple times should produce identical observations."""
env = self._make_env()
obs1 = env.reset(seed=42, task_id=0)
obs2 = env.reset(seed=42, task_id=0)
assert obs1.grid_frequency_hz == obs2.grid_frequency_hz
assert obs1.tick == obs2.tick
assert obs1.task_id == obs2.task_id
def test_done_observation_includes_episode_summary(self):
"""Final observations should expose a compact episode summary for training logs."""
env = self._make_env()
env.reset(seed=42, task_id=0)
obs = env.step(
GridAction(
action_type=ActionType.DISPATCH_GENERATION,
node_id="NODE_01",
mw=100.0,
)
)
if not obs.done:
obs = env.step(GridAction(action_type=ActionType.ADVANCE_TICK))
if not obs.done:
obs = env.step(GridAction(action_type=ActionType.ADVANCE_TICK))
assert obs.done is True
assert "episode_summary" in obs.metadata
def test_all_action_types_accepted(self):
"""All action types should be accepted without crashing."""
env = self._make_env()
env.reset(seed=42, task_id=1)
actions = [
GridAction(action_type=ActionType.ADVANCE_TICK),
GridAction(action_type=ActionType.DISPATCH_GENERATION, node_id="NODE_01", mw=50),
GridAction(action_type=ActionType.TOGGLE_CIRCUIT_BREAKER, edge_id="LINE_01", status="OPEN"),
GridAction(action_type=ActionType.RUN_STATE_ESTIMATION, subgraph=["NODE_01", "NODE_02"]),
GridAction(action_type=ActionType.QUARANTINE_SCADA_NODE, node_id="NODE_01"),
]
for action in actions:
obs = env.step(action)
assert isinstance(obs, GridObservation)
def test_weather_present(self):
"""Observation should include weather data."""
env = self._make_env()
obs = env.reset(seed=42, task_id=0)
assert len(obs.weather_forecast_matrix) > 0
assert obs.weather_summary != ""