FlowMo-WM / tests /test_planning_tasks.py
cccat6's picture
Update FlowMo-WM code and static flow protocol
ccf9f1b verified
from __future__ import annotations
import numpy as np
from driftwm.sim.env import SurfaceBoatEnv
from experiments.evaluate_image_planning import reset_task, task_goals
PAPER_FLOW_FAMILIES = [
"noflow",
"uniform",
"vortex_center",
"double_gyre",
"source_sink",
"source_sink_pair",
"gradient",
"shear",
"turbulent_patch",
"random_fourier",
]
def test_station_keeping_starts_inside_target_region() -> None:
for idx, flow_type in enumerate(PAPER_FLOW_FAMILIES):
rng = np.random.default_rng(500 + idx)
env = SurfaceBoatEnv(boat="twin", flow_type=flow_type, boundary="terminate", seed=idx)
reset_task(env, "station_keeping", flow_type, rng)
goal = task_goals("station_keeping", rng)[0]
assert float(np.linalg.norm(env.state[:2] - goal)) < 0.65
def test_task_reset_refreshes_local_flow_velocity() -> None:
for idx, flow_type in enumerate(PAPER_FLOW_FAMILIES):
for task in ["reach_target", "station_keeping", "waypoint_square", "waypoint_zigzag"]:
rng = np.random.default_rng(700 + 17 * idx)
env = SurfaceBoatEnv(boat="triangle", flow_type=flow_type, boundary="terminate", seed=idx)
reset_task(env, task, flow_type, rng)
np.testing.assert_allclose(env.last_flow_velocity, env.flow_at(env.state[:2]), atol=1.0e-6)