Spaces:
Runtime error
Runtime error
File size: 3,081 Bytes
cb054fe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | """
Overflow Environment Client.
Provides the client for connecting to an Overflow Environment server
via WebSocket for persistent sessions.
"""
from typing import Any, Dict, List
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from .models import (
CarStateData,
LaneOccupancyData,
OverflowAction,
OverflowObservation,
OverflowState,
Position,
ProximityData,
)
class OverflowEnv(EnvClient[OverflowAction, OverflowObservation, OverflowState]):
"""
WebSocket client for the Overflow Environment.
Example:
>>> with OverflowEnv(base_url="http://localhost:8000") as env:
... result = env.reset()
... print(result.observation.scene_description)
... print(result.observation.cars) # structured car data
... action = OverflowAction(decision="maintain", reasoning="Safe for now")
... result = env.step(action)
"""
def _step_payload(self, action: OverflowAction) -> Dict[str, Any]:
"""Convert OverflowAction to JSON payload for step request."""
return {
"decision": action.decision,
"reasoning": action.reasoning,
}
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[OverflowObservation]:
"""Parse server response into StepResult[OverflowObservation]."""
obs_data = payload.get("observation", {})
# Parse structured car data
cars = [
CarStateData(
carId=c["carId"],
lane=c["lane"],
position=Position(**c["position"]),
speed=c["speed"],
acceleration=c.get("acceleration", 0.0),
)
for c in obs_data.get("cars", [])
]
proximities = [
ProximityData(**p) for p in obs_data.get("proximities", [])
]
lane_occupancies = [
LaneOccupancyData(**lo) for lo in obs_data.get("lane_occupancies", [])
]
observation = OverflowObservation(
scene_description=obs_data.get("scene_description", ""),
incident_report=obs_data.get("incident_report", ""),
done=payload.get("done", False),
reward=payload.get("reward"),
cars=cars,
proximities=proximities,
lane_occupancies=lane_occupancies,
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict[str, Any]) -> OverflowState:
"""Parse server response into OverflowState."""
return OverflowState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
crash_count=payload.get("crash_count", 0),
near_miss_count=payload.get("near_miss_count", 0),
cars_reached_goal=payload.get("cars_reached_goal", 0),
total_cars=payload.get("total_cars", 5),
)
|