Spaces:
Sleeping
Sleeping
| """ | |
| baseline_agent.py — Rule-Based Traffic Signal Controller | |
| ========================================================= | |
| A deterministic agent that makes signal decisions using handcrafted | |
| heuristics. Acts as the reproducible baseline for comparison against | |
| trained RL policies. | |
| Decision hierarchy (highest priority first): | |
| 1. Emergency vehicle preemption — switch if an emergency vehicle is | |
| stuck at a red light and minimum green time has been served. | |
| 2. Minimum green time — never switch before a floor number of steps | |
| to prevent rapid oscillation. | |
| 3. Queue-imbalance trigger — switch when the queued-vehicle disparity | |
| between NS and EW exceeds a configurable threshold. | |
| 4. Maximum green cap — force a switch if one direction has been green | |
| for too long (fairness guard). | |
| 5. Default — keep current phase. | |
| Usage | |
| ----- | |
| from baseline_agent import RuleBasedAgent | |
| agent = RuleBasedAgent(min_green_time=5, imbalance_threshold=5) | |
| action = agent.select_action(state) # 0 or 1 | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict | |
| class RuleBasedAgent: | |
| """ | |
| Rule-based traffic signal controller. | |
| Parameters | |
| ---------- | |
| min_green_time : int | |
| Minimum number of steps to hold a phase before switching. | |
| Prevents oscillatory behaviour. | |
| imbalance_threshold : int | |
| Minimum queue difference (NS vs EW) required to trigger a switch. | |
| max_green_time : int | |
| Maximum consecutive steps before forcing a phase change. | |
| Acts as a starvation safety net. | |
| emergency_min_green : int | |
| Reduced minimum green time used when an emergency vehicle is | |
| waiting on a red lane. | |
| """ | |
| def __init__( | |
| self, | |
| min_green_time: int = 5, | |
| imbalance_threshold: int = 5, | |
| max_green_time: int = 20, | |
| emergency_min_green: int = 2, | |
| ) -> None: | |
| self.min_green_time = min_green_time | |
| self.imbalance_threshold = imbalance_threshold | |
| self.max_green_time = max_green_time | |
| self.emergency_min_green = emergency_min_green | |
| # Steps since last switch | |
| self._steps_since_switch: int = 0 | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def select_action(self, state: Dict[str, Any]) -> int: | |
| """ | |
| Choose an action given the current environment state. | |
| Parameters | |
| ---------- | |
| state : dict | |
| State dictionary as returned by ``TrafficEnv.get_state()``. | |
| Returns | |
| ------- | |
| int | |
| 0 → keep current signal phase | |
| 1 → switch signal phase | |
| """ | |
| self._steps_since_switch += 1 | |
| north = state["north_cars"] | |
| south = state["south_cars"] | |
| east = state["east_cars"] | |
| west = state["west_cars"] | |
| phase = state["phase"] | |
| # emergency_flags may be a dict (TrafficEnv) or a list (legacy) | |
| ef = state["emergency_flags"] | |
| if isinstance(ef, dict): | |
| ev_north, ev_south = ef["north"], ef["south"] | |
| ev_east, ev_west = ef["east"], ef["west"] | |
| else: | |
| ev_north, ev_south, ev_east, ev_west = (bool(x) for x in ef) | |
| ns_total = north + south | |
| ew_total = east + west | |
| # ── Rule 1: Emergency preemption ────────────────────────────── | |
| # High priority: switch if an EV is blocked on a red lane. | |
| # We apply a small safety buffer (2 steps) to avoid rapid jitter. | |
| emergency_on_red = False | |
| if phase == 0 and (ev_east or ev_west): | |
| emergency_on_red = True | |
| elif phase == 1 and (ev_north or ev_south): | |
| emergency_on_red = True | |
| if emergency_on_red: | |
| if self._steps_since_switch >= self.emergency_min_green: | |
| return self._switch() | |
| # ── Rule 2: Oscillation Damping (Minimum Green Time) ────────── | |
| if self._steps_since_switch < self.min_green_time: | |
| return 0 | |
| # ── Rule 3: Congestion/Pressure Trigger ─────────────────────── | |
| # We use a weighted pressure calculation (Queues + EV presence). | |
| ns_pressure = ns_total + (20 if (ev_north or ev_south) else 0) | |
| ew_pressure = ew_total + (20 if (ev_east or ev_west) else 0) | |
| if phase == 0: # NS currently green | |
| # Only switch if EW pressure is significantly higher | |
| if ew_pressure > ns_pressure + self.imbalance_threshold: | |
| return self._switch() | |
| else: # EW currently green | |
| if ns_pressure > ew_pressure + self.imbalance_threshold: | |
| return self._switch() | |
| # ── Rule 4: Fairness Guard (Maximum Green Time) ────────────── | |
| if self._steps_since_switch >= self.max_green_time: | |
| # Only switch if there's actually someone waiting on the other side | |
| other_side_waiting = (ew_total > 0) if phase == 0 else (ns_total > 0) | |
| if other_side_waiting: | |
| return self._switch() | |
| # ── Rule 5: Default — hold current phase ───────────────────── | |
| return 0 | |
| def reset(self) -> None: | |
| """Reset internal step counter (call at the start of each episode).""" | |
| self._steps_since_switch = 0 | |
| # ------------------------------------------------------------------ | |
| # Internal helpers | |
| # ------------------------------------------------------------------ | |
| def _switch(self) -> int: | |
| """Record a switch and reset the step counter.""" | |
| self._steps_since_switch = 0 | |
| return 1 | |