FlowMo-WM / driftwm /sim /dynamics.py
cccat6's picture
Initial FlowMo-WM public code release
604e535 verified
from __future__ import annotations
import numpy as np
from driftwm.sim.boat import BoatSpec, thruster_wrench_body
from driftwm.utils import wrap_angle
def rot_body_to_world(theta: float) -> np.ndarray:
c = np.cos(theta)
s = np.sin(theta)
return np.array([[c, -s], [s, c]], dtype=np.float32)
def step_dynamics(
state: np.ndarray,
action: np.ndarray,
spec: BoatSpec,
params: dict[str, float],
flow_velocity: np.ndarray,
dt: float,
workspace: tuple[float, float, float, float],
boundary: str = "terminate",
) -> tuple[np.ndarray, bool]:
action = np.clip(np.asarray(action, dtype=np.float32), -1.0, 1.0)
state = np.asarray(state, dtype=np.float32).copy()
x, y, theta, vx, vy, omega = state[:6]
u = state[6 : 6 + spec.action_dim]
tau_a = max(params["actuator_tau"], 1e-3)
alpha = min(1.0, dt / tau_a)
u_next = u + alpha * (action - u)
rot = rot_body_to_world(float(theta))
vel_world = np.array([vx, vy], dtype=np.float32)
rel_world = vel_world - np.asarray(flow_velocity, dtype=np.float32)
rel_body = rot.T @ rel_world
d1 = np.array([params["drag_linear_x"], params["drag_linear_y"]], dtype=np.float32)
d2 = np.array([params["drag_quad_x"], params["drag_quad_y"]], dtype=np.float32)
drag_body = -d1 * rel_body - d2 * np.abs(rel_body) * rel_body
thrust_body, tau_thr = thruster_wrench_body(spec, u_next, params)
force_world = rot @ (thrust_body + drag_body)
tau_drag = -params["drag_angular"] * omega - params["drag_angular_quad"] * abs(float(omega)) * omega
tau_total = tau_thr + tau_drag
vel_next = vel_world + dt * force_world / params["mass"]
pos_next = np.array([x, y], dtype=np.float32) + dt * vel_next
omega_next = omega + dt * tau_total / params["inertia"]
theta_next = wrap_angle(theta + dt * omega_next)
xmin, xmax, ymin, ymax = workspace
done = bool(pos_next[0] < xmin or pos_next[0] > xmax or pos_next[1] < ymin or pos_next[1] > ymax)
if done and boundary == "bounce":
restitution = 0.45
if pos_next[0] < xmin or pos_next[0] > xmax:
vel_next[0] *= -restitution
if pos_next[1] < ymin or pos_next[1] > ymax:
vel_next[1] *= -restitution
pos_next[0] = np.clip(pos_next[0], xmin, xmax)
pos_next[1] = np.clip(pos_next[1], ymin, ymax)
done = False
elif done and boundary == "clip":
pos_next[0] = np.clip(pos_next[0], xmin, xmax)
pos_next[1] = np.clip(pos_next[1], ymin, ymax)
done = False
next_state = state.copy()
next_state[:6] = np.array([pos_next[0], pos_next[1], theta_next, vel_next[0], vel_next[1], omega_next], dtype=np.float32)
next_state[6 : 6 + spec.action_dim] = u_next
return next_state.astype(np.float32), done