FlowMo-WM / driftwm /sim /env.py
cccat6's picture
Update FlowMo-WM code and static flow protocol
ccf9f1b verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import numpy as np
from driftwm.sim.boat import BoatSpec, get_boat_spec, sample_boat_params
from driftwm.sim.dynamics import step_dynamics
from driftwm.sim.flow import Flow, sample_flow
from driftwm.utils import obs_from_state
@dataclass
class EnvConfig:
boat: str = "twin"
flow_type: str = "noflow"
dt: float = 0.05
episode_steps: int = 200
workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0)
boundary: str = "terminate"
randomize_params: bool = True
class SurfaceBoatEnv:
def __init__(
self,
boat: str = "twin",
flow_type: str = "noflow",
dt: float = 0.05,
episode_steps: int = 200,
workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0),
boundary: str = "terminate",
randomize_params: bool = True,
seed: int | None = None,
):
self.config = EnvConfig(boat, flow_type, dt, episode_steps, workspace, boundary, randomize_params)
self.rng = np.random.default_rng(seed)
self.spec: BoatSpec = get_boat_spec(boat)
self.params: dict[str, float] = sample_boat_params(boat, self.rng, randomize_params)
self.flow: Flow = sample_flow(flow_type, self.rng, flow_id=1, workspace=workspace)
self.state = np.zeros(6 + self.spec.action_dim, dtype=np.float32)
self.t = 0
self.time = 0.0
self.last_flow_velocity = np.zeros(2, dtype=np.float32)
@property
def action_dim(self) -> int:
return self.spec.action_dim
@property
def workspace(self) -> tuple[float, float, float, float]:
return self.config.workspace
def reset(
self,
*,
boat: str | None = None,
flow_type: str | None = None,
flow: Flow | None = None,
flow_id: int | None = None,
random_velocity: bool = True,
initial_state: np.ndarray | None = None,
randomize_params: bool | None = None,
) -> tuple[np.ndarray, dict[str, Any]]:
if boat is not None:
self.config.boat = boat
if flow_type is not None:
self.config.flow_type = flow_type
if randomize_params is not None:
self.config.randomize_params = randomize_params
self.spec = get_boat_spec(self.config.boat)
self.params = sample_boat_params(self.config.boat, self.rng, self.config.randomize_params)
if flow is not None:
self.flow = flow
else:
fid = int(flow_id if flow_id is not None else self.rng.integers(1, 2_000_000))
self.flow = sample_flow(self.config.flow_type, self.rng, fid, self.config.workspace)
if initial_state is not None:
self.state = np.asarray(initial_state, dtype=np.float32).copy()
else:
xmin, xmax, ymin, ymax = self.config.workspace
margin = 1.0
pos = np.array(
[self.rng.uniform(xmin + margin, xmax - margin), self.rng.uniform(ymin + margin, ymax - margin)],
dtype=np.float32,
)
theta = self.rng.uniform(-np.pi, np.pi)
vel = self.rng.uniform(-0.12, 0.12, size=2).astype(np.float32) if random_velocity else np.zeros(2, dtype=np.float32)
omega = float(self.rng.uniform(-0.15, 0.15)) if random_velocity else 0.0
self.state = np.zeros(6 + self.spec.action_dim, dtype=np.float32)
self.state[:6] = np.array([pos[0], pos[1], theta, vel[0], vel[1], omega], dtype=np.float32)
self.t = 0
self.time = 0.0
self.last_flow_velocity = self.flow.velocity(self.state[:2], self.time)
return self.observation(), self.info()
def observation(self) -> np.ndarray:
return obs_from_state(self.state[:6])
def full_state(self) -> np.ndarray:
return self.state.copy()
def flow_at(self, pos: np.ndarray) -> np.ndarray:
return self.flow.velocity(np.asarray(pos, dtype=np.float32), self.time)
def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, dict[str, Any]]:
action = np.asarray(action, dtype=np.float32)[: self.action_dim]
flow_velocity = self.flow.velocity(self.state[:2], self.time)
self.last_flow_velocity = flow_velocity.astype(np.float32)
self.state, boundary_done = step_dynamics(
self.state,
action,
self.spec,
self.params,
flow_velocity,
self.config.dt,
self.config.workspace,
self.config.boundary,
)
self.t += 1
self.time += self.config.dt
timeout = self.t >= self.config.episode_steps
done = boundary_done or timeout
reward = 0.0
return self.observation(), reward, done, self.info()
def info(self) -> dict[str, Any]:
meta = {
"t": self.t,
"time": self.time,
"boat_type": self.spec.name,
"action_dim": self.action_dim,
"flow_velocity": self.last_flow_velocity.astype(float).tolist(),
"params": {k: float(v) for k, v in self.params.items()},
}
meta.update(self.flow.metadata())
return meta