permit-pathfinder / tests /test_fsm.py
yashppawar's picture
Upload folder using huggingface_hub
b22b2e7 verified
"""Tests for the PermitPathfinder FSM — transitions, optimal policies, edge cases."""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from server.permit_env_environment import (
PermitEnvironment,
TASKS,
STAGE_ISSUED,
)
from models import PermitAction
def _step(env, action_type, permit_id=None):
return env.step(PermitAction(action_type=action_type, permit_id=permit_id))
# ---------- Optimal policy ----------
def test_optimal_easy_foodtruck():
"""Walk the optimal submit->pay->inspect sequence for all 3 easy permits."""
env = PermitEnvironment()
obs = env.reset(seed=42, task_name="easy_foodtruck")
assert obs.task_name == "easy_foodtruck"
for pid in list(obs.permits.keys()):
for action in ["submit", "pay", "inspect"]:
obs = _step(env, action, pid)
assert obs.last_action_error is None, (
f"{action}({pid}) failed: {obs.last_action_error}"
)
assert obs.done is True
assert obs.reward >= 0.9
assert obs.wasted_submissions == 0
def test_optimal_medium_cafe():
"""Walk the optimal policy for medium_cafe respecting dependency order."""
env = PermitEnvironment()
obs = env.reset(seed=100, task_name="medium_cafe")
# Correct topological order
order = [
"business_license",
"zoning_approval",
"signage_permit",
"health_permit",
"fire_inspection",
"food_service_license",
]
for pid in order:
for action in ["submit", "pay", "inspect"]:
obs = _step(env, action, pid)
assert obs.last_action_error is None, (
f"{action}({pid}) failed: {obs.last_action_error}"
)
assert obs.done is True
assert obs.reward >= 0.9
assert obs.wasted_submissions == 0
def test_optimal_hard_restaurant():
"""Walk the optimal policy for hard_restaurant."""
env = PermitEnvironment()
obs = env.reset(seed=999, task_name="hard_restaurant")
order = [
"business_license",
"zoning_variance",
"building_permit",
"liquor_license",
"plumbing_permit",
"electrical_permit",
"hvac_permit",
"health_permit",
"fire_certificate",
"food_service_license",
]
for pid in order:
for action in ["submit", "pay", "inspect"]:
obs = _step(env, action, pid)
# Missing-doc event may revert one permit — not an error
if obs.last_action_error:
# Retry if stage was knocked back
obs = _step(env, action, pid)
# Events (missing-doc, regulation) may revert or block permits.
# Re-process any permit that isn't issued yet (multiple passes).
for _pass in range(3):
for pid in order:
p = obs.permits.get(pid, {})
if p.get("stage") != "issued":
for action in ["submit", "pay", "inspect"]:
obs = _step(env, action, pid)
if obs.last_action_error:
obs = _step(env, action, pid)
assert obs.done is True
assert obs.reward >= 0.70
# ---------- Illegal actions ----------
def test_submit_locked_permit_is_wasted():
"""Submitting a locked (prereqs-unmet) permit should fail and count as wasted."""
env = PermitEnvironment()
obs = env.reset(seed=1, task_name="medium_cafe")
# food_service_license requires health_permit + fire_inspection → locked
obs = _step(env, "submit", "food_service_license")
assert obs.last_action_error is not None
assert obs.wasted_submissions == 1
def test_pay_before_submit_is_wasted():
"""Paying a permit that hasn't been submitted/approved should fail."""
env = PermitEnvironment()
obs = env.reset(seed=1, task_name="easy_foodtruck")
obs = _step(env, "pay", "business_license")
assert obs.last_action_error is not None
assert obs.wasted_submissions == 1
def test_inspect_before_pay_is_wasted():
"""Inspecting a permit that hasn't been paid should fail."""
env = PermitEnvironment()
obs = env.reset(seed=1, task_name="easy_foodtruck")
obs = _step(env, "inspect", "business_license")
assert obs.last_action_error is not None
assert obs.wasted_submissions == 1
def test_unknown_permit_is_wasted():
"""Acting on a nonexistent permit should be wasted."""
env = PermitEnvironment()
env.reset(seed=1, task_name="easy_foodtruck")
obs = _step(env, "submit", "nonexistent_permit_99")
assert obs.last_action_error is not None
assert "Unknown permit" in obs.last_action_error
assert obs.wasted_submissions == 1
# ---------- Waste penalty ----------
def test_waste_penalty_reduces_reward():
"""Spamming illegal actions should decrease the reward via waste penalty."""
env = PermitEnvironment()
obs = env.reset(seed=1, task_name="easy_foodtruck")
initial_reward = obs.reward
# 5 illegal actions
for _ in range(5):
obs = _step(env, "submit", "nonexistent_permit")
assert obs.reward < initial_reward
assert obs.wasted_submissions == 5
# ---------- List and query are safe ----------
def test_list_does_not_advance_state():
"""list() should not mutate any permit state."""
env = PermitEnvironment()
obs1 = env.reset(seed=1, task_name="easy_foodtruck")
permits_before = {k: v["stage"] for k, v in obs1.permits.items()}
obs2 = _step(env, "list")
permits_after = {k: v["stage"] for k, v in obs2.permits.items()}
assert permits_before == permits_after
assert obs2.wasted_submissions == 0
assert obs2.last_action_error is None
def test_query_returns_info():
"""query() should return permit details without error."""
env = PermitEnvironment()
obs = env.reset(seed=1, task_name="easy_foodtruck")
first_pid = list(obs.permits.keys())[0]
obs = _step(env, "query", first_pid)
assert obs.last_action_error is None
assert first_pid in obs.message
# ---------- Empty reset (validator path) ----------
def test_empty_reset():
"""reset() with no args (validator's POST /reset with {}) must work."""
env = PermitEnvironment()
obs = env.reset()
assert obs.task_name in TASKS
assert obs.budget_remaining > 0
assert len(obs.permits) > 0
def test_reset_with_kwargs():
"""reset() accepting seed + task_name kwargs per OpenEnv best practice."""
env = PermitEnvironment()
obs = env.reset(seed=42, task_name="hard_restaurant")
assert obs.task_name == "hard_restaurant"
assert len(obs.permits) == 10
# ---------- Episode termination ----------
def test_max_steps_terminates():
"""Hitting max_steps should end the episode."""
env = PermitEnvironment()
obs = env.reset(seed=1, task_name="easy_foodtruck")
# Spam list() until max_steps
for _ in range(25):
obs = _step(env, "list")
if obs.done:
break
assert obs.done is True
# ---------- Regulation event ----------
def test_regulation_event_fires():
"""On hard_restaurant, a regulation event should fire around step 15+."""
env = PermitEnvironment()
obs = env.reset(seed=42, task_name="hard_restaurant")
# Process available permits (submit/pay/inspect) to build up issued
# permits and reach the event window (step >= 15).
messages = []
for _ in range(30):
# Try to advance any available permit
for pid, info in obs.permits.items():
stage = info["stage"]
if stage == "available":
obs = _step(env, "submit", pid)
messages.append(obs.message)
break
elif stage == "approved":
obs = _step(env, "pay", pid)
messages.append(obs.message)
break
elif stage == "paid":
obs = _step(env, "inspect", pid)
messages.append(obs.message)
break
else:
# No actionable permit found, just list to burn a step
obs = _step(env, "list")
messages.append(obs.message)
if obs.done:
break
all_text = " ".join(messages)
assert "[EVENT] Regulatory update" in all_text, (
"Expected a regulation event to fire during hard_restaurant steps 15+"
)
# ---------- Partial observability ----------
def test_hidden_prereqs_medium():
"""Medium task should have 2 permits with hidden prereqs shown as ['???']."""
env = PermitEnvironment()
obs = env.reset(seed=42, task_name="medium_cafe")
hidden_count = sum(
1 for p in obs.permits.values() if p["prereqs"] == ["???"]
)
assert hidden_count == 2, f"Expected 2 hidden prereqs, got {hidden_count}"
def test_query_reveals_prereqs():
"""Querying a hidden-prereq permit should reveal its real prereqs."""
env = PermitEnvironment()
obs = env.reset(seed=42, task_name="hard_restaurant")
# Find a permit with hidden prereqs
hidden_pid = None
for pid, p in obs.permits.items():
if p["prereqs"] == ["???"]:
hidden_pid = pid
break
assert hidden_pid is not None, "Expected at least one hidden permit in hard task"
# Query it
obs = _step(env, "query", hidden_pid)
assert "REVEALED" in obs.message
# After query, the permit's prereqs should no longer be hidden
assert obs.permits[hidden_pid]["prereqs"] != ["???"]
assert len(obs.permits[hidden_pid]["prereqs"]) > 0
def test_inquiry_budget_deducts():
"""After free queries are exhausted, budget should decrease."""
env = PermitEnvironment()
obs = env.reset(seed=42, task_name="hard_restaurant")
initial_budget = obs.budget_remaining
# hard_restaurant has inquiry_budget=3, inquiry_cost=50
first_pid = list(obs.permits.keys())[0]
# Use 3 free queries
for _ in range(3):
obs = _step(env, "query", first_pid)
assert obs.budget_remaining == initial_budget, "First 3 queries should be free"
# 4th query should cost $50
obs = _step(env, "query", first_pid)
assert obs.budget_remaining < initial_budget, "4th query should deduct from budget"
assert abs(obs.budget_remaining - (initial_budget - 50)) < 0.01