cache-env / tests /test_phase1.py
Parv Pareek
done
e75c8ce
"""Phase 1 gates: OpenEnv HTTP, three tasks, graders in [0,1], reproducible seed."""
import pytest
from fastapi.testclient import TestClient
from env.grader import clamp_unit_interval, evaluate_episode
from env.task_graders import TASK_AGENT_GRADERS
from server.app import app
@pytest.fixture
def client():
return TestClient(app)
def test_tasks_endpoint_three_graders(client):
r = client.get("/tasks")
assert r.status_code == 200
data = r.json()
assert len(data["tasks"]) >= 3
enabled = [t for t in data["tasks"] if t.get("grader")]
assert len(enabled) >= 3
assert len(data["grader_registry"]) >= 3
def test_each_task_grader_returns_unit_interval():
history = [
{"action": "keep", "is_stale": False},
{"action": "invalidate", "is_stale": True},
]
for name, fn in TASK_AGENT_GRADERS.items():
s = fn(history)
assert 0.0 <= s <= 1.0, (name, s)
def test_reset_step_openenv_shape(client):
r = client.post("/reset", json={"seed": 123, "task_id": "medium"})
assert r.status_code == 200
body = r.json()
assert set(body.keys()) >= {"observation", "reward", "done"}
obs = body["observation"]
assert obs["task_id"] == "medium"
key = obs["items"][0]["key"]
s = client.post("/step", json={"action": {"type": "keep", "key": key}})
assert s.status_code == 200
assert "observation" in s.json()
def test_reproducible_reset_seed(client):
a = client.post("/reset", json={"seed": 999, "task_id": "easy"}).json()["observation"]
b = client.post("/reset", json={"seed": 999, "task_id": "easy"}).json()["observation"]
assert a["items"] == b["items"]
def test_final_score_in_range(client):
r = client.post("/reset", json={"seed": 0, "task_id": "easy"})
obs = r.json()["observation"]
final = None
for _ in range(12):
k = obs["items"][0]["key"]
d = client.post("/step", json={"action": {"type": "keep", "key": k}}).json()
obs = d["observation"]
if obs.get("final_score") is not None:
final = obs["final_score"]
break
assert final is not None
assert 0.0 <= final <= 1.0
def test_clamp_unit_interval():
assert clamp_unit_interval(-1) == 0.0
assert clamp_unit_interval(2) == 1.0
assert evaluate_episode([]) == 0.0