FlowMo-WM / driftwm /sim /render.py
cccat6's picture
Initial FlowMo-WM public code release
604e535 verified
from __future__ import annotations
import math
from pathlib import Path
from typing import Iterable
import numpy as np
from PIL import Image, ImageDraw
from driftwm.sim.boat import BoatSpec, get_boat_spec
from driftwm.sim.flow import Flow
from driftwm.sim.dynamics import rot_body_to_world
from driftwm.utils import ensure_dir
Color = tuple[int, int, int]
def _world_to_px(point: np.ndarray, workspace: tuple[float, float, float, float], size: int, pad: int) -> tuple[int, int]:
xmin, xmax, ymin, ymax = workspace
x = (point[0] - xmin) / max(1e-6, xmax - xmin)
y = (point[1] - ymin) / max(1e-6, ymax - ymin)
px = int(pad + x * (size - 2 * pad))
py = int(size - pad - y * (size - 2 * pad))
return px, py
def _draw_arrow(draw: ImageDraw.ImageDraw, p0: tuple[int, int], p1: tuple[int, int], color: Color, width: int = 2) -> None:
draw.line([p0, p1], fill=color, width=width)
dx = p1[0] - p0[0]
dy = p1[1] - p0[1]
angle = math.atan2(dy, dx)
head = 8
for sign in (-1, 1):
a = angle + sign * 2.55
p = (int(p1[0] + head * math.cos(a)), int(p1[1] + head * math.sin(a)))
draw.line([p1, p], fill=color, width=width)
def draw_flow_field(
draw: ImageDraw.ImageDraw,
flow: Flow,
workspace: tuple[float, float, float, float],
size: int,
pad: int,
t: float = 0.0,
grid: int = 9,
) -> None:
xmin, xmax, ymin, ymax = workspace
xs = np.linspace(xmin + 0.7, xmax - 0.7, grid)
ys = np.linspace(ymin + 0.7, ymax - 0.7, grid)
for x in xs:
for y in ys:
p = np.array([x, y], dtype=np.float32)
v = flow.velocity(p, t)
speed = float(np.linalg.norm(v))
if speed < 1e-4:
continue
q = p + 0.75 * v / max(0.15, speed)
p0 = _world_to_px(p, workspace, size, pad)
p1 = _world_to_px(q, workspace, size, pad)
_draw_arrow(draw, p0, p1, (160, 190, 218), width=1)
def draw_boat(
draw: ImageDraw.ImageDraw,
state: np.ndarray,
spec: BoatSpec,
workspace: tuple[float, float, float, float],
size: int,
pad: int,
fill: Color = (38, 89, 133),
outline: Color = (15, 37, 61),
) -> None:
pos = state[:2]
theta = float(state[2])
rot = rot_body_to_world(theta)
hull = (spec.hull_vertices @ rot.T) + pos
pts = [_world_to_px(p, workspace, size, pad) for p in hull]
draw.polygon(pts, fill=fill, outline=outline)
nose = pos + rot @ np.array([0.56, 0.0], dtype=np.float32)
_draw_arrow(draw, _world_to_px(pos, workspace, size, pad), _world_to_px(nose, workspace, size, pad), (230, 242, 255), width=2)
for r, d in zip(spec.thruster_positions, spec.thruster_dirs):
p = pos + rot @ r
q = p + rot @ (r - 0.16 * d)
_draw_arrow(draw, _world_to_px(q, workspace, size, pad), _world_to_px(p, workspace, size, pad), (215, 108, 71), width=2)
def render_frame(
state: np.ndarray,
boat: str | BoatSpec,
flow: Flow,
workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0),
trajectory: np.ndarray | None = None,
goal: np.ndarray | None = None,
planned: Iterable[np.ndarray] | None = None,
size: int = 512,
pad: int = 28,
t: float = 0.0,
) -> Image.Image:
spec = get_boat_spec(boat) if isinstance(boat, str) else boat
img = Image.new("RGB", (size, size), (247, 250, 252))
draw = ImageDraw.Draw(img, "RGBA")
draw.rectangle([pad, pad, size - pad, size - pad], outline=(35, 54, 72, 255), width=2)
draw_flow_field(draw, flow, workspace, size, pad, t=t)
if planned is not None:
for rollout in planned:
pts = [_world_to_px(p[:2], workspace, size, pad) for p in rollout]
if len(pts) > 1:
draw.line(pts, fill=(110, 138, 183, 45), width=1)
if trajectory is not None and len(trajectory) > 1:
pts = [_world_to_px(p[:2], workspace, size, pad) for p in trajectory]
draw.line(pts, fill=(22, 131, 105, 230), width=3)
if goal is not None:
gx, gy = _world_to_px(np.asarray(goal, dtype=np.float32), workspace, size, pad)
r = 8
draw.ellipse([gx - r, gy - r, gx + r, gy + r], fill=(211, 67, 78, 230), outline=(120, 25, 33, 255), width=2)
draw_boat(draw, state, spec, workspace, size, pad)
return img
def save_gif(frames: list[Image.Image], path: str | Path, duration_ms: int = 40) -> None:
path = Path(path)
ensure_dir(path.parent)
if not frames:
raise ValueError("no frames to save")
frames[0].save(path, save_all=True, append_images=frames[1:], duration=duration_ms, loop=0)
def save_boat_geometry(boat: str, path: str | Path, size: int = 420) -> None:
from driftwm.sim.flow import NoFlow
state = np.array([5.0, 5.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float32)
img = render_frame(state, boat, NoFlow(), trajectory=None, size=size)
path = Path(path)
ensure_dir(path.parent)
img.save(path)
def save_flow_quiver(flow: Flow, path: str | Path, workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0)) -> None:
state = np.array([5.0, 5.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float32)
img = render_frame(state, "twin", flow, workspace=workspace, trajectory=None)
path = Path(path)
ensure_dir(path.parent)
img.save(path)