| from __future__ import annotations |
|
|
| import unittest |
|
|
| from autofarm.contracts import ( |
| ActionType, |
| EpisodeOutcomeStatus, |
| ) |
| from autofarm.evaluation.baselines import HybridStrategy |
| from autofarm.sim.engine import build_environment, build_interactive_environment |
|
|
| try: |
| from autofarm.adapters.datasets import build_dataset_catalog |
| RIDDLE_WORKFLOW_IMPORT_ERROR: ModuleNotFoundError | None = None |
| except ModuleNotFoundError as exc: |
| RIDDLE_WORKFLOW_IMPORT_ERROR = exc |
|
|
|
|
| def run_to_completion(env, *, max_steps: int = 40): |
| strategy = HybridStrategy() |
| while not env.done and len(env.event_history) < max_steps: |
| env.step(strategy) |
| return env.snapshot() |
|
|
|
|
| @unittest.skipIf(RIDDLE_WORKFLOW_IMPORT_ERROR is not None, f"riddle workflow dependencies unavailable: {RIDDLE_WORKFLOW_IMPORT_ERROR}") |
| class RiddleWorkflowTest(unittest.TestCase): |
| @classmethod |
| def setUpClass(cls) -> None: |
| cls.catalog = build_dataset_catalog() |
|
|
| def test_environment_shapes(self) -> None: |
| with self.subTest("riddle_environment_is_fixed_5x5"): |
| env = build_environment( |
| "dense_weed_remove", |
| catalog=self.catalog, |
| grid_size=3, |
| ) |
| self.assertEqual(len(env.zone_states), 25) |
| self.assertEqual(env.home_neighbor_zone_id, "zone_r03_c01") |
| self.assertEqual(env.current_zone_id, env.home_zone_id) |
| self.assertEqual(env.max_energy, 100.0) |
| self.assertEqual(env.energy_remaining, 100.0) |
| self.assertTrue(all(zone.preview_image_path for zone in env.zone_states)) |
|
|
| with self.subTest("interactive_environment_keeps_custom_grid_size"): |
| env = build_interactive_environment( |
| catalog=self.catalog, |
| grid_size=6, |
| ) |
| self.assertEqual(len(env.zone_states), 36) |
| self.assertEqual(env.max_energy, 100.0) |
| self.assertEqual(env.energy_remaining, 100.0) |
|
|
| def test_representative_riddles_complete_expected_actions(self) -> None: |
| cases = ( |
| ("dense_weed_remove", ActionType.REMOVE_WEEDS), |
| ("soil_measure_medium_fertilizer", ActionType.APPLY_FERTILIZER), |
| ("irrigate_dry_zone", ActionType.APPLY_WATER), |
| ("patrol_recharge_resume", ActionType.PATROL), |
| ) |
| for scenario_name, expected_action in cases: |
| with self.subTest(scenario=scenario_name): |
| env = build_environment(scenario_name, catalog=self.catalog) |
| run_to_completion(env) |
| zone = env._zone_state(env.riddle_spec.target_zone_id) |
| actions = [event.action for event in env.event_history] |
|
|
| self.assertEqual(env.outcome.status, EpisodeOutcomeStatus.SUCCESS) |
| self.assertIn(expected_action, actions) |
|
|
| if scenario_name == "dense_weed_remove": |
| self.assertTrue(zone.weeds_removed) |
| elif scenario_name == "soil_measure_medium_fertilizer": |
| self.assertIn(ActionType.MEASURE_SOIL, actions) |
| self.assertIsNotNone(zone.soil_measurement) |
| self.assertEqual(zone.applied_fertilizer_amount_kg_ha, 28.0) |
| elif scenario_name == "irrigate_dry_zone": |
| self.assertEqual(zone.applied_water_amount_mm, 18.0) |
| elif scenario_name == "patrol_recharge_resume": |
| self.assertGreaterEqual(actions.count(ActionType.PATROL), 2) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|