File size: 1,357 Bytes
ccf9f1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from __future__ import annotations

import numpy as np

from driftwm.sim.env import SurfaceBoatEnv
from experiments.evaluate_image_planning import reset_task, task_goals


PAPER_FLOW_FAMILIES = [
    "noflow",
    "uniform",
    "vortex_center",
    "double_gyre",
    "source_sink",
    "source_sink_pair",
    "gradient",
    "shear",
    "turbulent_patch",
    "random_fourier",
]


def test_station_keeping_starts_inside_target_region() -> None:
    for idx, flow_type in enumerate(PAPER_FLOW_FAMILIES):
        rng = np.random.default_rng(500 + idx)
        env = SurfaceBoatEnv(boat="twin", flow_type=flow_type, boundary="terminate", seed=idx)
        reset_task(env, "station_keeping", flow_type, rng)
        goal = task_goals("station_keeping", rng)[0]
        assert float(np.linalg.norm(env.state[:2] - goal)) < 0.65


def test_task_reset_refreshes_local_flow_velocity() -> None:
    for idx, flow_type in enumerate(PAPER_FLOW_FAMILIES):
        for task in ["reach_target", "station_keeping", "waypoint_square", "waypoint_zigzag"]:
            rng = np.random.default_rng(700 + 17 * idx)
            env = SurfaceBoatEnv(boat="triangle", flow_type=flow_type, boundary="terminate", seed=idx)
            reset_task(env, task, flow_type, rng)
            np.testing.assert_allclose(env.last_flow_velocity, env.flow_at(env.state[:2]), atol=1.0e-6)