Spaces:
Sleeping
Sleeping
| """Tests for the PermitPathfinder FSM — transitions, optimal policies, edge cases.""" | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| from server.permit_env_environment import ( | |
| PermitEnvironment, | |
| TASKS, | |
| STAGE_ISSUED, | |
| ) | |
| from models import PermitAction | |
| def _step(env, action_type, permit_id=None): | |
| return env.step(PermitAction(action_type=action_type, permit_id=permit_id)) | |
| # ---------- Optimal policy ---------- | |
| def test_optimal_easy_foodtruck(): | |
| """Walk the optimal submit->pay->inspect sequence for all 3 easy permits.""" | |
| env = PermitEnvironment() | |
| obs = env.reset(seed=42, task_name="easy_foodtruck") | |
| assert obs.task_name == "easy_foodtruck" | |
| for pid in list(obs.permits.keys()): | |
| for action in ["submit", "pay", "inspect"]: | |
| obs = _step(env, action, pid) | |
| assert obs.last_action_error is None, ( | |
| f"{action}({pid}) failed: {obs.last_action_error}" | |
| ) | |
| assert obs.done is True | |
| assert obs.reward >= 0.9 | |
| assert obs.wasted_submissions == 0 | |
| def test_optimal_medium_cafe(): | |
| """Walk the optimal policy for medium_cafe respecting dependency order.""" | |
| env = PermitEnvironment() | |
| obs = env.reset(seed=100, task_name="medium_cafe") | |
| # Correct topological order | |
| order = [ | |
| "business_license", | |
| "zoning_approval", | |
| "signage_permit", | |
| "health_permit", | |
| "fire_inspection", | |
| "food_service_license", | |
| ] | |
| for pid in order: | |
| for action in ["submit", "pay", "inspect"]: | |
| obs = _step(env, action, pid) | |
| assert obs.last_action_error is None, ( | |
| f"{action}({pid}) failed: {obs.last_action_error}" | |
| ) | |
| assert obs.done is True | |
| assert obs.reward >= 0.9 | |
| assert obs.wasted_submissions == 0 | |
| def test_optimal_hard_restaurant(): | |
| """Walk the optimal policy for hard_restaurant.""" | |
| env = PermitEnvironment() | |
| obs = env.reset(seed=999, task_name="hard_restaurant") | |
| order = [ | |
| "business_license", | |
| "zoning_variance", | |
| "building_permit", | |
| "liquor_license", | |
| "plumbing_permit", | |
| "electrical_permit", | |
| "hvac_permit", | |
| "health_permit", | |
| "fire_certificate", | |
| "food_service_license", | |
| ] | |
| for pid in order: | |
| for action in ["submit", "pay", "inspect"]: | |
| obs = _step(env, action, pid) | |
| # Missing-doc event may revert one permit — not an error | |
| if obs.last_action_error: | |
| # Retry if stage was knocked back | |
| obs = _step(env, action, pid) | |
| # Events (missing-doc, regulation) may revert or block permits. | |
| # Re-process any permit that isn't issued yet (multiple passes). | |
| for _pass in range(3): | |
| for pid in order: | |
| p = obs.permits.get(pid, {}) | |
| if p.get("stage") != "issued": | |
| for action in ["submit", "pay", "inspect"]: | |
| obs = _step(env, action, pid) | |
| if obs.last_action_error: | |
| obs = _step(env, action, pid) | |
| assert obs.done is True | |
| assert obs.reward >= 0.70 | |
| # ---------- Illegal actions ---------- | |
| def test_submit_locked_permit_is_wasted(): | |
| """Submitting a locked (prereqs-unmet) permit should fail and count as wasted.""" | |
| env = PermitEnvironment() | |
| obs = env.reset(seed=1, task_name="medium_cafe") | |
| # food_service_license requires health_permit + fire_inspection → locked | |
| obs = _step(env, "submit", "food_service_license") | |
| assert obs.last_action_error is not None | |
| assert obs.wasted_submissions == 1 | |
| def test_pay_before_submit_is_wasted(): | |
| """Paying a permit that hasn't been submitted/approved should fail.""" | |
| env = PermitEnvironment() | |
| obs = env.reset(seed=1, task_name="easy_foodtruck") | |
| obs = _step(env, "pay", "business_license") | |
| assert obs.last_action_error is not None | |
| assert obs.wasted_submissions == 1 | |
| def test_inspect_before_pay_is_wasted(): | |
| """Inspecting a permit that hasn't been paid should fail.""" | |
| env = PermitEnvironment() | |
| obs = env.reset(seed=1, task_name="easy_foodtruck") | |
| obs = _step(env, "inspect", "business_license") | |
| assert obs.last_action_error is not None | |
| assert obs.wasted_submissions == 1 | |
| def test_unknown_permit_is_wasted(): | |
| """Acting on a nonexistent permit should be wasted.""" | |
| env = PermitEnvironment() | |
| env.reset(seed=1, task_name="easy_foodtruck") | |
| obs = _step(env, "submit", "nonexistent_permit_99") | |
| assert obs.last_action_error is not None | |
| assert "Unknown permit" in obs.last_action_error | |
| assert obs.wasted_submissions == 1 | |
| # ---------- Waste penalty ---------- | |
| def test_waste_penalty_reduces_reward(): | |
| """Spamming illegal actions should decrease the reward via waste penalty.""" | |
| env = PermitEnvironment() | |
| obs = env.reset(seed=1, task_name="easy_foodtruck") | |
| initial_reward = obs.reward | |
| # 5 illegal actions | |
| for _ in range(5): | |
| obs = _step(env, "submit", "nonexistent_permit") | |
| assert obs.reward < initial_reward | |
| assert obs.wasted_submissions == 5 | |
| # ---------- List and query are safe ---------- | |
| def test_list_does_not_advance_state(): | |
| """list() should not mutate any permit state.""" | |
| env = PermitEnvironment() | |
| obs1 = env.reset(seed=1, task_name="easy_foodtruck") | |
| permits_before = {k: v["stage"] for k, v in obs1.permits.items()} | |
| obs2 = _step(env, "list") | |
| permits_after = {k: v["stage"] for k, v in obs2.permits.items()} | |
| assert permits_before == permits_after | |
| assert obs2.wasted_submissions == 0 | |
| assert obs2.last_action_error is None | |
| def test_query_returns_info(): | |
| """query() should return permit details without error.""" | |
| env = PermitEnvironment() | |
| obs = env.reset(seed=1, task_name="easy_foodtruck") | |
| first_pid = list(obs.permits.keys())[0] | |
| obs = _step(env, "query", first_pid) | |
| assert obs.last_action_error is None | |
| assert first_pid in obs.message | |
| # ---------- Empty reset (validator path) ---------- | |
| def test_empty_reset(): | |
| """reset() with no args (validator's POST /reset with {}) must work.""" | |
| env = PermitEnvironment() | |
| obs = env.reset() | |
| assert obs.task_name in TASKS | |
| assert obs.budget_remaining > 0 | |
| assert len(obs.permits) > 0 | |
| def test_reset_with_kwargs(): | |
| """reset() accepting seed + task_name kwargs per OpenEnv best practice.""" | |
| env = PermitEnvironment() | |
| obs = env.reset(seed=42, task_name="hard_restaurant") | |
| assert obs.task_name == "hard_restaurant" | |
| assert len(obs.permits) == 10 | |
| # ---------- Episode termination ---------- | |
| def test_max_steps_terminates(): | |
| """Hitting max_steps should end the episode.""" | |
| env = PermitEnvironment() | |
| obs = env.reset(seed=1, task_name="easy_foodtruck") | |
| # Spam list() until max_steps | |
| for _ in range(25): | |
| obs = _step(env, "list") | |
| if obs.done: | |
| break | |
| assert obs.done is True | |
| # ---------- Regulation event ---------- | |
| def test_regulation_event_fires(): | |
| """On hard_restaurant, a regulation event should fire around step 15+.""" | |
| env = PermitEnvironment() | |
| obs = env.reset(seed=42, task_name="hard_restaurant") | |
| # Process available permits (submit/pay/inspect) to build up issued | |
| # permits and reach the event window (step >= 15). | |
| messages = [] | |
| for _ in range(30): | |
| # Try to advance any available permit | |
| for pid, info in obs.permits.items(): | |
| stage = info["stage"] | |
| if stage == "available": | |
| obs = _step(env, "submit", pid) | |
| messages.append(obs.message) | |
| break | |
| elif stage == "approved": | |
| obs = _step(env, "pay", pid) | |
| messages.append(obs.message) | |
| break | |
| elif stage == "paid": | |
| obs = _step(env, "inspect", pid) | |
| messages.append(obs.message) | |
| break | |
| else: | |
| # No actionable permit found, just list to burn a step | |
| obs = _step(env, "list") | |
| messages.append(obs.message) | |
| if obs.done: | |
| break | |
| all_text = " ".join(messages) | |
| assert "[EVENT] Regulatory update" in all_text, ( | |
| "Expected a regulation event to fire during hard_restaurant steps 15+" | |
| ) | |
| # ---------- Partial observability ---------- | |
| def test_hidden_prereqs_medium(): | |
| """Medium task should have 2 permits with hidden prereqs shown as ['???'].""" | |
| env = PermitEnvironment() | |
| obs = env.reset(seed=42, task_name="medium_cafe") | |
| hidden_count = sum( | |
| 1 for p in obs.permits.values() if p["prereqs"] == ["???"] | |
| ) | |
| assert hidden_count == 2, f"Expected 2 hidden prereqs, got {hidden_count}" | |
| def test_query_reveals_prereqs(): | |
| """Querying a hidden-prereq permit should reveal its real prereqs.""" | |
| env = PermitEnvironment() | |
| obs = env.reset(seed=42, task_name="hard_restaurant") | |
| # Find a permit with hidden prereqs | |
| hidden_pid = None | |
| for pid, p in obs.permits.items(): | |
| if p["prereqs"] == ["???"]: | |
| hidden_pid = pid | |
| break | |
| assert hidden_pid is not None, "Expected at least one hidden permit in hard task" | |
| # Query it | |
| obs = _step(env, "query", hidden_pid) | |
| assert "REVEALED" in obs.message | |
| # After query, the permit's prereqs should no longer be hidden | |
| assert obs.permits[hidden_pid]["prereqs"] != ["???"] | |
| assert len(obs.permits[hidden_pid]["prereqs"]) > 0 | |
| def test_inquiry_budget_deducts(): | |
| """After free queries are exhausted, budget should decrease.""" | |
| env = PermitEnvironment() | |
| obs = env.reset(seed=42, task_name="hard_restaurant") | |
| initial_budget = obs.budget_remaining | |
| # hard_restaurant has inquiry_budget=3, inquiry_cost=50 | |
| first_pid = list(obs.permits.keys())[0] | |
| # Use 3 free queries | |
| for _ in range(3): | |
| obs = _step(env, "query", first_pid) | |
| assert obs.budget_remaining == initial_budget, "First 3 queries should be free" | |
| # 4th query should cost $50 | |
| obs = _step(env, "query", first_pid) | |
| assert obs.budget_remaining < initial_budget, "4th query should deduct from budget" | |
| assert abs(obs.budget_remaining - (initial_budget - 50)) < 0.01 | |