File size: 7,977 Bytes
604e535 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | """Clean image observations for the boat benchmark."""
from __future__ import annotations
import numpy as np
import torch
from PIL import Image, ImageDraw
from driftwm.sim.boat import BoatSpec, get_boat_spec
from driftwm.sim.dynamics import rot_body_to_world
_GRID_CACHE: dict[tuple[str, str, int, tuple[float, float, float, float], int], tuple[torch.Tensor, torch.Tensor]] = {}
_HULL_CACHE: dict[tuple[str, str, torch.dtype, float], torch.Tensor] = {}
def world_to_pixel(point: np.ndarray, workspace: tuple[float, float, float, float], image_size: int, pad: int) -> tuple[int, int]:
xmin, xmax, ymin, ymax = workspace
x = (float(point[0]) - xmin) / (xmax - xmin)
y = (float(point[1]) - ymin) / (ymax - ymin)
px = int(round(pad + x * (image_size - 2 * pad)))
py = int(round(image_size - pad - y * (image_size - 2 * pad)))
return px, py
def render_clean_boat_image(
state: np.ndarray,
boat: str | BoatSpec,
image_size: int = 64,
workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0),
pad: int = 4,
visual_scale: float = 2.5,
) -> Image.Image:
spec = get_boat_spec(boat) if isinstance(boat, str) else boat
img = Image.new("RGB", (image_size, image_size), (246, 249, 251))
draw = ImageDraw.Draw(img, "RGBA")
draw.rectangle([pad, pad, image_size - pad, image_size - pad], outline=(70, 82, 94, 255), width=1)
pos = np.asarray(state[:2], dtype=np.float32)
rot = rot_body_to_world(float(state[2]))
hull = ((spec.hull_vertices * float(visual_scale)) @ rot.T) + pos
pts = [world_to_pixel(p, workspace, image_size, pad) for p in hull]
draw.polygon(pts, fill=(35, 91, 140, 255), outline=(18, 45, 76, 255))
bow_marker = ((np.array([0.22, 0.0], dtype=np.float32) * float(visual_scale)) @ rot.T) + pos
mx, my = world_to_pixel(bow_marker, workspace, image_size, pad)
radius = max(2, image_size // 40)
draw.ellipse(
[mx - radius, my - radius, mx + radius, my + radius],
fill=(245, 204, 80, 255),
outline=(94, 65, 12, 255),
)
return img
def render_clean_boat_array(
state: np.ndarray,
boat: str | BoatSpec,
image_size: int = 64,
workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0),
visual_scale: float = 2.5,
) -> np.ndarray:
return np.asarray(
render_clean_boat_image(state, boat, image_size=image_size, workspace=workspace, visual_scale=visual_scale),
dtype=np.uint8,
)
def _polygon_mask(body_x: torch.Tensor, body_y: torch.Tensor, vertices: torch.Tensor) -> torch.Tensor:
inside = torch.zeros_like(body_x, dtype=torch.bool)
count = int(vertices.shape[0])
for i in range(count):
j = (i + 1) % count
xi, yi = vertices[i, 0], vertices[i, 1]
xj, yj = vertices[j, 0], vertices[j, 1]
crosses = (yi > body_y) != (yj > body_y)
x_at_y = (xj - xi) * (body_y - yi) / (yj - yi + 1.0e-6) + xi
inside = torch.logical_xor(inside, crosses & (body_x < x_at_y))
return inside
def render_clean_boat_tensor(
states: torch.Tensor,
boat_ids: torch.Tensor,
image_size: int = 160,
workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0),
pad: int = 4,
visual_scale: float = 2.5,
) -> torch.Tensor:
"""Render a batch of clean boat observations on the tensor device.
Args:
states: tensor with shape ``(N, 6)`` containing ``x, y, theta, ...``.
boat_ids: tensor with shape ``(N,)`` where 0 is twin and 1 is triangle.
Returns:
``uint8`` tensor with shape ``(N, 3, H, W)``.
"""
if states.ndim != 2 or states.shape[-1] < 3:
raise ValueError("states must have shape (N, >=3)")
if boat_ids.ndim != 1 or boat_ids.shape[0] != states.shape[0]:
raise ValueError("boat_ids must have shape (N,)")
device = states.device
n = int(states.shape[0])
h = int(image_size)
w = int(image_size)
dtype = states.dtype
image = torch.empty((n, 3, h, w), dtype=torch.uint8, device=device)
background = torch.tensor([246, 249, 251], dtype=torch.uint8, device=device).view(1, 3, 1, 1)
image.copy_(background.expand_as(image))
border = torch.tensor([70, 82, 94], dtype=torch.uint8, device=device).view(1, 3, 1)
image[:, :, pad, pad : w - pad + 1] = border
image[:, :, h - pad, pad : w - pad + 1] = border
image[:, :, pad : h - pad + 1, pad] = border
image[:, :, pad : h - pad + 1, w - pad] = border
xmin, xmax, ymin, ymax = workspace
grid_key = (str(device), str(dtype), h, tuple(float(v) for v in workspace), int(pad))
cached_grid = _GRID_CACHE.get(grid_key)
if cached_grid is None:
xs = torch.linspace(xmin, xmax, w - 2 * pad + 1, device=device, dtype=dtype)
ys = torch.linspace(ymax, ymin, h - 2 * pad + 1, device=device, dtype=dtype)
full_x = torch.empty((h, w), device=device, dtype=dtype)
full_y = torch.empty((h, w), device=device, dtype=dtype)
full_x[:] = xmin - 1.0
full_y[:] = ymin - 1.0
full_x[pad : h - pad + 1, pad : w - pad + 1] = xs.view(1, -1)
full_y[pad : h - pad + 1, pad : w - pad + 1] = ys.view(-1, 1)
_GRID_CACHE[grid_key] = (full_x, full_y)
else:
full_x, full_y = cached_grid
x = states[:, 0].view(n, 1, 1)
y = states[:, 1].view(n, 1, 1)
theta = states[:, 2].view(n, 1, 1)
cos_t = torch.cos(theta)
sin_t = torch.sin(theta)
dx = full_x.view(1, h, w) - x
dy = full_y.view(1, h, w) - y
body_x = cos_t * dx + sin_t * dy
body_y = -sin_t * dx + cos_t * dy
hull_color = torch.tensor([35, 91, 140], dtype=torch.uint8, device=device).view(1, 3, 1, 1)
marker_color = torch.tensor([245, 204, 80], dtype=torch.uint8, device=device).view(1, 3, 1, 1)
radius_world = float(max(2, h // 40)) * float(xmax - xmin) / float(h - 2 * pad)
marker_x = 0.22 * float(visual_scale)
marker = (body_x - marker_x).square() + body_y.square() <= radius_world * radius_world
hull_key_twin = ("twin", str(device), dtype, float(visual_scale))
hull_key_triangle = ("triangle", str(device), dtype, float(visual_scale))
twin_vertices = _HULL_CACHE.get(hull_key_twin)
if twin_vertices is None:
twin_vertices = torch.as_tensor(get_boat_spec("twin").hull_vertices, dtype=dtype, device=device) * float(visual_scale)
_HULL_CACHE[hull_key_twin] = twin_vertices
triangle_vertices = _HULL_CACHE.get(hull_key_triangle)
if triangle_vertices is None:
triangle_vertices = torch.as_tensor(get_boat_spec("triangle").hull_vertices, dtype=dtype, device=device) * float(visual_scale)
_HULL_CACHE[hull_key_triangle] = triangle_vertices
for boat_id, vertices in [(0, twin_vertices), (1, triangle_vertices)]:
index = torch.nonzero(boat_ids == boat_id, as_tuple=False).flatten()
if index.numel() == 0:
continue
mask = _polygon_mask(body_x[index], body_y[index], vertices)
image[index] = torch.where(mask[:, None], hull_color.expand(index.numel(), -1, h, w), image[index])
image[index] = torch.where(marker[index, None], marker_color.expand(index.numel(), -1, h, w), image[index])
return image
def render_clean_boat_history_tensor(
states: torch.Tensor,
boat_ids: torch.Tensor,
image_size: int = 160,
workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0),
visual_scale: float = 2.5,
) -> torch.Tensor:
"""Render state histories with shape ``(B, T, 6)`` to ``(B, T, 3, H, W)``."""
if states.ndim != 3:
raise ValueError("states must have shape (B, T, 6)")
b, t, d = states.shape
expanded_boats = boat_ids.view(b, 1).expand(b, t).reshape(b * t)
rendered = render_clean_boat_tensor(
states.reshape(b * t, d),
expanded_boats,
image_size=image_size,
workspace=workspace,
visual_scale=visual_scale,
)
return rendered.reshape(b, t, 3, image_size, image_size)
|