Spaces:
Sleeping
Sleeping
| """ | |
| test_env.py — Simulation Runner & Sanity Tests | |
| ================================================ | |
| Provides two entry-points: | |
| run_simulation(mode) – Run one full episode and print a formatted report. | |
| run_all() – Run all three difficulty modes and compare. | |
| run_sanity_checks() – Fast correctness assertions (no pytest needed). | |
| Usage | |
| ----- | |
| python test_env.py # runs all modes + sanity checks | |
| python test_env.py easy # run a single mode | |
| """ | |
| from __future__ import annotations | |
| import sys | |
| import builtins | |
| from typing import Dict, Any | |
| from env import TrafficEnv | |
| from tasks import get_config | |
| from baseline_agent import RuleBasedAgent | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| _COL = 80 # separator width | |
| def _separator(char: str = "─") -> str: | |
| return char * _COL | |
| _ASCII_FALLBACKS = ( | |
| ("\u2550", "="), | |
| ("\u2500", "-"), | |
| ("\u2502", "|"), | |
| ("\u00b7", "-"), | |
| ("\U0001F6A8", "EV"), | |
| ("\u2713", "PASS"), | |
| ("\u2717", "FAIL"), | |
| ("\u26a0\ufe0f", "WARNING"), | |
| ("\u2705", "PASS"), | |
| ("\u2014", "-"), | |
| ("\u2265", ">="), | |
| ("\u2264", "<="), | |
| ("\u2208", "in"), | |
| ) | |
| def _safe_text(text: str) -> str: | |
| encoding = getattr(sys.stdout, "encoding", None) or "utf-8" | |
| try: | |
| text.encode(encoding) | |
| return text | |
| except UnicodeEncodeError: | |
| for src, dest in _ASCII_FALLBACKS: | |
| text = text.replace(src, dest) | |
| return text | |
| def print(*args, **kwargs) -> None: # type: ignore[override] | |
| """ | |
| Safe local print wrapper: | |
| - keeps rich Unicode output when supported | |
| - falls back to ASCII-safe glyphs on limited encodings (e.g. cp1252) | |
| """ | |
| file = kwargs.get("file", sys.stdout) | |
| if file is not sys.stdout: | |
| builtins.print(*args, **kwargs) | |
| return | |
| sep = kwargs.get("sep", " ") | |
| end = kwargs.get("end", "\n") | |
| flush = kwargs.get("flush", False) | |
| text = sep.join(str(arg) for arg in args) | |
| builtins.print(_safe_text(text), end=end, flush=flush, file=file) | |
| def _fmt_metric(key: str, value: Any) -> str: | |
| label = key.replace("_", " ").title() | |
| if isinstance(value, float): | |
| return f" {label:<30} {value:.4f}" | |
| return f" {label:<30} {value}" | |
| # --------------------------------------------------------------------------- | |
| # Single-mode simulation | |
| # --------------------------------------------------------------------------- | |
| def run_simulation(mode: str = "medium", verbose: bool = True) -> Dict[str, Any]: | |
| """ | |
| Run one complete episode in the specified difficulty mode. | |
| Parameters | |
| ---------- | |
| mode : str | |
| "easy", "medium", or "hard" | |
| verbose : bool | |
| Print step-by-step output if True. | |
| Returns | |
| ------- | |
| dict | |
| Final info metrics plus 'cumulative_reward' and 'mode'. | |
| """ | |
| config = get_config(mode) | |
| env = TrafficEnv(config) | |
| agent = RuleBasedAgent( | |
| min_green_time=5, | |
| imbalance_threshold=5, | |
| max_green_time=15, | |
| emergency_min_green=2, | |
| ) | |
| state = env.reset() | |
| agent.reset() | |
| done = False | |
| total_reward = 0.0 | |
| step_rewards = [] | |
| if verbose: | |
| print() | |
| print(_separator("═")) | |
| print(f" TRAFFIC SIGNAL SIMULATION · Mode: {mode.upper()}") | |
| print(_separator("═")) | |
| header = ( | |
| f"{'Step':<6} │ {'Phase':<4} │ " | |
| f"{'N':>4} {'S':>4} {'E':>4} {'W':>4} │ " | |
| f"{'NS':>4} {'EW':>4} │ " | |
| f"{'Reward':>8} │ EV" | |
| ) | |
| print(header) | |
| print(_separator()) | |
| while not done: | |
| action = agent.select_action(state) | |
| next_state, reward, done, info = env.step(action) | |
| total_reward += reward | |
| step_rewards.append(reward) | |
| if verbose: | |
| phase_str = "NS" if next_state["phase"] == 0 else "EW" | |
| ns_q = next_state["north_cars"] + next_state["south_cars"] | |
| ew_q = next_state["east_cars"] + next_state["west_cars"] | |
| ev_flags = next_state["emergency_flags"] | |
| ev_active = "🚨" if any(ev_flags.values()) else " " | |
| # Print every 5 steps, or whenever there's an emergency | |
| if env.step_count % 5 == 0 or any(ev_flags.values()): | |
| print( | |
| f"{env.step_count:<6} │ {phase_str:<4} │ " | |
| f"{next_state['north_cars']:>4} " | |
| f"{next_state['south_cars']:>4} " | |
| f"{next_state['east_cars']:>4} " | |
| f"{next_state['west_cars']:>4} │ " | |
| f"{ns_q:>4} {ew_q:>4} │ " | |
| f"{reward:>8.3f} │ {ev_active}" | |
| ) | |
| state = next_state | |
| if verbose: | |
| print(_separator()) | |
| print(f"\n FINAL METRICS ({mode.upper()})") | |
| print(_separator()) | |
| for k, v in info.items(): | |
| print(_fmt_metric(k, v)) | |
| print(_fmt_metric("cumulative_reward", total_reward)) | |
| if step_rewards: | |
| print(_fmt_metric("min_step_reward", min(step_rewards))) | |
| print(_fmt_metric("max_step_reward", max(step_rewards))) | |
| print() | |
| result = dict(info) | |
| result["cumulative_reward"] = total_reward | |
| result["mode"] = mode | |
| return result | |
| # --------------------------------------------------------------------------- | |
| # Run all modes and print comparison table | |
| # --------------------------------------------------------------------------- | |
| def run_all() -> None: | |
| """Run easy, medium and hard in sequence; print a comparison table.""" | |
| results = {} | |
| for mode in ("easy", "medium", "hard"): | |
| results[mode] = run_simulation(mode, verbose=True) | |
| print() | |
| print(_separator("═")) | |
| print(" CROSS-MODE COMPARISON") | |
| print(_separator("═")) | |
| metrics = [ | |
| "total_cleared", "avg_waiting_time", | |
| "max_queue_length", "signal_switch_count", | |
| "congestion_score", "avg_ev_clear_time", | |
| "fairness_score", "cumulative_reward", | |
| ] | |
| col_w = 18 | |
| header = f" {'Metric':<30}" + "".join(f"{m.upper():>{col_w}}" for m in ("easy", "medium", "hard")) | |
| print(header) | |
| print(_separator()) | |
| for m in metrics: | |
| row = f" {m.replace('_',' ').title():<30}" | |
| for mode in ("easy", "medium", "hard"): | |
| val = results[mode].get(m, "—") | |
| if isinstance(val, float): | |
| row += f"{val:>{col_w}.3f}" | |
| else: | |
| row += f"{val:>{col_w}}" | |
| print(row) | |
| print(_separator("═")) | |
| print() | |
| # --------------------------------------------------------------------------- | |
| # Sanity / correctness checks (no external test runner needed) | |
| # --------------------------------------------------------------------------- | |
| def run_sanity_checks() -> None: | |
| """Assert basic correctness invariants for all difficulty modes.""" | |
| print() | |
| print(_separator("═")) | |
| print(" SANITY CHECKS") | |
| print(_separator("═")) | |
| passed = 0 | |
| failed = 0 | |
| def check(name: str, condition: bool) -> None: | |
| nonlocal passed, failed | |
| status = "✓ PASS" if condition else "✗ FAIL" | |
| print(f" [{status}] {name}") | |
| if condition: | |
| passed += 1 | |
| else: | |
| failed += 1 | |
| for mode in ("easy", "medium", "hard"): | |
| cfg = get_config(mode) | |
| env = TrafficEnv(cfg) | |
| agent = RuleBasedAgent() | |
| # 1. reset() returns valid state | |
| state = env.reset() | |
| agent.reset() | |
| check( | |
| f"[{mode}] reset() returns all-zero queues", | |
| all(state[f"{d}_cars"] == 0 for d in ("north", "south", "east", "west")), | |
| ) | |
| # 2. Step returns correct tuple length | |
| action = agent.select_action(state) | |
| result = env.step(action) | |
| check(f"[{mode}] step() returns 4-tuple", len(result) == 4) | |
| ns, reward, done, info = result | |
| # 3. Reward is clipped | |
| check(f"[{mode}] reward in [-1, 1]", -1.0 <= reward <= 1.0) | |
| # 4. State keys present | |
| required_keys = { | |
| "north_cars", "south_cars", "east_cars", "west_cars", | |
| "waiting_times", "phase", "emergency_flags", "step_count", | |
| } | |
| check(f"[{mode}] state has required keys", required_keys.issubset(ns.keys())) | |
| # 5. Info keys present | |
| required_info = { | |
| "total_cleared", "avg_waiting_time", | |
| "max_queue_length", "signal_switch_count", | |
| "congestion_score", "avg_ev_clear_time", | |
| "fairness_score", | |
| } | |
| check(f"[{mode}] info has required keys", required_info.issubset(info.keys())) | |
| # 6. Queues never go negative | |
| for _ in range(cfg["max_steps"]): | |
| a = agent.select_action(ns) | |
| ns, _, done, _ = env.step(a) | |
| if done: | |
| break | |
| all_non_neg = all(v >= 0 for v in env.queues.values()) | |
| check(f"[{mode}] queues never go negative (full episode)", all_non_neg) | |
| # 7. Queues never exceed max_queue | |
| check( | |
| f"[{mode}] queues never exceed max_queue ({cfg['max_queue']})", | |
| all(v <= cfg["max_queue"] for v in env.queues.values()), | |
| ) | |
| # 8. Signal phase is always 0 or 1 | |
| check(f"[{mode}] phase is always 0 or 1", env.phase in (0, 1)) | |
| # 9. total_cleared is non-negative | |
| check(f"[{mode}] total_cleared ≥ 0", env.total_cleared >= 0) | |
| # 10. congestion_score in [0, 1] | |
| score = info["congestion_score"] | |
| check(f"[{mode}] congestion_score ∈ [0, 1]", 0.0 <= score <= 1.0) | |
| print() | |
| print(_separator()) | |
| print(f" Results: {passed} passed, {failed} failed") | |
| print(_separator("═")) | |
| if failed: | |
| print(" ⚠️ Some checks failed — review the environment logic.") | |
| else: | |
| print(" ✅ All sanity checks passed.") | |
| print() | |
| # --------------------------------------------------------------------------- | |
| # CLI entry-point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| if len(sys.argv) == 2 and sys.argv[1].lower() in ("easy", "medium", "hard"): | |
| run_simulation(sys.argv[1].lower(), verbose=True) | |
| else: | |
| run_all() | |
| run_sanity_checks() | |