| """Clean image observations for the boat benchmark.""" |
|
|
| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
| from PIL import Image, ImageDraw |
|
|
| from driftwm.sim.boat import BoatSpec, get_boat_spec |
| from driftwm.sim.dynamics import rot_body_to_world |
|
|
| _GRID_CACHE: dict[tuple[str, str, int, tuple[float, float, float, float], int], tuple[torch.Tensor, torch.Tensor]] = {} |
| _HULL_CACHE: dict[tuple[str, str, torch.dtype, float], torch.Tensor] = {} |
|
|
|
|
| def world_to_pixel(point: np.ndarray, workspace: tuple[float, float, float, float], image_size: int, pad: int) -> tuple[int, int]: |
| xmin, xmax, ymin, ymax = workspace |
| x = (float(point[0]) - xmin) / (xmax - xmin) |
| y = (float(point[1]) - ymin) / (ymax - ymin) |
| px = int(round(pad + x * (image_size - 2 * pad))) |
| py = int(round(image_size - pad - y * (image_size - 2 * pad))) |
| return px, py |
|
|
|
|
| def render_clean_boat_image( |
| state: np.ndarray, |
| boat: str | BoatSpec, |
| image_size: int = 64, |
| workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0), |
| pad: int = 4, |
| visual_scale: float = 2.5, |
| ) -> Image.Image: |
| spec = get_boat_spec(boat) if isinstance(boat, str) else boat |
| img = Image.new("RGB", (image_size, image_size), (246, 249, 251)) |
| draw = ImageDraw.Draw(img, "RGBA") |
| draw.rectangle([pad, pad, image_size - pad, image_size - pad], outline=(70, 82, 94, 255), width=1) |
| pos = np.asarray(state[:2], dtype=np.float32) |
| rot = rot_body_to_world(float(state[2])) |
| hull = ((spec.hull_vertices * float(visual_scale)) @ rot.T) + pos |
| pts = [world_to_pixel(p, workspace, image_size, pad) for p in hull] |
| draw.polygon(pts, fill=(35, 91, 140, 255), outline=(18, 45, 76, 255)) |
| bow_marker = ((np.array([0.22, 0.0], dtype=np.float32) * float(visual_scale)) @ rot.T) + pos |
| mx, my = world_to_pixel(bow_marker, workspace, image_size, pad) |
| radius = max(2, image_size // 40) |
| draw.ellipse( |
| [mx - radius, my - radius, mx + radius, my + radius], |
| fill=(245, 204, 80, 255), |
| outline=(94, 65, 12, 255), |
| ) |
| return img |
|
|
|
|
| def render_clean_boat_array( |
| state: np.ndarray, |
| boat: str | BoatSpec, |
| image_size: int = 64, |
| workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0), |
| visual_scale: float = 2.5, |
| ) -> np.ndarray: |
| return np.asarray( |
| render_clean_boat_image(state, boat, image_size=image_size, workspace=workspace, visual_scale=visual_scale), |
| dtype=np.uint8, |
| ) |
|
|
|
|
| def _polygon_mask(body_x: torch.Tensor, body_y: torch.Tensor, vertices: torch.Tensor) -> torch.Tensor: |
| inside = torch.zeros_like(body_x, dtype=torch.bool) |
| count = int(vertices.shape[0]) |
| for i in range(count): |
| j = (i + 1) % count |
| xi, yi = vertices[i, 0], vertices[i, 1] |
| xj, yj = vertices[j, 0], vertices[j, 1] |
| crosses = (yi > body_y) != (yj > body_y) |
| x_at_y = (xj - xi) * (body_y - yi) / (yj - yi + 1.0e-6) + xi |
| inside = torch.logical_xor(inside, crosses & (body_x < x_at_y)) |
| return inside |
|
|
|
|
| def render_clean_boat_tensor( |
| states: torch.Tensor, |
| boat_ids: torch.Tensor, |
| image_size: int = 160, |
| workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0), |
| pad: int = 4, |
| visual_scale: float = 2.5, |
| ) -> torch.Tensor: |
| """Render a batch of clean boat observations on the tensor device. |
| |
| Args: |
| states: tensor with shape ``(N, 6)`` containing ``x, y, theta, ...``. |
| boat_ids: tensor with shape ``(N,)`` where 0 is twin and 1 is triangle. |
| |
| Returns: |
| ``uint8`` tensor with shape ``(N, 3, H, W)``. |
| """ |
| if states.ndim != 2 or states.shape[-1] < 3: |
| raise ValueError("states must have shape (N, >=3)") |
| if boat_ids.ndim != 1 or boat_ids.shape[0] != states.shape[0]: |
| raise ValueError("boat_ids must have shape (N,)") |
| device = states.device |
| n = int(states.shape[0]) |
| h = int(image_size) |
| w = int(image_size) |
| dtype = states.dtype |
| image = torch.empty((n, 3, h, w), dtype=torch.uint8, device=device) |
| background = torch.tensor([246, 249, 251], dtype=torch.uint8, device=device).view(1, 3, 1, 1) |
| image.copy_(background.expand_as(image)) |
|
|
| border = torch.tensor([70, 82, 94], dtype=torch.uint8, device=device).view(1, 3, 1) |
| image[:, :, pad, pad : w - pad + 1] = border |
| image[:, :, h - pad, pad : w - pad + 1] = border |
| image[:, :, pad : h - pad + 1, pad] = border |
| image[:, :, pad : h - pad + 1, w - pad] = border |
|
|
| xmin, xmax, ymin, ymax = workspace |
| grid_key = (str(device), str(dtype), h, tuple(float(v) for v in workspace), int(pad)) |
| cached_grid = _GRID_CACHE.get(grid_key) |
| if cached_grid is None: |
| xs = torch.linspace(xmin, xmax, w - 2 * pad + 1, device=device, dtype=dtype) |
| ys = torch.linspace(ymax, ymin, h - 2 * pad + 1, device=device, dtype=dtype) |
| full_x = torch.empty((h, w), device=device, dtype=dtype) |
| full_y = torch.empty((h, w), device=device, dtype=dtype) |
| full_x[:] = xmin - 1.0 |
| full_y[:] = ymin - 1.0 |
| full_x[pad : h - pad + 1, pad : w - pad + 1] = xs.view(1, -1) |
| full_y[pad : h - pad + 1, pad : w - pad + 1] = ys.view(-1, 1) |
| _GRID_CACHE[grid_key] = (full_x, full_y) |
| else: |
| full_x, full_y = cached_grid |
|
|
| x = states[:, 0].view(n, 1, 1) |
| y = states[:, 1].view(n, 1, 1) |
| theta = states[:, 2].view(n, 1, 1) |
| cos_t = torch.cos(theta) |
| sin_t = torch.sin(theta) |
| dx = full_x.view(1, h, w) - x |
| dy = full_y.view(1, h, w) - y |
| body_x = cos_t * dx + sin_t * dy |
| body_y = -sin_t * dx + cos_t * dy |
|
|
| hull_color = torch.tensor([35, 91, 140], dtype=torch.uint8, device=device).view(1, 3, 1, 1) |
| marker_color = torch.tensor([245, 204, 80], dtype=torch.uint8, device=device).view(1, 3, 1, 1) |
| radius_world = float(max(2, h // 40)) * float(xmax - xmin) / float(h - 2 * pad) |
| marker_x = 0.22 * float(visual_scale) |
| marker = (body_x - marker_x).square() + body_y.square() <= radius_world * radius_world |
|
|
| hull_key_twin = ("twin", str(device), dtype, float(visual_scale)) |
| hull_key_triangle = ("triangle", str(device), dtype, float(visual_scale)) |
| twin_vertices = _HULL_CACHE.get(hull_key_twin) |
| if twin_vertices is None: |
| twin_vertices = torch.as_tensor(get_boat_spec("twin").hull_vertices, dtype=dtype, device=device) * float(visual_scale) |
| _HULL_CACHE[hull_key_twin] = twin_vertices |
| triangle_vertices = _HULL_CACHE.get(hull_key_triangle) |
| if triangle_vertices is None: |
| triangle_vertices = torch.as_tensor(get_boat_spec("triangle").hull_vertices, dtype=dtype, device=device) * float(visual_scale) |
| _HULL_CACHE[hull_key_triangle] = triangle_vertices |
| for boat_id, vertices in [(0, twin_vertices), (1, triangle_vertices)]: |
| index = torch.nonzero(boat_ids == boat_id, as_tuple=False).flatten() |
| if index.numel() == 0: |
| continue |
| mask = _polygon_mask(body_x[index], body_y[index], vertices) |
| image[index] = torch.where(mask[:, None], hull_color.expand(index.numel(), -1, h, w), image[index]) |
| image[index] = torch.where(marker[index, None], marker_color.expand(index.numel(), -1, h, w), image[index]) |
| return image |
|
|
|
|
| def render_clean_boat_history_tensor( |
| states: torch.Tensor, |
| boat_ids: torch.Tensor, |
| image_size: int = 160, |
| workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0), |
| visual_scale: float = 2.5, |
| ) -> torch.Tensor: |
| """Render state histories with shape ``(B, T, 6)`` to ``(B, T, 3, H, W)``.""" |
| if states.ndim != 3: |
| raise ValueError("states must have shape (B, T, 6)") |
| b, t, d = states.shape |
| expanded_boats = boat_ids.view(b, 1).expand(b, t).reshape(b * t) |
| rendered = render_clean_boat_tensor( |
| states.reshape(b * t, d), |
| expanded_boats, |
| image_size=image_size, |
| workspace=workspace, |
| visual_scale=visual_scale, |
| ) |
| return rendered.reshape(b, t, 3, image_size, image_size) |
|
|