Spaces:
Sleeping
Sleeping
File size: 2,623 Bytes
46eecf4 b4b210e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | """Tests for episode endpoints."""
import pytest
from fastapi.testclient import TestClient
def test_reset_episode(client: TestClient, sample_task: dict) -> None:
"""Test resetting an episode."""
reset_request = {"task_id": sample_task["task_id"]}
response = client.post("/api/episode/reset", json=reset_request)
assert response.status_code == 201
data = response.json()
assert "episode_id" in data
assert "observation" in data
def test_step_episode(client: TestClient, sample_task: dict, sample_action: dict) -> None:
"""Test stepping through an episode."""
# First reset
reset_request = {"task_id": sample_task["task_id"]}
reset_response = client.post("/api/episode/reset", json=reset_request)
assert reset_response.status_code == 201
episode_id = reset_response.json()["episode_id"]
# Then step
step_data = {
"episode_id": episode_id,
"action": sample_action,
}
response = client.post("/api/episode/step", json=step_data)
assert response.status_code == 200
data = response.json()
assert "observation" in data
assert "reward" in data
def test_get_state(client: TestClient, sample_task: dict) -> None:
"""Test getting episode state."""
# First reset
reset_request = {"task_id": sample_task["task_id"]}
reset_response = client.post("/api/episode/reset", json=reset_request)
episode_id = reset_response.json()["episode_id"]
# Get state
response = client.get(f"/api/episode/state/{episode_id}")
assert response.status_code == 200
data = response.json()
assert data["episode_id"] == episode_id
def test_openenv_reset_alias(client: TestClient, sample_task: dict) -> None:
"""Test OpenEnv-compatible reset alias at root path."""
response = client.post("/reset", json={"task": sample_task["task_id"]})
assert response.status_code == 200
data = response.json()
assert "episode_id" in data
assert data["task_id"] == sample_task["task_id"]
def test_openenv_step_alias_with_string_action(client: TestClient, sample_task: dict) -> None:
"""Test OpenEnv-compatible step alias accepts string action payloads."""
reset_response = client.post("/reset", json={"task_id": sample_task["task_id"]})
assert reset_response.status_code == 200
episode_id = reset_response.json()["episode_id"]
step_response = client.post(
"/step",
json={
"episode_id": episode_id,
"action": "done",
},
)
assert step_response.status_code == 200
data = step_response.json()
assert "done" in data
|