Spaces:
Runtime error
Runtime error
| """ | |
| Overflow Environment Client. | |
| Provides the client for connecting to an Overflow Environment server | |
| via WebSocket for persistent sessions. | |
| """ | |
| from typing import Any, Dict, List | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_client import EnvClient | |
| from .models import ( | |
| CarStateData, | |
| LaneOccupancyData, | |
| OverflowAction, | |
| OverflowObservation, | |
| OverflowState, | |
| Position, | |
| ProximityData, | |
| ) | |
| class OverflowEnv(EnvClient[OverflowAction, OverflowObservation, OverflowState]): | |
| """ | |
| WebSocket client for the Overflow Environment. | |
| Example: | |
| >>> with OverflowEnv(base_url="http://localhost:8000") as env: | |
| ... result = env.reset() | |
| ... print(result.observation.scene_description) | |
| ... print(result.observation.cars) # structured car data | |
| ... action = OverflowAction(decision="maintain", reasoning="Safe for now") | |
| ... result = env.step(action) | |
| """ | |
| def _step_payload(self, action: OverflowAction) -> Dict[str, Any]: | |
| """Convert OverflowAction to JSON payload for step request.""" | |
| return { | |
| "decision": action.decision, | |
| "reasoning": action.reasoning, | |
| } | |
| def _parse_result(self, payload: Dict[str, Any]) -> StepResult[OverflowObservation]: | |
| """Parse server response into StepResult[OverflowObservation].""" | |
| obs_data = payload.get("observation", {}) | |
| # Parse structured car data | |
| cars = [ | |
| CarStateData( | |
| carId=c["carId"], | |
| lane=c["lane"], | |
| position=Position(**c["position"]), | |
| speed=c["speed"], | |
| acceleration=c.get("acceleration", 0.0), | |
| ) | |
| for c in obs_data.get("cars", []) | |
| ] | |
| proximities = [ | |
| ProximityData(**p) for p in obs_data.get("proximities", []) | |
| ] | |
| lane_occupancies = [ | |
| LaneOccupancyData(**lo) for lo in obs_data.get("lane_occupancies", []) | |
| ] | |
| observation = OverflowObservation( | |
| scene_description=obs_data.get("scene_description", ""), | |
| incident_report=obs_data.get("incident_report", ""), | |
| done=payload.get("done", False), | |
| reward=payload.get("reward"), | |
| cars=cars, | |
| proximities=proximities, | |
| lane_occupancies=lane_occupancies, | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=payload.get("reward"), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: Dict[str, Any]) -> OverflowState: | |
| """Parse server response into OverflowState.""" | |
| return OverflowState( | |
| episode_id=payload.get("episode_id"), | |
| step_count=payload.get("step_count", 0), | |
| crash_count=payload.get("crash_count", 0), | |
| near_miss_count=payload.get("near_miss_count", 0), | |
| cars_reached_goal=payload.get("cars_reached_goal", 0), | |
| total_cars=payload.get("total_cars", 5), | |
| ) | |