| from __future__ import annotations |
|
|
| import numpy as np |
|
|
| from driftwm.sim.env import SurfaceBoatEnv |
| from experiments.evaluate_image_planning import reset_task, task_goals |
|
|
|
|
| PAPER_FLOW_FAMILIES = [ |
| "noflow", |
| "uniform", |
| "vortex_center", |
| "double_gyre", |
| "source_sink", |
| "source_sink_pair", |
| "gradient", |
| "shear", |
| "turbulent_patch", |
| "random_fourier", |
| ] |
|
|
|
|
| def test_station_keeping_starts_inside_target_region() -> None: |
| for idx, flow_type in enumerate(PAPER_FLOW_FAMILIES): |
| rng = np.random.default_rng(500 + idx) |
| env = SurfaceBoatEnv(boat="twin", flow_type=flow_type, boundary="terminate", seed=idx) |
| reset_task(env, "station_keeping", flow_type, rng) |
| goal = task_goals("station_keeping", rng)[0] |
| assert float(np.linalg.norm(env.state[:2] - goal)) < 0.65 |
|
|
|
|
| def test_task_reset_refreshes_local_flow_velocity() -> None: |
| for idx, flow_type in enumerate(PAPER_FLOW_FAMILIES): |
| for task in ["reach_target", "station_keeping", "waypoint_square", "waypoint_zigzag"]: |
| rng = np.random.default_rng(700 + 17 * idx) |
| env = SurfaceBoatEnv(boat="triangle", flow_type=flow_type, boundary="terminate", seed=idx) |
| reset_task(env, task, flow_type, rng) |
| np.testing.assert_allclose(env.last_flow_velocity, env.flow_at(env.state[:2]), atol=1.0e-6) |
|
|