traffic_light_env / client.py
rishabh16196's picture
Upload folder using huggingface_hub
5920175 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Traffic Light Environment Client."""
from typing import Dict
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
from .models import (
NUM_DIRECTIONS,
NUM_LANES,
VEHICLE_TYPE_NAMES,
TrafficLightAction,
TrafficLightObservation,
)
class TrafficLightEnv(
EnvClient[TrafficLightAction, TrafficLightObservation, State]
):
"""
Client for the Traffic Light Environment.
Controls a single 4-way intersection traffic light via WebSocket.
Observes per-direction vehicle counts at 100 m and 500 m (4 directions,
2 lanes each), plus per-direction light states.
Use reset(task="task_name") to select a scenario:
balanced, rush_hour_ns, rush_hour_ew, alternating_surge,
random_spikes, gridlock, emergency_vehicle, or "random".
Example:
>>> async with TrafficLightEnv(base_url="http://localhost:8000") as client:
... result = await client.reset(task="emergency_vehicle")
... print(f"Task: {result.observation.task_name}")
...
... result = await client.step(TrafficLightAction(phase=0))
... print(f"NS 100m: {result.observation.ns_100m}")
"""
def _step_payload(self, action: TrafficLightAction) -> Dict:
return {
"phase": action.phase,
}
def _parse_result(self, payload: Dict) -> StepResult[TrafficLightObservation]:
obs_data = payload.get("observation", {})
observation = TrafficLightObservation(
task_name=obs_data.get("task_name", "balanced"),
# Per-direction 100 m
ns_100m=obs_data.get("ns_100m", 0),
sn_100m=obs_data.get("sn_100m", 0),
ew_100m=obs_data.get("ew_100m", 0),
we_100m=obs_data.get("we_100m", 0),
# Per-direction 500 m
ns_500m=obs_data.get("ns_500m", 0),
sn_500m=obs_data.get("sn_500m", 0),
ew_500m=obs_data.get("ew_500m", 0),
we_500m=obs_data.get("we_500m", 0),
# Lights
light_ns=obs_data.get("light_ns", 0),
light_sn=obs_data.get("light_sn", 0),
light_ew=obs_data.get("light_ew", 0),
light_we=obs_data.get("light_we", 0),
# Emergency
emergency_direction=obs_data.get("emergency_direction", -1),
emergency_lane=obs_data.get("emergency_lane", -1),
emergency_wait=obs_data.get("emergency_wait", 0),
# Phase / timing
active_phase=obs_data.get("active_phase", 0),
yellow_remaining=obs_data.get("yellow_remaining", 0),
time_in_phase=obs_data.get("time_in_phase", 0),
step_number=obs_data.get("step_number", 0),
# Aggregates
total_waiting=obs_data.get("total_waiting", 0),
total_throughput=obs_data.get("total_throughput", 0),
arrivals=obs_data.get("arrivals", [0] * NUM_DIRECTIONS),
departures=obs_data.get("departures", [0] * NUM_DIRECTIONS),
# Per-lane detail
lanes_100m=obs_data.get("lanes_100m", [0] * NUM_LANES),
lanes_500m=obs_data.get("lanes_500m", [0] * NUM_LANES),
# Vehicle composition
vehicles_100m=obs_data.get(
"vehicles_100m",
{vt: [0] * NUM_DIRECTIONS for vt in VEHICLE_TYPE_NAMES},
),
vehicles_500m=obs_data.get(
"vehicles_500m",
{vt: [0] * NUM_DIRECTIONS for vt in VEHICLE_TYPE_NAMES},
),
# Dilemma zone
dilemma_risk=obs_data.get("dilemma_risk", 0.0),
total_dilemma_vehicles=obs_data.get("total_dilemma_vehicles", 0.0),
# Grading
grade_score=obs_data.get("grade_score"),
grade_details=obs_data.get("grade_details"),
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> State:
return State(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
)