open_env_meta / baseline_agent.py
arrow072's picture
Upload 14 files
5516cba verified
"""
baseline_agent.py — Rule-Based Traffic Signal Controller
=========================================================
A deterministic agent that makes signal decisions using handcrafted
heuristics. Acts as the reproducible baseline for comparison against
trained RL policies.
Decision hierarchy (highest priority first):
1. Emergency vehicle preemption — switch if an emergency vehicle is
stuck at a red light and minimum green time has been served.
2. Minimum green time — never switch before a floor number of steps
to prevent rapid oscillation.
3. Queue-imbalance trigger — switch when the queued-vehicle disparity
between NS and EW exceeds a configurable threshold.
4. Maximum green cap — force a switch if one direction has been green
for too long (fairness guard).
5. Default — keep current phase.
Usage
-----
from baseline_agent import RuleBasedAgent
agent = RuleBasedAgent(min_green_time=5, imbalance_threshold=5)
action = agent.select_action(state) # 0 or 1
"""
from __future__ import annotations
from typing import Any, Dict
class RuleBasedAgent:
"""
Rule-based traffic signal controller.
Parameters
----------
min_green_time : int
Minimum number of steps to hold a phase before switching.
Prevents oscillatory behaviour.
imbalance_threshold : int
Minimum queue difference (NS vs EW) required to trigger a switch.
max_green_time : int
Maximum consecutive steps before forcing a phase change.
Acts as a starvation safety net.
emergency_min_green : int
Reduced minimum green time used when an emergency vehicle is
waiting on a red lane.
"""
def __init__(
self,
min_green_time: int = 5,
imbalance_threshold: int = 5,
max_green_time: int = 20,
emergency_min_green: int = 2,
) -> None:
self.min_green_time = min_green_time
self.imbalance_threshold = imbalance_threshold
self.max_green_time = max_green_time
self.emergency_min_green = emergency_min_green
# Steps since last switch
self._steps_since_switch: int = 0
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def select_action(self, state: Dict[str, Any]) -> int:
"""
Choose an action given the current environment state.
Parameters
----------
state : dict
State dictionary as returned by ``TrafficEnv.get_state()``.
Returns
-------
int
0 → keep current signal phase
1 → switch signal phase
"""
self._steps_since_switch += 1
north = state["north_cars"]
south = state["south_cars"]
east = state["east_cars"]
west = state["west_cars"]
phase = state["phase"]
# emergency_flags may be a dict (TrafficEnv) or a list (legacy)
ef = state["emergency_flags"]
if isinstance(ef, dict):
ev_north, ev_south = ef["north"], ef["south"]
ev_east, ev_west = ef["east"], ef["west"]
else:
ev_north, ev_south, ev_east, ev_west = (bool(x) for x in ef)
ns_total = north + south
ew_total = east + west
# ── Rule 1: Emergency preemption ──────────────────────────────
# High priority: switch if an EV is blocked on a red lane.
# We apply a small safety buffer (2 steps) to avoid rapid jitter.
emergency_on_red = False
if phase == 0 and (ev_east or ev_west):
emergency_on_red = True
elif phase == 1 and (ev_north or ev_south):
emergency_on_red = True
if emergency_on_red:
if self._steps_since_switch >= self.emergency_min_green:
return self._switch()
# ── Rule 2: Oscillation Damping (Minimum Green Time) ──────────
if self._steps_since_switch < self.min_green_time:
return 0
# ── Rule 3: Congestion/Pressure Trigger ───────────────────────
# We use a weighted pressure calculation (Queues + EV presence).
ns_pressure = ns_total + (20 if (ev_north or ev_south) else 0)
ew_pressure = ew_total + (20 if (ev_east or ev_west) else 0)
if phase == 0: # NS currently green
# Only switch if EW pressure is significantly higher
if ew_pressure > ns_pressure + self.imbalance_threshold:
return self._switch()
else: # EW currently green
if ns_pressure > ew_pressure + self.imbalance_threshold:
return self._switch()
# ── Rule 4: Fairness Guard (Maximum Green Time) ──────────────
if self._steps_since_switch >= self.max_green_time:
# Only switch if there's actually someone waiting on the other side
other_side_waiting = (ew_total > 0) if phase == 0 else (ns_total > 0)
if other_side_waiting:
return self._switch()
# ── Rule 5: Default — hold current phase ─────────────────────
return 0
def reset(self) -> None:
"""Reset internal step counter (call at the start of each episode)."""
self._steps_since_switch = 0
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _switch(self) -> int:
"""Record a switch and reset the step counter."""
self._steps_since_switch = 0
return 1