File size: 2,320 Bytes
e75c8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""Phase 1 gates: OpenEnv HTTP, three tasks, graders in [0,1], reproducible seed."""

import pytest
from fastapi.testclient import TestClient

from env.grader import clamp_unit_interval, evaluate_episode
from env.task_graders import TASK_AGENT_GRADERS
from server.app import app


@pytest.fixture
def client():
    return TestClient(app)


def test_tasks_endpoint_three_graders(client):
    r = client.get("/tasks")
    assert r.status_code == 200
    data = r.json()
    assert len(data["tasks"]) >= 3
    enabled = [t for t in data["tasks"] if t.get("grader")]
    assert len(enabled) >= 3
    assert len(data["grader_registry"]) >= 3


def test_each_task_grader_returns_unit_interval():
    history = [
        {"action": "keep", "is_stale": False},
        {"action": "invalidate", "is_stale": True},
    ]
    for name, fn in TASK_AGENT_GRADERS.items():
        s = fn(history)
        assert 0.0 <= s <= 1.0, (name, s)


def test_reset_step_openenv_shape(client):
    r = client.post("/reset", json={"seed": 123, "task_id": "medium"})
    assert r.status_code == 200
    body = r.json()
    assert set(body.keys()) >= {"observation", "reward", "done"}
    obs = body["observation"]
    assert obs["task_id"] == "medium"
    key = obs["items"][0]["key"]
    s = client.post("/step", json={"action": {"type": "keep", "key": key}})
    assert s.status_code == 200
    assert "observation" in s.json()


def test_reproducible_reset_seed(client):
    a = client.post("/reset", json={"seed": 999, "task_id": "easy"}).json()["observation"]
    b = client.post("/reset", json={"seed": 999, "task_id": "easy"}).json()["observation"]
    assert a["items"] == b["items"]


def test_final_score_in_range(client):
    r = client.post("/reset", json={"seed": 0, "task_id": "easy"})
    obs = r.json()["observation"]
    final = None
    for _ in range(12):
        k = obs["items"][0]["key"]
        d = client.post("/step", json={"action": {"type": "keep", "key": k}}).json()
        obs = d["observation"]
        if obs.get("final_score") is not None:
            final = obs["final_score"]
            break
    assert final is not None
    assert 0.0 <= final <= 1.0


def test_clamp_unit_interval():
    assert clamp_unit_interval(-1) == 0.0
    assert clamp_unit_interval(2) == 1.0
    assert evaluate_episode([]) == 0.0