"""Clean image observations for the boat benchmark.""" from __future__ import annotations import numpy as np import torch from PIL import Image, ImageDraw from driftwm.sim.boat import BoatSpec, get_boat_spec from driftwm.sim.dynamics import rot_body_to_world _GRID_CACHE: dict[tuple[str, str, int, tuple[float, float, float, float], int], tuple[torch.Tensor, torch.Tensor]] = {} _HULL_CACHE: dict[tuple[str, str, torch.dtype, float], torch.Tensor] = {} def world_to_pixel(point: np.ndarray, workspace: tuple[float, float, float, float], image_size: int, pad: int) -> tuple[int, int]: xmin, xmax, ymin, ymax = workspace x = (float(point[0]) - xmin) / (xmax - xmin) y = (float(point[1]) - ymin) / (ymax - ymin) px = int(round(pad + x * (image_size - 2 * pad))) py = int(round(image_size - pad - y * (image_size - 2 * pad))) return px, py def render_clean_boat_image( state: np.ndarray, boat: str | BoatSpec, image_size: int = 64, workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0), pad: int = 4, visual_scale: float = 2.5, ) -> Image.Image: spec = get_boat_spec(boat) if isinstance(boat, str) else boat img = Image.new("RGB", (image_size, image_size), (246, 249, 251)) draw = ImageDraw.Draw(img, "RGBA") draw.rectangle([pad, pad, image_size - pad, image_size - pad], outline=(70, 82, 94, 255), width=1) pos = np.asarray(state[:2], dtype=np.float32) rot = rot_body_to_world(float(state[2])) hull = ((spec.hull_vertices * float(visual_scale)) @ rot.T) + pos pts = [world_to_pixel(p, workspace, image_size, pad) for p in hull] draw.polygon(pts, fill=(35, 91, 140, 255), outline=(18, 45, 76, 255)) bow_marker = ((np.array([0.22, 0.0], dtype=np.float32) * float(visual_scale)) @ rot.T) + pos mx, my = world_to_pixel(bow_marker, workspace, image_size, pad) radius = max(2, image_size // 40) draw.ellipse( [mx - radius, my - radius, mx + radius, my + radius], fill=(245, 204, 80, 255), outline=(94, 65, 12, 255), ) return img def render_clean_boat_array( state: np.ndarray, boat: str | BoatSpec, image_size: int = 64, workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0), visual_scale: float = 2.5, ) -> np.ndarray: return np.asarray( render_clean_boat_image(state, boat, image_size=image_size, workspace=workspace, visual_scale=visual_scale), dtype=np.uint8, ) def _polygon_mask(body_x: torch.Tensor, body_y: torch.Tensor, vertices: torch.Tensor) -> torch.Tensor: inside = torch.zeros_like(body_x, dtype=torch.bool) count = int(vertices.shape[0]) for i in range(count): j = (i + 1) % count xi, yi = vertices[i, 0], vertices[i, 1] xj, yj = vertices[j, 0], vertices[j, 1] crosses = (yi > body_y) != (yj > body_y) x_at_y = (xj - xi) * (body_y - yi) / (yj - yi + 1.0e-6) + xi inside = torch.logical_xor(inside, crosses & (body_x < x_at_y)) return inside def render_clean_boat_tensor( states: torch.Tensor, boat_ids: torch.Tensor, image_size: int = 160, workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0), pad: int = 4, visual_scale: float = 2.5, ) -> torch.Tensor: """Render a batch of clean boat observations on the tensor device. Args: states: tensor with shape ``(N, 6)`` containing ``x, y, theta, ...``. boat_ids: tensor with shape ``(N,)`` where 0 is twin and 1 is triangle. Returns: ``uint8`` tensor with shape ``(N, 3, H, W)``. """ if states.ndim != 2 or states.shape[-1] < 3: raise ValueError("states must have shape (N, >=3)") if boat_ids.ndim != 1 or boat_ids.shape[0] != states.shape[0]: raise ValueError("boat_ids must have shape (N,)") device = states.device n = int(states.shape[0]) h = int(image_size) w = int(image_size) dtype = states.dtype image = torch.empty((n, 3, h, w), dtype=torch.uint8, device=device) background = torch.tensor([246, 249, 251], dtype=torch.uint8, device=device).view(1, 3, 1, 1) image.copy_(background.expand_as(image)) border = torch.tensor([70, 82, 94], dtype=torch.uint8, device=device).view(1, 3, 1) image[:, :, pad, pad : w - pad + 1] = border image[:, :, h - pad, pad : w - pad + 1] = border image[:, :, pad : h - pad + 1, pad] = border image[:, :, pad : h - pad + 1, w - pad] = border xmin, xmax, ymin, ymax = workspace grid_key = (str(device), str(dtype), h, tuple(float(v) for v in workspace), int(pad)) cached_grid = _GRID_CACHE.get(grid_key) if cached_grid is None: xs = torch.linspace(xmin, xmax, w - 2 * pad + 1, device=device, dtype=dtype) ys = torch.linspace(ymax, ymin, h - 2 * pad + 1, device=device, dtype=dtype) full_x = torch.empty((h, w), device=device, dtype=dtype) full_y = torch.empty((h, w), device=device, dtype=dtype) full_x[:] = xmin - 1.0 full_y[:] = ymin - 1.0 full_x[pad : h - pad + 1, pad : w - pad + 1] = xs.view(1, -1) full_y[pad : h - pad + 1, pad : w - pad + 1] = ys.view(-1, 1) _GRID_CACHE[grid_key] = (full_x, full_y) else: full_x, full_y = cached_grid x = states[:, 0].view(n, 1, 1) y = states[:, 1].view(n, 1, 1) theta = states[:, 2].view(n, 1, 1) cos_t = torch.cos(theta) sin_t = torch.sin(theta) dx = full_x.view(1, h, w) - x dy = full_y.view(1, h, w) - y body_x = cos_t * dx + sin_t * dy body_y = -sin_t * dx + cos_t * dy hull_color = torch.tensor([35, 91, 140], dtype=torch.uint8, device=device).view(1, 3, 1, 1) marker_color = torch.tensor([245, 204, 80], dtype=torch.uint8, device=device).view(1, 3, 1, 1) radius_world = float(max(2, h // 40)) * float(xmax - xmin) / float(h - 2 * pad) marker_x = 0.22 * float(visual_scale) marker = (body_x - marker_x).square() + body_y.square() <= radius_world * radius_world hull_key_twin = ("twin", str(device), dtype, float(visual_scale)) hull_key_triangle = ("triangle", str(device), dtype, float(visual_scale)) twin_vertices = _HULL_CACHE.get(hull_key_twin) if twin_vertices is None: twin_vertices = torch.as_tensor(get_boat_spec("twin").hull_vertices, dtype=dtype, device=device) * float(visual_scale) _HULL_CACHE[hull_key_twin] = twin_vertices triangle_vertices = _HULL_CACHE.get(hull_key_triangle) if triangle_vertices is None: triangle_vertices = torch.as_tensor(get_boat_spec("triangle").hull_vertices, dtype=dtype, device=device) * float(visual_scale) _HULL_CACHE[hull_key_triangle] = triangle_vertices for boat_id, vertices in [(0, twin_vertices), (1, triangle_vertices)]: index = torch.nonzero(boat_ids == boat_id, as_tuple=False).flatten() if index.numel() == 0: continue mask = _polygon_mask(body_x[index], body_y[index], vertices) image[index] = torch.where(mask[:, None], hull_color.expand(index.numel(), -1, h, w), image[index]) image[index] = torch.where(marker[index, None], marker_color.expand(index.numel(), -1, h, w), image[index]) return image def render_clean_boat_history_tensor( states: torch.Tensor, boat_ids: torch.Tensor, image_size: int = 160, workspace: tuple[float, float, float, float] = (0.0, 10.0, 0.0, 10.0), visual_scale: float = 2.5, ) -> torch.Tensor: """Render state histories with shape ``(B, T, 6)`` to ``(B, T, 3, H, W)``.""" if states.ndim != 3: raise ValueError("states must have shape (B, T, 6)") b, t, d = states.shape expanded_boats = boat_ids.view(b, 1).expand(b, t).reshape(b * t) rendered = render_clean_boat_tensor( states.reshape(b * t, d), expanded_boats, image_size=image_size, workspace=workspace, visual_scale=visual_scale, ) return rendered.reshape(b, t, 3, image_size, image_size)