""" test_env.py — Simulation Runner & Sanity Tests ================================================ Provides two entry-points: run_simulation(mode) – Run one full episode and print a formatted report. run_all() – Run all three difficulty modes and compare. run_sanity_checks() – Fast correctness assertions (no pytest needed). Usage ----- python test_env.py # runs all modes + sanity checks python test_env.py easy # run a single mode """ from __future__ import annotations import sys import builtins from typing import Dict, Any from env import TrafficEnv from tasks import get_config from baseline_agent import RuleBasedAgent # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- _COL = 80 # separator width def _separator(char: str = "─") -> str: return char * _COL _ASCII_FALLBACKS = ( ("\u2550", "="), ("\u2500", "-"), ("\u2502", "|"), ("\u00b7", "-"), ("\U0001F6A8", "EV"), ("\u2713", "PASS"), ("\u2717", "FAIL"), ("\u26a0\ufe0f", "WARNING"), ("\u2705", "PASS"), ("\u2014", "-"), ("\u2265", ">="), ("\u2264", "<="), ("\u2208", "in"), ) def _safe_text(text: str) -> str: encoding = getattr(sys.stdout, "encoding", None) or "utf-8" try: text.encode(encoding) return text except UnicodeEncodeError: for src, dest in _ASCII_FALLBACKS: text = text.replace(src, dest) return text def print(*args, **kwargs) -> None: # type: ignore[override] """ Safe local print wrapper: - keeps rich Unicode output when supported - falls back to ASCII-safe glyphs on limited encodings (e.g. cp1252) """ file = kwargs.get("file", sys.stdout) if file is not sys.stdout: builtins.print(*args, **kwargs) return sep = kwargs.get("sep", " ") end = kwargs.get("end", "\n") flush = kwargs.get("flush", False) text = sep.join(str(arg) for arg in args) builtins.print(_safe_text(text), end=end, flush=flush, file=file) def _fmt_metric(key: str, value: Any) -> str: label = key.replace("_", " ").title() if isinstance(value, float): return f" {label:<30} {value:.4f}" return f" {label:<30} {value}" # --------------------------------------------------------------------------- # Single-mode simulation # --------------------------------------------------------------------------- def run_simulation(mode: str = "medium", verbose: bool = True) -> Dict[str, Any]: """ Run one complete episode in the specified difficulty mode. Parameters ---------- mode : str "easy", "medium", or "hard" verbose : bool Print step-by-step output if True. Returns ------- dict Final info metrics plus 'cumulative_reward' and 'mode'. """ config = get_config(mode) env = TrafficEnv(config) agent = RuleBasedAgent( min_green_time=5, imbalance_threshold=5, max_green_time=15, emergency_min_green=2, ) state = env.reset() agent.reset() done = False total_reward = 0.0 step_rewards = [] if verbose: print() print(_separator("═")) print(f" TRAFFIC SIGNAL SIMULATION · Mode: {mode.upper()}") print(_separator("═")) header = ( f"{'Step':<6} │ {'Phase':<4} │ " f"{'N':>4} {'S':>4} {'E':>4} {'W':>4} │ " f"{'NS':>4} {'EW':>4} │ " f"{'Reward':>8} │ EV" ) print(header) print(_separator()) while not done: action = agent.select_action(state) next_state, reward, done, info = env.step(action) total_reward += reward step_rewards.append(reward) if verbose: phase_str = "NS" if next_state["phase"] == 0 else "EW" ns_q = next_state["north_cars"] + next_state["south_cars"] ew_q = next_state["east_cars"] + next_state["west_cars"] ev_flags = next_state["emergency_flags"] ev_active = "🚨" if any(ev_flags.values()) else " " # Print every 5 steps, or whenever there's an emergency if env.step_count % 5 == 0 or any(ev_flags.values()): print( f"{env.step_count:<6} │ {phase_str:<4} │ " f"{next_state['north_cars']:>4} " f"{next_state['south_cars']:>4} " f"{next_state['east_cars']:>4} " f"{next_state['west_cars']:>4} │ " f"{ns_q:>4} {ew_q:>4} │ " f"{reward:>8.3f} │ {ev_active}" ) state = next_state if verbose: print(_separator()) print(f"\n FINAL METRICS ({mode.upper()})") print(_separator()) for k, v in info.items(): print(_fmt_metric(k, v)) print(_fmt_metric("cumulative_reward", total_reward)) if step_rewards: print(_fmt_metric("min_step_reward", min(step_rewards))) print(_fmt_metric("max_step_reward", max(step_rewards))) print() result = dict(info) result["cumulative_reward"] = total_reward result["mode"] = mode return result # --------------------------------------------------------------------------- # Run all modes and print comparison table # --------------------------------------------------------------------------- def run_all() -> None: """Run easy, medium and hard in sequence; print a comparison table.""" results = {} for mode in ("easy", "medium", "hard"): results[mode] = run_simulation(mode, verbose=True) print() print(_separator("═")) print(" CROSS-MODE COMPARISON") print(_separator("═")) metrics = [ "total_cleared", "avg_waiting_time", "max_queue_length", "signal_switch_count", "congestion_score", "avg_ev_clear_time", "fairness_score", "cumulative_reward", ] col_w = 18 header = f" {'Metric':<30}" + "".join(f"{m.upper():>{col_w}}" for m in ("easy", "medium", "hard")) print(header) print(_separator()) for m in metrics: row = f" {m.replace('_',' ').title():<30}" for mode in ("easy", "medium", "hard"): val = results[mode].get(m, "—") if isinstance(val, float): row += f"{val:>{col_w}.3f}" else: row += f"{val:>{col_w}}" print(row) print(_separator("═")) print() # --------------------------------------------------------------------------- # Sanity / correctness checks (no external test runner needed) # --------------------------------------------------------------------------- def run_sanity_checks() -> None: """Assert basic correctness invariants for all difficulty modes.""" print() print(_separator("═")) print(" SANITY CHECKS") print(_separator("═")) passed = 0 failed = 0 def check(name: str, condition: bool) -> None: nonlocal passed, failed status = "✓ PASS" if condition else "✗ FAIL" print(f" [{status}] {name}") if condition: passed += 1 else: failed += 1 for mode in ("easy", "medium", "hard"): cfg = get_config(mode) env = TrafficEnv(cfg) agent = RuleBasedAgent() # 1. reset() returns valid state state = env.reset() agent.reset() check( f"[{mode}] reset() returns all-zero queues", all(state[f"{d}_cars"] == 0 for d in ("north", "south", "east", "west")), ) # 2. Step returns correct tuple length action = agent.select_action(state) result = env.step(action) check(f"[{mode}] step() returns 4-tuple", len(result) == 4) ns, reward, done, info = result # 3. Reward is clipped check(f"[{mode}] reward in [-1, 1]", -1.0 <= reward <= 1.0) # 4. State keys present required_keys = { "north_cars", "south_cars", "east_cars", "west_cars", "waiting_times", "phase", "emergency_flags", "step_count", } check(f"[{mode}] state has required keys", required_keys.issubset(ns.keys())) # 5. Info keys present required_info = { "total_cleared", "avg_waiting_time", "max_queue_length", "signal_switch_count", "congestion_score", "avg_ev_clear_time", "fairness_score", } check(f"[{mode}] info has required keys", required_info.issubset(info.keys())) # 6. Queues never go negative for _ in range(cfg["max_steps"]): a = agent.select_action(ns) ns, _, done, _ = env.step(a) if done: break all_non_neg = all(v >= 0 for v in env.queues.values()) check(f"[{mode}] queues never go negative (full episode)", all_non_neg) # 7. Queues never exceed max_queue check( f"[{mode}] queues never exceed max_queue ({cfg['max_queue']})", all(v <= cfg["max_queue"] for v in env.queues.values()), ) # 8. Signal phase is always 0 or 1 check(f"[{mode}] phase is always 0 or 1", env.phase in (0, 1)) # 9. total_cleared is non-negative check(f"[{mode}] total_cleared ≥ 0", env.total_cleared >= 0) # 10. congestion_score in [0, 1] score = info["congestion_score"] check(f"[{mode}] congestion_score ∈ [0, 1]", 0.0 <= score <= 1.0) print() print(_separator()) print(f" Results: {passed} passed, {failed} failed") print(_separator("═")) if failed: print(" ⚠️ Some checks failed — review the environment logic.") else: print(" ✅ All sanity checks passed.") print() # --------------------------------------------------------------------------- # CLI entry-point # --------------------------------------------------------------------------- if __name__ == "__main__": if len(sys.argv) == 2 and sys.argv[1].lower() in ("easy", "medium", "hard"): run_simulation(sys.argv[1].lower(), verbose=True) else: run_all() run_sanity_checks()