clarify-rl / tests /test_environment.py
Anurag Agarwal
ClarifyRL: initial HF Space deploy
2414d31
"""Integration tests for ClarifyEnvironment — full episode flows.
Requires full openenv stack installed. Run with:
python -m pytest tests/test_environment.py -v
Skip with:
python -m pytest -m 'not integration'
"""
from __future__ import annotations
import json
import pytest
pytestmark = pytest.mark.integration
from openenv.core.env_server.mcp_types import CallToolAction, ListToolsAction
from server.clarify_environment import ClarifyEnvironment
from server.scenarios import REQUIRED_KEYS_BY_FAMILY
@pytest.fixture
def env():
e = ClarifyEnvironment()
e.reset(seed=7, task_id="medium")
return e
def _call(env, tool, **kwargs):
action = CallToolAction(tool_name=tool, arguments=kwargs)
return env.step(action)
class TestReset:
def test_returns_observation(self):
e = ClarifyEnvironment()
obs = e.reset(seed=42, task_id="easy")
assert obs.done is False
assert obs.reward == 0.0
assert obs.result["type"] == "task"
def test_state_after_reset(self):
e = ClarifyEnvironment()
e.reset(seed=42, task_id="medium")
s = e.state
assert s.step_count == 0
assert s.questions_remaining == 6
assert s.episode_done is False
assert s.plan_submitted is False
@pytest.mark.parametrize("diff", ["easy", "medium", "hard"])
def test_all_difficulties(self, diff):
e = ClarifyEnvironment()
obs = e.reset(seed=1, task_id=diff)
assert obs.result["task_id"] == diff
class TestListTools:
def test_returns_three_tools(self, env):
obs = env.step(ListToolsAction())
names = [t.name for t in obs.tools]
assert set(names) == {"get_task_info", "ask_question", "propose_plan"}
class TestGetTaskInfo:
def test_free_action(self, env):
obs = _call(env, "get_task_info")
assert obs.reward == 0.0
assert obs.done is False
data = obs.result.data
assert "request" in data
assert "family" in data
def test_step_count_increments(self, env):
_call(env, "get_task_info")
assert env.state.step_count == 1
class TestAskQuestion:
def test_reveals_field(self, env):
obs = _call(env, "ask_question", question="what is the order id?")
data = obs.result.data
assert data["field_revealed"] is not None
assert obs.reward > 0
def test_budget_decrements(self, env):
_call(env, "ask_question", question="what is the order id?")
assert env.state.questions_remaining == 5
def test_duplicate_penalty(self, env):
_call(env, "ask_question", question="what is the order id?")
obs2 = _call(env, "ask_question", question="tell me order id again?")
assert obs2.reward < 0
assert obs2.result.data["duplicate"] is True
def test_unknown_question_small_reward(self, env):
obs = _call(env, "ask_question", question="tell me about your cat")
assert obs.reward > 0
assert obs.result.data["field_revealed"] is None
def test_over_cap(self, env):
for i in range(6):
_call(env, "ask_question", question=f"question {i}")
obs = _call(env, "ask_question", question="one more")
assert obs.result.data["over_cap"] is True
assert obs.done is True
def test_truncates_long_question(self, env):
long_q = "x" * 500
obs = _call(env, "ask_question", question=long_q)
assert obs.done is False
class TestProposePlan:
def test_terminates_episode(self, env):
plan = json.dumps({"order_id": "#1", "item_issue": "late", "refund_or_replace": "refund"})
obs = _call(env, "propose_plan", plan=plan)
assert obs.done is True
assert env.state.episode_done is True
assert env.state.plan_submitted is True
assert env.state.final_score is not None
def test_bad_json_zero_score(self, env):
obs = _call(env, "propose_plan", plan="not json")
assert obs.done is True
assert env.state.final_score == 0.0
def test_missing_keys_zero_score(self, env):
obs = _call(env, "propose_plan", plan='{"order_id": "#1"}')
assert obs.done is True
assert env.state.final_score == 0.0
def test_breakdown_populated(self, env):
plan = json.dumps({"order_id": "#1", "item_issue": "late", "refund_or_replace": "refund"})
_call(env, "propose_plan", plan=plan)
bd = env.state.score_breakdown
assert "FormatCheckRubric" in bd
class TestEpisodeDoneGuard:
def test_no_ask_after_plan(self, env):
plan = json.dumps({"order_id": "#1", "item_issue": "late", "refund_or_replace": "refund"})
_call(env, "propose_plan", plan=plan)
obs = _call(env, "ask_question", question="more info?")
assert obs.done is True
assert obs.result.data.get("error") == "episode already ended"
def test_no_plan_after_plan(self, env):
plan = json.dumps({"order_id": "#1", "item_issue": "late", "refund_or_replace": "refund"})
_call(env, "propose_plan", plan=plan)
obs = _call(env, "propose_plan", plan=plan)
assert obs.done is True
assert obs.result.data.get("error") == "episode already ended"
def test_no_info_after_plan(self, env):
plan = json.dumps({"order_id": "#1", "item_issue": "late", "refund_or_replace": "refund"})
_call(env, "propose_plan", plan=plan)
obs = _call(env, "get_task_info")
assert obs.done is True
class TestMaxSteps:
def test_max_steps_enforced(self):
e = ClarifyEnvironment()
e.reset(seed=7, task_id="easy")
for i in range(10):
obs = _call(e, "get_task_info")
if obs.done:
assert e.state.step_count <= 8
return
pytest.fail("max_steps not enforced within 10 calls")
class TestStepCount:
def test_increments_each_call(self, env):
for i in range(1, 4):
_call(env, "get_task_info")
assert env.state.step_count == i
class TestOraclePolicy:
def test_oracle_scores_high(self):
for seed in range(10):
e = ClarifyEnvironment()
obs = e.reset(seed=seed, task_id="medium")
family = obs.result["family"]
profile = e._scenario["hidden_profile"]
critical = e._scenario["critical_fields"]
required_keys = REQUIRED_KEYS_BY_FAMILY[family]
for cf in critical:
kw = cf.replace("_", " ")
_call(e, "ask_question", question=f"what is the {kw}?")
plan = json.dumps(profile)
obs = _call(e, "propose_plan", plan=plan)
assert obs.done is True
score = e.state.final_score
assert score is not None
assert score > 0.5, f"seed={seed} family={family} oracle score={score} too low"