FlowMo-WM / experiments /plot_task_schematics.py
cccat6's picture
Rename traditional baseline artifacts as LOS controllers
cc396fd verified
from __future__ import annotations
from pathlib import Path
from PIL import Image, ImageDraw
from driftwm.utils import ensure_dir
OUT_DIR = Path("experiments/reports/figures/task_schematics")
WORKSPACE = (0.0, 10.0, 0.0, 10.0)
BG = (247, 250, 252)
BORDER = (45, 62, 80)
START = (35, 91, 140)
GOAL = (203, 62, 74)
ACTIVE = (30, 139, 88)
PASSIVE = (145, 153, 163)
FLOW = (68, 135, 205)
SWITCH = (238, 184, 72)
def w2p(x: float, y: float, size: int = 640, pad: int = 58) -> tuple[int, int]:
xmin, xmax, ymin, ymax = WORKSPACE
px = int(pad + (x - xmin) / (xmax - xmin) * (size - 2 * pad))
py = int(size - pad - (y - ymin) / (ymax - ymin) * (size - 2 * pad))
return px, py
def draw_arrow(
draw: ImageDraw.ImageDraw,
p0: tuple[int, int],
p1: tuple[int, int],
color: tuple[int, int, int],
width: int = 6,
head: int = 18,
) -> None:
import math
draw.line((p0, p1), fill=color, width=width)
dx = p1[0] - p0[0]
dy = p1[1] - p0[1]
angle = math.atan2(dy, dx)
for delta in (2.55, -2.55):
a = angle + delta
q = (int(p1[0] + head * math.cos(a)), int(p1[1] + head * math.sin(a)))
draw.line((p1, q), fill=color, width=width)
def draw_polyline_arrow(
draw: ImageDraw.ImageDraw,
pts: list[tuple[int, int]],
color: tuple[int, int, int],
width: int = 6,
) -> None:
if len(pts) < 2:
return
draw.line(pts, fill=color, width=width, joint="curve")
draw_arrow(draw, pts[-2], pts[-1], color, width=width)
def circle(draw: ImageDraw.ImageDraw, center: tuple[int, int], radius: int, color: tuple[int, int, int], width: int = 0) -> None:
box = (center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius)
if width:
draw.ellipse(box, outline=color, width=width)
else:
draw.ellipse(box, fill=color)
def base(size: int = 640) -> tuple[Image.Image, ImageDraw.ImageDraw]:
img = Image.new("RGB", (size, size), BG)
draw = ImageDraw.Draw(img)
draw.rectangle((58, 58, size - 58, size - 58), outline=BORDER, width=3)
return img, draw
def draw_flow_arrows(draw: ImageDraw.ImageDraw, arrows: list[tuple[tuple[float, float], tuple[float, float]]]) -> None:
for start, end in arrows:
draw_arrow(draw, w2p(*start), w2p(*end), FLOW, width=5, head=15)
def save(img: Image.Image, name: str) -> None:
ensure_dir(OUT_DIR / "clean")
path = OUT_DIR / "clean" / name
img.save(path)
print(path)
def reach_target() -> Image.Image:
img, draw = base()
start = w2p(2.0, 2.0)
goal = w2p(8.0, 8.0)
draw_arrow(draw, start, goal, ACTIVE)
circle(draw, start, 13, START)
circle(draw, goal, 18, GOAL, width=6)
return img
def station_keeping() -> Image.Image:
img, draw = base()
center = w2p(5.0, 5.0)
circle(draw, center, 72, GOAL, width=6)
draw.line((center[0] - 22, center[1], center[0] + 22, center[1]), fill=GOAL, width=4)
draw.line((center[0], center[1] - 22, center[0], center[1] + 22), fill=GOAL, width=4)
draw_flow_arrows(draw, [((6.8, 6.8), (5.85, 5.95)), ((6.6, 3.3), (5.8, 4.2)), ((3.5, 6.8), (4.25, 5.9))])
pts = [w2p(4.55, 4.25), w2p(5.2, 4.45), w2p(5.45, 5.15), w2p(4.85, 5.45), w2p(4.65, 4.8), w2p(5.0, 5.02)]
draw_polyline_arrow(draw, pts, ACTIVE, width=6)
circle(draw, pts[0], 12, START)
return img
def waypoint_square() -> Image.Image:
img, draw = base()
start = w2p(2.0, 2.0)
pts = [w2p(2.5, 2.5), w2p(7.5, 2.5), w2p(7.5, 7.5), w2p(2.5, 7.5)]
draw_polyline_arrow(draw, [start] + pts, ACTIVE, width=6)
circle(draw, start, 13, START)
for p in pts:
circle(draw, p, 13, GOAL)
return img
def waypoint_zigzag() -> Image.Image:
img, draw = base()
start = w2p(2.0, 2.0)
pts = [w2p(2.5, 7.0), w2p(4.2, 3.0), w2p(5.8, 7.0), w2p(7.5, 3.0)]
draw_polyline_arrow(draw, [start] + pts, ACTIVE, width=6)
circle(draw, start, 13, START)
for p in pts:
circle(draw, p, 13, GOAL)
return img
def make_contact_sheet(names: list[str]) -> None:
ensure_dir(OUT_DIR)
thumbs = [Image.open(OUT_DIR / "clean" / name).resize((240, 240), Image.Resampling.LANCZOS) for name in names]
sheet = Image.new("RGB", (2 * 280 + 40, 2 * 300 + 40), (246, 249, 251))
draw = ImageDraw.Draw(sheet)
labels = [name.removeprefix("task_").removesuffix(".png") for name in names]
for i, (thumb, label) in enumerate(zip(thumbs, labels, strict=True)):
row, col = divmod(i, 2)
x = 20 + col * 280
y = 20 + row * 300
sheet.paste(thumb, (x, y))
draw.text((x, y + 250), label, fill=(35, 45, 58))
path = OUT_DIR / "task_schematics_contact_sheet.png"
sheet.save(path)
print(path)
def main() -> None:
tasks = {
"task_reach_target.png": reach_target(),
"task_station_keeping.png": station_keeping(),
"task_waypoint_square.png": waypoint_square(),
"task_waypoint_zigzag.png": waypoint_zigzag(),
}
for name, img in tasks.items():
save(img, name)
make_contact_sheet(
[
"task_reach_target.png",
"task_station_keeping.png",
"task_waypoint_square.png",
"task_waypoint_zigzag.png",
]
)
if __name__ == "__main__":
main()