open_env_meta / test_env.py
arrow072's picture
Upload 14 files
5516cba verified
"""
test_env.py — Simulation Runner & Sanity Tests
================================================
Provides two entry-points:
run_simulation(mode) – Run one full episode and print a formatted report.
run_all() – Run all three difficulty modes and compare.
run_sanity_checks() – Fast correctness assertions (no pytest needed).
Usage
-----
python test_env.py # runs all modes + sanity checks
python test_env.py easy # run a single mode
"""
from __future__ import annotations
import sys
import builtins
from typing import Dict, Any
from env import TrafficEnv
from tasks import get_config
from baseline_agent import RuleBasedAgent
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_COL = 80 # separator width
def _separator(char: str = "─") -> str:
return char * _COL
_ASCII_FALLBACKS = (
("\u2550", "="),
("\u2500", "-"),
("\u2502", "|"),
("\u00b7", "-"),
("\U0001F6A8", "EV"),
("\u2713", "PASS"),
("\u2717", "FAIL"),
("\u26a0\ufe0f", "WARNING"),
("\u2705", "PASS"),
("\u2014", "-"),
("\u2265", ">="),
("\u2264", "<="),
("\u2208", "in"),
)
def _safe_text(text: str) -> str:
encoding = getattr(sys.stdout, "encoding", None) or "utf-8"
try:
text.encode(encoding)
return text
except UnicodeEncodeError:
for src, dest in _ASCII_FALLBACKS:
text = text.replace(src, dest)
return text
def print(*args, **kwargs) -> None: # type: ignore[override]
"""
Safe local print wrapper:
- keeps rich Unicode output when supported
- falls back to ASCII-safe glyphs on limited encodings (e.g. cp1252)
"""
file = kwargs.get("file", sys.stdout)
if file is not sys.stdout:
builtins.print(*args, **kwargs)
return
sep = kwargs.get("sep", " ")
end = kwargs.get("end", "\n")
flush = kwargs.get("flush", False)
text = sep.join(str(arg) for arg in args)
builtins.print(_safe_text(text), end=end, flush=flush, file=file)
def _fmt_metric(key: str, value: Any) -> str:
label = key.replace("_", " ").title()
if isinstance(value, float):
return f" {label:<30} {value:.4f}"
return f" {label:<30} {value}"
# ---------------------------------------------------------------------------
# Single-mode simulation
# ---------------------------------------------------------------------------
def run_simulation(mode: str = "medium", verbose: bool = True) -> Dict[str, Any]:
"""
Run one complete episode in the specified difficulty mode.
Parameters
----------
mode : str
"easy", "medium", or "hard"
verbose : bool
Print step-by-step output if True.
Returns
-------
dict
Final info metrics plus 'cumulative_reward' and 'mode'.
"""
config = get_config(mode)
env = TrafficEnv(config)
agent = RuleBasedAgent(
min_green_time=5,
imbalance_threshold=5,
max_green_time=15,
emergency_min_green=2,
)
state = env.reset()
agent.reset()
done = False
total_reward = 0.0
step_rewards = []
if verbose:
print()
print(_separator("═"))
print(f" TRAFFIC SIGNAL SIMULATION · Mode: {mode.upper()}")
print(_separator("═"))
header = (
f"{'Step':<6}{'Phase':<4} │ "
f"{'N':>4} {'S':>4} {'E':>4} {'W':>4} │ "
f"{'NS':>4} {'EW':>4} │ "
f"{'Reward':>8} │ EV"
)
print(header)
print(_separator())
while not done:
action = agent.select_action(state)
next_state, reward, done, info = env.step(action)
total_reward += reward
step_rewards.append(reward)
if verbose:
phase_str = "NS" if next_state["phase"] == 0 else "EW"
ns_q = next_state["north_cars"] + next_state["south_cars"]
ew_q = next_state["east_cars"] + next_state["west_cars"]
ev_flags = next_state["emergency_flags"]
ev_active = "🚨" if any(ev_flags.values()) else " "
# Print every 5 steps, or whenever there's an emergency
if env.step_count % 5 == 0 or any(ev_flags.values()):
print(
f"{env.step_count:<6}{phase_str:<4} │ "
f"{next_state['north_cars']:>4} "
f"{next_state['south_cars']:>4} "
f"{next_state['east_cars']:>4} "
f"{next_state['west_cars']:>4} │ "
f"{ns_q:>4} {ew_q:>4} │ "
f"{reward:>8.3f}{ev_active}"
)
state = next_state
if verbose:
print(_separator())
print(f"\n FINAL METRICS ({mode.upper()})")
print(_separator())
for k, v in info.items():
print(_fmt_metric(k, v))
print(_fmt_metric("cumulative_reward", total_reward))
if step_rewards:
print(_fmt_metric("min_step_reward", min(step_rewards)))
print(_fmt_metric("max_step_reward", max(step_rewards)))
print()
result = dict(info)
result["cumulative_reward"] = total_reward
result["mode"] = mode
return result
# ---------------------------------------------------------------------------
# Run all modes and print comparison table
# ---------------------------------------------------------------------------
def run_all() -> None:
"""Run easy, medium and hard in sequence; print a comparison table."""
results = {}
for mode in ("easy", "medium", "hard"):
results[mode] = run_simulation(mode, verbose=True)
print()
print(_separator("═"))
print(" CROSS-MODE COMPARISON")
print(_separator("═"))
metrics = [
"total_cleared", "avg_waiting_time",
"max_queue_length", "signal_switch_count",
"congestion_score", "avg_ev_clear_time",
"fairness_score", "cumulative_reward",
]
col_w = 18
header = f" {'Metric':<30}" + "".join(f"{m.upper():>{col_w}}" for m in ("easy", "medium", "hard"))
print(header)
print(_separator())
for m in metrics:
row = f" {m.replace('_',' ').title():<30}"
for mode in ("easy", "medium", "hard"):
val = results[mode].get(m, "—")
if isinstance(val, float):
row += f"{val:>{col_w}.3f}"
else:
row += f"{val:>{col_w}}"
print(row)
print(_separator("═"))
print()
# ---------------------------------------------------------------------------
# Sanity / correctness checks (no external test runner needed)
# ---------------------------------------------------------------------------
def run_sanity_checks() -> None:
"""Assert basic correctness invariants for all difficulty modes."""
print()
print(_separator("═"))
print(" SANITY CHECKS")
print(_separator("═"))
passed = 0
failed = 0
def check(name: str, condition: bool) -> None:
nonlocal passed, failed
status = "✓ PASS" if condition else "✗ FAIL"
print(f" [{status}] {name}")
if condition:
passed += 1
else:
failed += 1
for mode in ("easy", "medium", "hard"):
cfg = get_config(mode)
env = TrafficEnv(cfg)
agent = RuleBasedAgent()
# 1. reset() returns valid state
state = env.reset()
agent.reset()
check(
f"[{mode}] reset() returns all-zero queues",
all(state[f"{d}_cars"] == 0 for d in ("north", "south", "east", "west")),
)
# 2. Step returns correct tuple length
action = agent.select_action(state)
result = env.step(action)
check(f"[{mode}] step() returns 4-tuple", len(result) == 4)
ns, reward, done, info = result
# 3. Reward is clipped
check(f"[{mode}] reward in [-1, 1]", -1.0 <= reward <= 1.0)
# 4. State keys present
required_keys = {
"north_cars", "south_cars", "east_cars", "west_cars",
"waiting_times", "phase", "emergency_flags", "step_count",
}
check(f"[{mode}] state has required keys", required_keys.issubset(ns.keys()))
# 5. Info keys present
required_info = {
"total_cleared", "avg_waiting_time",
"max_queue_length", "signal_switch_count",
"congestion_score", "avg_ev_clear_time",
"fairness_score",
}
check(f"[{mode}] info has required keys", required_info.issubset(info.keys()))
# 6. Queues never go negative
for _ in range(cfg["max_steps"]):
a = agent.select_action(ns)
ns, _, done, _ = env.step(a)
if done:
break
all_non_neg = all(v >= 0 for v in env.queues.values())
check(f"[{mode}] queues never go negative (full episode)", all_non_neg)
# 7. Queues never exceed max_queue
check(
f"[{mode}] queues never exceed max_queue ({cfg['max_queue']})",
all(v <= cfg["max_queue"] for v in env.queues.values()),
)
# 8. Signal phase is always 0 or 1
check(f"[{mode}] phase is always 0 or 1", env.phase in (0, 1))
# 9. total_cleared is non-negative
check(f"[{mode}] total_cleared ≥ 0", env.total_cleared >= 0)
# 10. congestion_score in [0, 1]
score = info["congestion_score"]
check(f"[{mode}] congestion_score ∈ [0, 1]", 0.0 <= score <= 1.0)
print()
print(_separator())
print(f" Results: {passed} passed, {failed} failed")
print(_separator("═"))
if failed:
print(" ⚠️ Some checks failed — review the environment logic.")
else:
print(" ✅ All sanity checks passed.")
print()
# ---------------------------------------------------------------------------
# CLI entry-point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
if len(sys.argv) == 2 and sys.argv[1].lower() in ("easy", "medium", "hard"):
run_simulation(sys.argv[1].lower(), verbose=True)
else:
run_all()
run_sanity_checks()