FlowMo-WM / experiments /shared /src /vision /clean_renderer.py
cccat6's picture
Initial FlowMo-WM public code release
604e535 verified
raw
history blame
7.98 kB
"""Clean image observations for the boat benchmark."""
from __future__ import annotations
import numpy as np
import torch
from PIL import Image, ImageDraw
from driftwm.sim.boat import BoatSpec, get_boat_spec
from driftwm.sim.dynamics import rot_body_to_world
_GRID_CACHE: dict[tuple[str, str, int, tuple[float, float, float, float], int], tuple[torch.Tensor, torch.Tensor]] = {}
_HULL_CACHE: dict[tuple[str, str, torch.dtype, float], torch.Tensor] = {}
def world_to_pixel(point: np.ndarray, workspace: tuple[float, float, float, float], image_size: int, pad: int) -> tuple[int, int]:
xmin, xmax, ymin, ymax = workspace
x = (float(point[0]) - xmin) / (xmax - xmin)
y = (float(point[1]) - ymin) / (ymax - ymin)
px = int(round(pad + x * (image_size - 2 * pad)))
py = int(round(image_size - pad - y * (image_size - 2 * pad)))
return px, py
def render_clean_boat_image(
state: np.ndarray,
boat: str | BoatSpec,
image_size: int = 64,
workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0),
pad: int = 4,
visual_scale: float = 2.5,
) -> Image.Image:
spec = get_boat_spec(boat) if isinstance(boat, str) else boat
img = Image.new("RGB", (image_size, image_size), (246, 249, 251))
draw = ImageDraw.Draw(img, "RGBA")
draw.rectangle([pad, pad, image_size - pad, image_size - pad], outline=(70, 82, 94, 255), width=1)
pos = np.asarray(state[:2], dtype=np.float32)
rot = rot_body_to_world(float(state[2]))
hull = ((spec.hull_vertices * float(visual_scale)) @ rot.T) + pos
pts = [world_to_pixel(p, workspace, image_size, pad) for p in hull]
draw.polygon(pts, fill=(35, 91, 140, 255), outline=(18, 45, 76, 255))
bow_marker = ((np.array([0.22, 0.0], dtype=np.float32) * float(visual_scale)) @ rot.T) + pos
mx, my = world_to_pixel(bow_marker, workspace, image_size, pad)
radius = max(2, image_size // 40)
draw.ellipse(
[mx - radius, my - radius, mx + radius, my + radius],
fill=(245, 204, 80, 255),
outline=(94, 65, 12, 255),
)
return img
def render_clean_boat_array(
state: np.ndarray,
boat: str | BoatSpec,
image_size: int = 64,
workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0),
visual_scale: float = 2.5,
) -> np.ndarray:
return np.asarray(
render_clean_boat_image(state, boat, image_size=image_size, workspace=workspace, visual_scale=visual_scale),
dtype=np.uint8,
)
def _polygon_mask(body_x: torch.Tensor, body_y: torch.Tensor, vertices: torch.Tensor) -> torch.Tensor:
inside = torch.zeros_like(body_x, dtype=torch.bool)
count = int(vertices.shape[0])
for i in range(count):
j = (i + 1) % count
xi, yi = vertices[i, 0], vertices[i, 1]
xj, yj = vertices[j, 0], vertices[j, 1]
crosses = (yi > body_y) != (yj > body_y)
x_at_y = (xj - xi) * (body_y - yi) / (yj - yi + 1.0e-6) + xi
inside = torch.logical_xor(inside, crosses & (body_x < x_at_y))
return inside
def render_clean_boat_tensor(
states: torch.Tensor,
boat_ids: torch.Tensor,
image_size: int = 160,
workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0),
pad: int = 4,
visual_scale: float = 2.5,
) -> torch.Tensor:
"""Render a batch of clean boat observations on the tensor device.
Args:
states: tensor with shape ``(N, 6)`` containing ``x, y, theta, ...``.
boat_ids: tensor with shape ``(N,)`` where 0 is twin and 1 is triangle.
Returns:
``uint8`` tensor with shape ``(N, 3, H, W)``.
"""
if states.ndim != 2 or states.shape[-1] < 3:
raise ValueError("states must have shape (N, >=3)")
if boat_ids.ndim != 1 or boat_ids.shape[0] != states.shape[0]:
raise ValueError("boat_ids must have shape (N,)")
device = states.device
n = int(states.shape[0])
h = int(image_size)
w = int(image_size)
dtype = states.dtype
image = torch.empty((n, 3, h, w), dtype=torch.uint8, device=device)
background = torch.tensor([246, 249, 251], dtype=torch.uint8, device=device).view(1, 3, 1, 1)
image.copy_(background.expand_as(image))
border = torch.tensor([70, 82, 94], dtype=torch.uint8, device=device).view(1, 3, 1)
image[:, :, pad, pad : w - pad + 1] = border
image[:, :, h - pad, pad : w - pad + 1] = border
image[:, :, pad : h - pad + 1, pad] = border
image[:, :, pad : h - pad + 1, w - pad] = border
xmin, xmax, ymin, ymax = workspace
grid_key = (str(device), str(dtype), h, tuple(float(v) for v in workspace), int(pad))
cached_grid = _GRID_CACHE.get(grid_key)
if cached_grid is None:
xs = torch.linspace(xmin, xmax, w - 2 * pad + 1, device=device, dtype=dtype)
ys = torch.linspace(ymax, ymin, h - 2 * pad + 1, device=device, dtype=dtype)
full_x = torch.empty((h, w), device=device, dtype=dtype)
full_y = torch.empty((h, w), device=device, dtype=dtype)
full_x[:] = xmin - 1.0
full_y[:] = ymin - 1.0
full_x[pad : h - pad + 1, pad : w - pad + 1] = xs.view(1, -1)
full_y[pad : h - pad + 1, pad : w - pad + 1] = ys.view(-1, 1)
_GRID_CACHE[grid_key] = (full_x, full_y)
else:
full_x, full_y = cached_grid
x = states[:, 0].view(n, 1, 1)
y = states[:, 1].view(n, 1, 1)
theta = states[:, 2].view(n, 1, 1)
cos_t = torch.cos(theta)
sin_t = torch.sin(theta)
dx = full_x.view(1, h, w) - x
dy = full_y.view(1, h, w) - y
body_x = cos_t * dx + sin_t * dy
body_y = -sin_t * dx + cos_t * dy
hull_color = torch.tensor([35, 91, 140], dtype=torch.uint8, device=device).view(1, 3, 1, 1)
marker_color = torch.tensor([245, 204, 80], dtype=torch.uint8, device=device).view(1, 3, 1, 1)
radius_world = float(max(2, h // 40)) * float(xmax - xmin) / float(h - 2 * pad)
marker_x = 0.22 * float(visual_scale)
marker = (body_x - marker_x).square() + body_y.square() <= radius_world * radius_world
hull_key_twin = ("twin", str(device), dtype, float(visual_scale))
hull_key_triangle = ("triangle", str(device), dtype, float(visual_scale))
twin_vertices = _HULL_CACHE.get(hull_key_twin)
if twin_vertices is None:
twin_vertices = torch.as_tensor(get_boat_spec("twin").hull_vertices, dtype=dtype, device=device) * float(visual_scale)
_HULL_CACHE[hull_key_twin] = twin_vertices
triangle_vertices = _HULL_CACHE.get(hull_key_triangle)
if triangle_vertices is None:
triangle_vertices = torch.as_tensor(get_boat_spec("triangle").hull_vertices, dtype=dtype, device=device) * float(visual_scale)
_HULL_CACHE[hull_key_triangle] = triangle_vertices
for boat_id, vertices in [(0, twin_vertices), (1, triangle_vertices)]:
index = torch.nonzero(boat_ids == boat_id, as_tuple=False).flatten()
if index.numel() == 0:
continue
mask = _polygon_mask(body_x[index], body_y[index], vertices)
image[index] = torch.where(mask[:, None], hull_color.expand(index.numel(), -1, h, w), image[index])
image[index] = torch.where(marker[index, None], marker_color.expand(index.numel(), -1, h, w), image[index])
return image
def render_clean_boat_history_tensor(
states: torch.Tensor,
boat_ids: torch.Tensor,
image_size: int = 160,
workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0),
visual_scale: float = 2.5,
) -> torch.Tensor:
"""Render state histories with shape ``(B, T, 6)`` to ``(B, T, 3, H, W)``."""
if states.ndim != 3:
raise ValueError("states must have shape (B, T, 6)")
b, t, d = states.shape
expanded_boats = boat_ids.view(b, 1).expand(b, t).reshape(b * t)
rendered = render_clean_boat_tensor(
states.reshape(b * t, d),
expanded_boats,
image_size=image_size,
workspace=workspace,
visual_scale=visual_scale,
)
return rendered.reshape(b, t, 3, image_size, image_size)