Buckets:
| """Tests for the CropRL REST API (extras/api).""" | |
| import pytest | |
| from fastapi.testclient import TestClient | |
| from extras.api.app import app | |
| def _reset_global_state(): | |
| """Reset global env state between tests.""" | |
| import extras.api.app as api_mod | |
| api_mod._env = None | |
| api_mod._last_obs = None | |
| api_mod._episode_step = 0 | |
| yield | |
| def client(): | |
| return TestClient(app) | |
| # ── Health & Info ────────────────────────────────────────────── | |
| class TestInfo: | |
| def test_root(self, client): | |
| r = client.get("/") | |
| assert r.status_code == 200 | |
| assert r.json()["status"] == "ok" | |
| def test_list_tasks(self, client): | |
| r = client.get("/tasks") | |
| assert r.status_code == 200 | |
| data = r.json() | |
| task_ids = [t["task_id"] for t in data] | |
| assert "easy" in task_ids | |
| assert "medium" in task_ids | |
| assert "hard" in task_ids | |
| def test_list_actions(self, client): | |
| r = client.get("/actions") | |
| assert r.status_code == 200 | |
| actions = r.json()["actions"] | |
| assert len(actions) == 11 | |
| assert actions[0]["id"] == 0 | |
| assert actions[0]["name"] == "Wait" | |
| # ── Reset ────────────────────────────────────────────────────── | |
| class TestReset: | |
| def test_reset_default(self, client): | |
| r = client.post("/reset") | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert data["task_id"] == "medium" | |
| assert "observation" in data | |
| obs = data["observation"] | |
| assert obs["current_month"] == 1 | |
| assert obs["cash_balance"] == 10000.0 | |
| assert obs["done"] is False | |
| def test_reset_easy(self, client): | |
| r = client.post("/reset", json={"task": "easy", "seed": 1}) | |
| assert r.status_code == 200 | |
| obs = r.json()["observation"] | |
| assert obs["cash_balance"] == 15000.0 | |
| def test_reset_hard(self, client): | |
| r = client.post("/reset", json={"task": "hard"}) | |
| assert r.status_code == 200 | |
| obs = r.json()["observation"] | |
| assert obs["cash_balance"] == 7000.0 | |
| def test_reset_unknown_task(self, client): | |
| r = client.post("/reset", json={"task": "impossible"}) | |
| assert r.status_code == 400 | |
| def test_reset_with_text_mode(self, client): | |
| r = client.post("/reset", json={"task": "easy", "text_mode": True}) | |
| obs = r.json()["observation"] | |
| assert obs["text_summary"] is not None | |
| assert "Farm Dashboard" in obs["text_summary"] | |
| def test_reset_without_text_mode(self, client): | |
| r = client.post("/reset", json={"task": "easy", "text_mode": False}) | |
| obs = r.json()["observation"] | |
| assert obs["text_summary"] is None | |
| # ── State ────────────────────────────────────────────────────── | |
| class TestState: | |
| def test_state_before_reset(self, client): | |
| r = client.get("/state") | |
| assert r.status_code == 400 | |
| def test_state_after_reset(self, client): | |
| client.post("/reset") | |
| r = client.get("/state") | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert data["step"] == 0 | |
| assert data["done"] is False | |
| assert "observation" in data | |
| # ── Step ─────────────────────────────────────────────────────── | |
| class TestStep: | |
| def test_step_before_reset(self, client): | |
| r = client.post("/step", json={"action_id": 0}) | |
| assert r.status_code == 400 | |
| def test_step_wait(self, client): | |
| client.post("/reset") | |
| r = client.post("/step", json={"action_id": 0}) | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert data["step"] == 1 | |
| assert data["action_id"] == 0 | |
| assert data["action_name"] == "Wait" | |
| assert "reward" in data | |
| assert data["done"] is False | |
| def test_step_plant_chickpea(self, client): | |
| client.post("/reset") | |
| r = client.post("/step", json={"action_id": 3}) | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert data["action_name"] == "Plant Chickpea (Legume)" | |
| assert data["observation"]["active_crop_type"] == 3 | |
| assert "planted" in data["message"].lower() | |
| def test_step_invalid_action_id(self, client): | |
| client.post("/reset") | |
| r = client.post("/step", json={"action_id": 99}) | |
| assert r.status_code in (400, 422) | |
| def test_step_after_done(self, client): | |
| """Stepping after episode is done should return 400.""" | |
| client.post("/reset", json={"task": "easy"}) | |
| # Run until done | |
| for _ in range(70): | |
| r = client.post("/step", json={"action_id": 0}) | |
| if r.json()["done"]: | |
| break | |
| # Try one more step | |
| r = client.post("/step", json={"action_id": 0}) | |
| assert r.status_code == 400 | |
| def test_observation_structure(self, client): | |
| """Verify the observation dict has all expected keys.""" | |
| client.post("/reset") | |
| r = client.post("/step", json={"action_id": 0}) | |
| obs = r.json()["observation"] | |
| required_keys = [ | |
| "current_month", "current_step", "expected_rainfall", | |
| "active_crop_type", "crop_age_months", "soil_nitrogen", | |
| "cash_balance", "current_debt", "market_prices", | |
| "costs", "storage", "done", | |
| ] | |
| for key in required_keys: | |
| assert key in obs, f"Missing key: {key}" | |
| # ── Full Episode ─────────────────────────────────────────────── | |
| class TestFullEpisode: | |
| def test_plant_harvest_cycle(self, client): | |
| """Run a full plant → wait → harvest cycle via API.""" | |
| client.post("/reset", json={"task": "easy"}) | |
| # Plant chickpea | |
| r = client.post("/step", json={"action_id": 3}) | |
| assert r.json()["observation"]["active_crop_type"] == 3 | |
| # Wait | |
| r = client.post("/step", json={"action_id": 0}) | |
| assert r.json()["observation"]["crop_age_months"] >= 1 | |
| # Harvest & sell | |
| r = client.post("/step", json={"action_id": 7}) | |
| data = r.json() | |
| assert data["observation"]["active_crop_type"] == 0 | |
| assert data["reward"] > 0 | |
| def test_state_updates_after_step(self, client): | |
| """GET /state should reflect changes from POST /step.""" | |
| client.post("/reset") | |
| # Take action | |
| client.post("/step", json={"action_id": 3}) | |
| # Verify state matches | |
| r = client.get("/state") | |
| data = r.json() | |
| assert data["step"] == 1 | |
| assert data["observation"]["active_crop_type"] == 3 | |
| def test_multiple_resets(self, client): | |
| """Resetting should fully clear previous episode.""" | |
| client.post("/reset", json={"task": "easy"}) | |
| client.post("/step", json={"action_id": 3}) | |
| client.post("/step", json={"action_id": 0}) | |
| # Reset again | |
| r = client.post("/reset", json={"task": "hard"}) | |
| obs = r.json()["observation"] | |
| assert obs["cash_balance"] == 7000.0 | |
| assert obs["active_crop_type"] == 0 | |
| assert obs["current_step"] == 0 | |
Xet Storage Details
- Size:
- 7.6 kB
- Xet hash:
- c4e812d3fefc5ca7805d15cc3788bb9d2e727329db2f6e41780018b5d3da39c3
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.