from __future__ import annotations from dataclasses import dataclass from typing import Any import numpy as np from driftwm.sim.boat import BoatSpec, get_boat_spec, sample_boat_params from driftwm.sim.dynamics import step_dynamics from driftwm.sim.flow import Flow, sample_flow from driftwm.utils import obs_from_state @dataclass class EnvConfig: boat: str = "twin" flow_type: str = "noflow" dt: float = 0.05 episode_steps: int = 200 workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0) boundary: str = "terminate" randomize_params: bool = True class SurfaceBoatEnv: def __init__( self, boat: str = "twin", flow_type: str = "noflow", dt: float = 0.05, episode_steps: int = 200, workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0), boundary: str = "terminate", randomize_params: bool = True, seed: int | None = None, ): self.config = EnvConfig(boat, flow_type, dt, episode_steps, workspace, boundary, randomize_params) self.rng = np.random.default_rng(seed) self.spec: BoatSpec = get_boat_spec(boat) self.params: dict[str, float] = sample_boat_params(boat, self.rng, randomize_params) self.flow: Flow = sample_flow(flow_type, self.rng, flow_id=1, workspace=workspace) self.state = np.zeros(6 + self.spec.action_dim, dtype=np.float32) self.t = 0 self.time = 0.0 self.last_flow_velocity = np.zeros(2, dtype=np.float32) @property def action_dim(self) -> int: return self.spec.action_dim @property def workspace(self) -> tuple[float, float, float, float]: return self.config.workspace def reset( self, *, boat: str | None = None, flow_type: str | None = None, flow: Flow | None = None, flow_id: int | None = None, random_velocity: bool = True, initial_state: np.ndarray | None = None, randomize_params: bool | None = None, ) -> tuple[np.ndarray, dict[str, Any]]: if boat is not None: self.config.boat = boat if flow_type is not None: self.config.flow_type = flow_type if randomize_params is not None: self.config.randomize_params = randomize_params self.spec = get_boat_spec(self.config.boat) self.params = sample_boat_params(self.config.boat, self.rng, self.config.randomize_params) if flow is not None: self.flow = flow else: fid = int(flow_id if flow_id is not None else self.rng.integers(1, 2_000_000)) self.flow = sample_flow(self.config.flow_type, self.rng, fid, self.config.workspace) if initial_state is not None: self.state = np.asarray(initial_state, dtype=np.float32).copy() else: xmin, xmax, ymin, ymax = self.config.workspace margin = 1.0 pos = np.array( [self.rng.uniform(xmin + margin, xmax - margin), self.rng.uniform(ymin + margin, ymax - margin)], dtype=np.float32, ) theta = self.rng.uniform(-np.pi, np.pi) vel = self.rng.uniform(-0.12, 0.12, size=2).astype(np.float32) if random_velocity else np.zeros(2, dtype=np.float32) omega = float(self.rng.uniform(-0.15, 0.15)) if random_velocity else 0.0 self.state = np.zeros(6 + self.spec.action_dim, dtype=np.float32) self.state[:6] = np.array([pos[0], pos[1], theta, vel[0], vel[1], omega], dtype=np.float32) self.t = 0 self.time = 0.0 self.last_flow_velocity = self.flow.velocity(self.state[:2], self.time) return self.observation(), self.info() def observation(self) -> np.ndarray: return obs_from_state(self.state[:6]) def full_state(self) -> np.ndarray: return self.state.copy() def flow_at(self, pos: np.ndarray) -> np.ndarray: return self.flow.velocity(np.asarray(pos, dtype=np.float32), self.time) def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, dict[str, Any]]: action = np.asarray(action, dtype=np.float32)[: self.action_dim] flow_velocity = self.flow.velocity(self.state[:2], self.time) self.last_flow_velocity = flow_velocity.astype(np.float32) self.state, boundary_done = step_dynamics( self.state, action, self.spec, self.params, flow_velocity, self.config.dt, self.config.workspace, self.config.boundary, ) self.t += 1 self.time += self.config.dt timeout = self.t >= self.config.episode_steps done = boundary_done or timeout reward = 0.0 return self.observation(), reward, done, self.info() def info(self) -> dict[str, Any]: meta = { "t": self.t, "time": self.time, "boat_type": self.spec.name, "action_dim": self.action_dim, "flow_velocity": self.last_flow_velocity.astype(float).tolist(), "params": {k: float(v) for k, v in self.params.items()}, } meta.update(self.flow.metadata()) return meta