File size: 5,024 Bytes
6fac95b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Client for SUMO-RL environment.

This module provides a client to interact with the SUMO traffic signal
control environment via WebSocket for persistent sessions.
"""

from typing import Any, Dict

from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient

from .models import SumoAction, SumoObservation, SumoState


class SumoRLEnv(EnvClient[SumoAction, SumoObservation, SumoState]):
    """
    Client for SUMO-RL traffic signal control environment.

    This client maintains a persistent WebSocket connection to a SUMO
    environment server to control traffic signals using reinforcement learning.

    Example:
        >>> # Start container and connect
        >>> env = SumoRLEnv.from_docker_image("sumo-rl-env:latest")
        >>> try:
        ...     # Reset environment
        ...     result = env.reset()
        ...     print(f"Observation shape: {result.observation.observation_shape}")
        ...     print(f"Action space: {result.observation.action_mask}")
        ...
        ...     # Take action
        ...     result = env.step(SumoAction(phase_id=1))
        ...     print(f"Reward: {result.reward}, Done: {result.done}")
        ...
        ...     # Get state
        ...     state = env.state()
        ...     print(f"Sim time: {state.sim_time}, Total vehicles: {state.total_vehicles}")
        ... finally:
        ...     env.close()

    Example with custom network:
        >>> # Use custom SUMO network via volume mount
        >>> env = SumoRLEnv.from_docker_image(
        ...     "sumo-rl-env:latest",
        ...     port=8000,
        ...     volumes={
        ...         "/path/to/my/nets": {"bind": "/nets", "mode": "ro"}
        ...     },
        ...     environment={
        ...         "SUMO_NET_FILE": "/nets/my-network.net.xml",
        ...         "SUMO_ROUTE_FILE": "/nets/my-routes.rou.xml",
        ...     }
        ... )

    Example with configuration:
        >>> # Adjust simulation parameters
        >>> env = SumoRLEnv.from_docker_image(
        ...     "sumo-rl-env:latest",
        ...     environment={
        ...         "SUMO_NUM_SECONDS": "10000",
        ...         "SUMO_DELTA_TIME": "10",
        ...         "SUMO_REWARD_FN": "queue",
        ...         "SUMO_SEED": "123",
        ...     }
        ... )
    """

    def _step_payload(self, action: SumoAction) -> Dict[str, Any]:
        """
        Convert SumoAction to JSON payload for HTTP request.

        Args:
            action: SumoAction containing phase_id to execute.

        Returns:
            Dictionary payload for step endpoint.
        """
        return {
            "phase_id": action.phase_id,
            "ts_id": action.ts_id,
        }

    def _parse_result(self, payload: Dict[str, Any]) -> StepResult[SumoObservation]:
        """
        Parse step result from HTTP response JSON.

        Args:
            payload: JSON response from step endpoint.

        Returns:
            StepResult containing SumoObservation.
        """
        obs_data = payload.get("observation", {})

        observation = SumoObservation(
            observation=obs_data.get("observation", []),
            observation_shape=obs_data.get("observation_shape", []),
            action_mask=obs_data.get("action_mask", []),
            sim_time=obs_data.get("sim_time", 0.0),
            done=obs_data.get("done", False),
            reward=obs_data.get("reward"),
            metadata=obs_data.get("metadata", {}),
        )

        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict[str, Any]) -> SumoState:
        """
        Parse state from HTTP response JSON.

        Args:
            payload: JSON response from state endpoint.

        Returns:
            SumoState object.
        """
        return SumoState(
            episode_id=payload.get("episode_id", ""),
            step_count=payload.get("step_count", 0),
            net_file=payload.get("net_file", ""),
            route_file=payload.get("route_file", ""),
            num_seconds=payload.get("num_seconds", 20000),
            delta_time=payload.get("delta_time", 5),
            yellow_time=payload.get("yellow_time", 2),
            min_green=payload.get("min_green", 5),
            max_green=payload.get("max_green", 50),
            reward_fn=payload.get("reward_fn", "diff-waiting-time"),
            sim_time=payload.get("sim_time", 0.0),
            total_vehicles=payload.get("total_vehicles", 0),
            total_waiting_time=payload.get("total_waiting_time", 0.0),
            mean_waiting_time=payload.get("mean_waiting_time", 0.0),
            mean_speed=payload.get("mean_speed", 0.0),
        )