autofarm / tests /test_riddle_workflows.py
isabelku's picture
AutoFarm Space deploy
826dd96
from __future__ import annotations
import unittest
from autofarm.contracts import (
ActionType,
EpisodeOutcomeStatus,
)
from autofarm.evaluation.baselines import HybridStrategy
from autofarm.sim.engine import build_environment, build_interactive_environment
try:
from autofarm.adapters.datasets import build_dataset_catalog
RIDDLE_WORKFLOW_IMPORT_ERROR: ModuleNotFoundError | None = None
except ModuleNotFoundError as exc: # pragma: no cover - local runtime may not have dataset deps
RIDDLE_WORKFLOW_IMPORT_ERROR = exc
def run_to_completion(env, *, max_steps: int = 40):
strategy = HybridStrategy()
while not env.done and len(env.event_history) < max_steps:
env.step(strategy)
return env.snapshot()
@unittest.skipIf(RIDDLE_WORKFLOW_IMPORT_ERROR is not None, f"riddle workflow dependencies unavailable: {RIDDLE_WORKFLOW_IMPORT_ERROR}")
class RiddleWorkflowTest(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.catalog = build_dataset_catalog()
def test_environment_shapes(self) -> None:
with self.subTest("riddle_environment_is_fixed_5x5"):
env = build_environment(
"dense_weed_remove",
catalog=self.catalog,
grid_size=3,
)
self.assertEqual(len(env.zone_states), 25)
self.assertEqual(env.home_neighbor_zone_id, "zone_r03_c01")
self.assertEqual(env.current_zone_id, env.home_zone_id)
self.assertEqual(env.max_energy, 100.0)
self.assertEqual(env.energy_remaining, 100.0)
self.assertTrue(all(zone.preview_image_path for zone in env.zone_states))
with self.subTest("interactive_environment_keeps_custom_grid_size"):
env = build_interactive_environment(
catalog=self.catalog,
grid_size=6,
)
self.assertEqual(len(env.zone_states), 36)
self.assertEqual(env.max_energy, 100.0)
self.assertEqual(env.energy_remaining, 100.0)
def test_representative_riddles_complete_expected_actions(self) -> None:
cases = (
("dense_weed_remove", ActionType.REMOVE_WEEDS),
("soil_measure_medium_fertilizer", ActionType.APPLY_FERTILIZER),
("irrigate_dry_zone", ActionType.APPLY_WATER),
("patrol_recharge_resume", ActionType.PATROL),
)
for scenario_name, expected_action in cases:
with self.subTest(scenario=scenario_name):
env = build_environment(scenario_name, catalog=self.catalog)
run_to_completion(env)
zone = env._zone_state(env.riddle_spec.target_zone_id)
actions = [event.action for event in env.event_history]
self.assertEqual(env.outcome.status, EpisodeOutcomeStatus.SUCCESS)
self.assertIn(expected_action, actions)
if scenario_name == "dense_weed_remove":
self.assertTrue(zone.weeds_removed)
elif scenario_name == "soil_measure_medium_fertilizer":
self.assertIn(ActionType.MEASURE_SOIL, actions)
self.assertIsNotNone(zone.soil_measurement)
self.assertEqual(zone.applied_fertilizer_amount_kg_ha, 28.0)
elif scenario_name == "irrigate_dry_zone":
self.assertEqual(zone.applied_water_amount_mm, 18.0)
elif scenario_name == "patrol_recharge_resume":
self.assertGreaterEqual(actions.count(ActionType.PATROL), 2)
if __name__ == "__main__":
unittest.main()