File size: 4,609 Bytes
22b47f2
 
 
 
 
 
 
 
 
 
 
 
 
 
5920175
 
 
 
 
 
 
22b47f2
 
 
 
 
 
 
 
 
e46d8fe
 
22b47f2
 
 
 
 
 
e46d8fe
 
22b47f2
 
e46d8fe
 
22b47f2
 
 
 
 
 
 
 
 
 
 
e46d8fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22b47f2
 
e46d8fe
22b47f2
 
 
 
e46d8fe
22b47f2
 
e46d8fe
 
 
 
 
5920175
 
 
 
 
 
 
 
 
 
 
 
e46d8fe
9c3eb66
 
22b47f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Traffic Light Environment Client."""

from typing import Dict

from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State

from .models import (
    NUM_DIRECTIONS,
    NUM_LANES,
    VEHICLE_TYPE_NAMES,
    TrafficLightAction,
    TrafficLightObservation,
)


class TrafficLightEnv(
    EnvClient[TrafficLightAction, TrafficLightObservation, State]
):
    """
    Client for the Traffic Light Environment.

    Controls a single 4-way intersection traffic light via WebSocket.
    Observes per-direction vehicle counts at 100 m and 500 m (4 directions,
    2 lanes each), plus per-direction light states.

    Use reset(task="task_name") to select a scenario:
        balanced, rush_hour_ns, rush_hour_ew, alternating_surge,
        random_spikes, gridlock, emergency_vehicle, or "random".

    Example:
        >>> async with TrafficLightEnv(base_url="http://localhost:8000") as client:
        ...     result = await client.reset(task="emergency_vehicle")
        ...     print(f"Task: {result.observation.task_name}")
        ...
        ...     result = await client.step(TrafficLightAction(phase=0))
        ...     print(f"NS 100m: {result.observation.ns_100m}")
    """

    def _step_payload(self, action: TrafficLightAction) -> Dict:
        return {
            "phase": action.phase,
        }

    def _parse_result(self, payload: Dict) -> StepResult[TrafficLightObservation]:
        obs_data = payload.get("observation", {})
        observation = TrafficLightObservation(
            task_name=obs_data.get("task_name", "balanced"),
            # Per-direction 100 m
            ns_100m=obs_data.get("ns_100m", 0),
            sn_100m=obs_data.get("sn_100m", 0),
            ew_100m=obs_data.get("ew_100m", 0),
            we_100m=obs_data.get("we_100m", 0),
            # Per-direction 500 m
            ns_500m=obs_data.get("ns_500m", 0),
            sn_500m=obs_data.get("sn_500m", 0),
            ew_500m=obs_data.get("ew_500m", 0),
            we_500m=obs_data.get("we_500m", 0),
            # Lights
            light_ns=obs_data.get("light_ns", 0),
            light_sn=obs_data.get("light_sn", 0),
            light_ew=obs_data.get("light_ew", 0),
            light_we=obs_data.get("light_we", 0),
            # Emergency
            emergency_direction=obs_data.get("emergency_direction", -1),
            emergency_lane=obs_data.get("emergency_lane", -1),
            emergency_wait=obs_data.get("emergency_wait", 0),
            # Phase / timing
            active_phase=obs_data.get("active_phase", 0),
            yellow_remaining=obs_data.get("yellow_remaining", 0),
            time_in_phase=obs_data.get("time_in_phase", 0),
            step_number=obs_data.get("step_number", 0),
            # Aggregates
            total_waiting=obs_data.get("total_waiting", 0),
            total_throughput=obs_data.get("total_throughput", 0),
            arrivals=obs_data.get("arrivals", [0] * NUM_DIRECTIONS),
            departures=obs_data.get("departures", [0] * NUM_DIRECTIONS),
            # Per-lane detail
            lanes_100m=obs_data.get("lanes_100m", [0] * NUM_LANES),
            lanes_500m=obs_data.get("lanes_500m", [0] * NUM_LANES),
            # Vehicle composition
            vehicles_100m=obs_data.get(
                "vehicles_100m",
                {vt: [0] * NUM_DIRECTIONS for vt in VEHICLE_TYPE_NAMES},
            ),
            vehicles_500m=obs_data.get(
                "vehicles_500m",
                {vt: [0] * NUM_DIRECTIONS for vt in VEHICLE_TYPE_NAMES},
            ),
            # Dilemma zone
            dilemma_risk=obs_data.get("dilemma_risk", 0.0),
            total_dilemma_vehicles=obs_data.get("total_dilemma_vehicles", 0.0),
            # Grading
            grade_score=obs_data.get("grade_score"),
            grade_details=obs_data.get("grade_details"),
            done=payload.get("done", False),
            reward=payload.get("reward"),
            metadata=obs_data.get("metadata", {}),
        )

        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict) -> State:
        return State(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
        )