File size: 3,081 Bytes
cb054fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Overflow Environment Client.

Provides the client for connecting to an Overflow Environment server
via WebSocket for persistent sessions.
"""

from typing import Any, Dict, List

from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient

from .models import (
    CarStateData,
    LaneOccupancyData,
    OverflowAction,
    OverflowObservation,
    OverflowState,
    Position,
    ProximityData,
)


class OverflowEnv(EnvClient[OverflowAction, OverflowObservation, OverflowState]):
    """
    WebSocket client for the Overflow Environment.

    Example:
        >>> with OverflowEnv(base_url="http://localhost:8000") as env:
        ...     result = env.reset()
        ...     print(result.observation.scene_description)
        ...     print(result.observation.cars)  # structured car data
        ...     action = OverflowAction(decision="maintain", reasoning="Safe for now")
        ...     result = env.step(action)
    """

    def _step_payload(self, action: OverflowAction) -> Dict[str, Any]:
        """Convert OverflowAction to JSON payload for step request."""
        return {
            "decision": action.decision,
            "reasoning": action.reasoning,
        }

    def _parse_result(self, payload: Dict[str, Any]) -> StepResult[OverflowObservation]:
        """Parse server response into StepResult[OverflowObservation]."""
        obs_data = payload.get("observation", {})

        # Parse structured car data
        cars = [
            CarStateData(
                carId=c["carId"],
                lane=c["lane"],
                position=Position(**c["position"]),
                speed=c["speed"],
                acceleration=c.get("acceleration", 0.0),
            )
            for c in obs_data.get("cars", [])
        ]

        proximities = [
            ProximityData(**p) for p in obs_data.get("proximities", [])
        ]

        lane_occupancies = [
            LaneOccupancyData(**lo) for lo in obs_data.get("lane_occupancies", [])
        ]

        observation = OverflowObservation(
            scene_description=obs_data.get("scene_description", ""),
            incident_report=obs_data.get("incident_report", ""),
            done=payload.get("done", False),
            reward=payload.get("reward"),
            cars=cars,
            proximities=proximities,
            lane_occupancies=lane_occupancies,
        )
        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict[str, Any]) -> OverflowState:
        """Parse server response into OverflowState."""
        return OverflowState(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
            crash_count=payload.get("crash_count", 0),
            near_miss_count=payload.get("near_miss_count", 0),
            cars_reached_goal=payload.get("cars_reached_goal", 0),
            total_cars=payload.get("total_cars", 5),
        )