FlowMo-WM / experiments /export_paper_artifacts.py
cccat6's picture
Clean public repository for reproducibility
8e384df verified
#!/usr/bin/env python3
"""Export paper figures, tables, raw data, and provenance files.
All values are read from experiments/reports/*.json and GIF frames are
extracted from experiments/reports/paper_planning/gifs/*.gif.
"""
from __future__ import annotations
import csv
import json
import math
import re
import shutil
import subprocess
from collections import defaultdict
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, Iterable
from PIL import Image, ImageDraw, ImageFont, ImageSequence
ROOT = Path(__file__).resolve().parents[1]
REPORTS = ROOT / "experiments" / "reports"
PLANNING_DIR = REPORTS / "paper_planning"
GIF_DIR = PLANNING_DIR / "gifs"
OUT = REPORTS / "paper_artifacts"
PREDICTION_JSON = REPORTS / "paper_prediction.json"
PROBE_JSON = REPORTS / "paper_flowmo_latent_probes.json"
TASK_ORDER = ["reach_target", "station_keeping", "waypoint_square", "waypoint_zigzag"]
BOAT_ORDER = ["twin", "triangle"]
FLOW_ORDER = [
"noflow",
"uniform",
"vortex_center",
"double_gyre",
"source_sink",
"source_sink_pair",
"gradient",
"shear",
"turbulent_patch",
"random_fourier",
]
METHOD_ORDER = [
"flowmo",
"leworldmodel",
"planet",
"tdmpc2",
"pid_los_controller",
"no_flow_los_controller",
"current_estimator_los_controller",
"oracle_flow_los_controller",
]
LEARNED_METHODS = ["flowmo", "leworldmodel", "planet", "tdmpc2"]
TRADITIONAL_METHODS = [
"pid_los_controller",
"no_flow_los_controller",
"current_estimator_los_controller",
"oracle_flow_los_controller",
]
METHOD_LABEL = {
"flowmo": "FlowMo-WM",
"leworldmodel": "LeWorldModel",
"planet": "PlaNet/RSSM",
"tdmpc2": "TD-MPC2",
"pid_los_controller": "PID/LOS",
"no_flow_los_controller": "No-Flow LOS",
"current_estimator_los_controller": "Current-Estimator LOS",
"oracle_flow_los_controller": "Oracle-Flow LOS",
}
METHOD_DESCRIPTION = {
"pid_los_controller": "Line-of-sight waypoint tracking baseline using the clean-image pose estimate.",
"no_flow_los_controller": "Line-of-sight tracking that ignores ambient flow; measures the cost of no current compensation.",
"current_estimator_los_controller": "Line-of-sight tracking with an online drift estimate from recent pose history.",
"oracle_flow_los_controller": "Line-of-sight tracking with privileged true local simulator flow feed-forward; a reference controller, not a world-model baseline.",
}
METHOD_SHORT = {
"flowmo": "FlowMo",
"leworldmodel": "LeWM",
"planet": "RSSM",
"tdmpc2": "TD2",
"pid_los_controller": "PID/LOS",
"no_flow_los_controller": "NF-LOS",
"current_estimator_los_controller": "CE-LOS",
"oracle_flow_los_controller": "OF-LOS",
}
TASK_LABEL = {
"reach_target": "Reach",
"station_keeping": "Station",
"waypoint_square": "Square",
"waypoint_zigzag": "Zigzag",
}
BOAT_LABEL = {"twin": "Twin", "triangle": "Triangle"}
FLOW_LABEL = {
"noflow": "No flow",
"uniform": "Uniform",
"vortex_center": "Vortex",
"double_gyre": "Double gyre",
"source_sink": "Source/sink",
"source_sink_pair": "Src/sink pair",
"gradient": "Gradient",
"shear": "Shear",
"turbulent_patch": "Turbulent",
"random_fourier": "Fourier",
}
HORIZONS = [1, 3, 6, 8, 10, 20, 30, 40, 60]
FIG3_TASK_FLOW = {
"reach_target": "uniform",
"station_keeping": "vortex_center",
"waypoint_square": "gradient",
"waypoint_zigzag": "random_fourier",
}
FIG3_EPISODE = 0
METHOD_COLORS = {
"flowmo": (31, 119, 180),
"leworldmodel": (255, 127, 14),
"planet": (44, 160, 44),
"tdmpc2": (148, 103, 189),
"pid_los_controller": (127, 127, 127),
"no_flow_los_controller": (214, 39, 40),
"current_estimator_los_controller": (23, 190, 207),
"oracle_flow_los_controller": (140, 86, 75),
}
@dataclass(frozen=True)
class SummaryRecord:
source_file: str
item_index: int
method: str
task: str
boat: str
flow_type: str
context_mode: str
episodes: int
successes: int
success_rate: float
final_distance_mean: float
mean_min_goal_distance: float
path_length_success_mean: float | None
energy_success_mean: float | None
steps_success_mean: float | None
@dataclass(frozen=True)
class EpisodeRecord:
source_file: str
item_index: int
result_index: int
method: str
task: str
boat: str
flow_type: str
context_mode: str
episode: int
success: bool
final_distance: float
mean_min_goal_distance: float
energy: float | None
path_length: float | None
steps: int | None
def rel(path: Path) -> str:
return str(path.relative_to(ROOT))
def ensure_dir(path: Path) -> None:
path.mkdir(parents=True, exist_ok=True)
def trim_whitespace(img: Image.Image, pad_x: int = 8, pad_y: int = 0, threshold: int = 250) -> Image.Image:
"""Crop near-white border while preserving a small horizontal margin."""
rgb = img.convert("RGB")
pix = rgb.load()
w, h = rgb.size
min_x, min_y = w, h
max_x, max_y = -1, -1
for y in range(h):
for x in range(w):
r, g, b = pix[x, y]
if r < threshold or g < threshold or b < threshold:
min_x = min(min_x, x)
min_y = min(min_y, y)
max_x = max(max_x, x)
max_y = max(max_y, y)
if max_x < min_x or max_y < min_y:
return rgb
min_x = max(0, min_x - pad_x)
max_x = min(w - 1, max_x + pad_x)
min_y = max(0, min_y - pad_y)
max_y = min(h - 1, max_y + pad_y)
return rgb.crop((min_x, min_y, max_x + 1, max_y + 1))
def read_json(path: Path) -> Any:
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def write_text(path: Path, text: str) -> None:
ensure_dir(path.parent)
path.write_text(text, encoding="utf-8")
def safe_float(value: Any) -> float | None:
if value is None:
return None
if isinstance(value, float) and math.isnan(value):
return None
return float(value)
def fmt(value: float | int | None, digits: int = 3) -> str:
if value is None:
return "--"
if isinstance(value, float) and math.isnan(value):
return "--"
return f"{float(value):.{digits}f}"
def pct(value: float | None, digits: int = 1) -> str:
if value is None:
return "--"
return f"{100.0 * float(value):.{digits}f}"
def latex_escape(text: str) -> str:
repl = {
"\\": r"\textbackslash{}",
"&": r"\&",
"%": r"\%",
"$": r"\$",
"#": r"\#",
"_": r"\_",
"{": r"\{",
"}": r"\}",
"~": r"\textasciitilde{}",
"^": r"\textasciicircum{}",
}
return "".join(repl.get(ch, ch) for ch in text)
def write_table(path: Path, header: list[str], rows: list[list[str]], caption: str, label: str) -> None:
colspec = "l" * len(header)
lines = [
r"\begin{table*}[t]",
r"\centering",
rf"\caption{{{caption}}}",
rf"\label{{{label}}}",
r"\scriptsize",
rf"\begin{{tabular}}{{{colspec}}}",
r"\toprule",
" & ".join(latex_escape(h) for h in header) + r" \\",
r"\midrule",
]
for row in rows:
lines.append(" & ".join(latex_escape(x) for x in row) + r" \\")
lines += [r"\bottomrule", r"\end{tabular}", r"\end{table*}", ""]
write_text(path, "\n".join(lines))
def write_rows(path_base: Path, rows: list[dict[str, Any]], fieldnames: list[str]) -> None:
ensure_dir(path_base.parent)
for ext, dialect in [(".tsv", "excel-tab"), (".csv", "excel")]:
with (path_base.with_suffix(ext)).open("w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames, dialect=dialect)
writer.writeheader()
for row in rows:
writer.writerow({k: row.get(k, "") for k in fieldnames})
@lru_cache(maxsize=2)
def arial_font_path(bold: bool) -> str | None:
family = "Arial:style=Bold" if bold else "Arial:style=Regular"
try:
path = subprocess.check_output(["fc-match", "-f", "%{file}", family], text=True).strip()
except (OSError, subprocess.CalledProcessError):
return None
return path or None
def font(size: int, bold: bool = False) -> ImageFont.FreeTypeFont | ImageFont.ImageFont:
matched = arial_font_path(bold)
candidates = [
matched,
"/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf" if bold else "/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf",
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" if bold else "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
"/usr/share/fonts/dejavu/DejaVuSans-Bold.ttf" if bold else "/usr/share/fonts/dejavu/DejaVuSans.ttf",
]
for path in candidates:
if not path:
continue
try:
return ImageFont.truetype(path, size=size)
except OSError:
pass
return ImageFont.load_default()
def load_planning() -> tuple[list[SummaryRecord], list[EpisodeRecord]]:
summaries: list[SummaryRecord] = []
episodes: list[EpisodeRecord] = []
for path in sorted(PLANNING_DIR.glob("*.json")):
data = read_json(path)
if not isinstance(data, list):
raise ValueError(f"Expected list in {path}")
for item_index, item in enumerate(data):
method = item["method"]
task = item["task"]
boat = item["boat"]
flow_type = item["flow_type"]
for context_mode, metrics in item["by_context"].items():
summaries.append(
SummaryRecord(
source_file=rel(path),
item_index=item_index,
method=method,
task=task,
boat=boat,
flow_type=flow_type,
context_mode=context_mode,
episodes=int(metrics["episodes"]),
successes=int(metrics["successes"]),
success_rate=float(metrics["success_rate"]),
final_distance_mean=float(metrics["final_distance_mean"]),
mean_min_goal_distance=float(metrics["mean_min_goal_distance"]),
path_length_success_mean=safe_float(metrics.get("path_length_success_mean")),
energy_success_mean=safe_float(metrics.get("energy_success_mean")),
steps_success_mean=safe_float(metrics.get("steps_success_mean")),
)
)
for result_index, result in enumerate(item["results"]):
episodes.append(
EpisodeRecord(
source_file=rel(path),
item_index=item_index,
result_index=result_index,
method=method,
task=task,
boat=boat,
flow_type=flow_type,
context_mode=result["context_mode"],
episode=int(result["episode"]),
success=bool(result["success"]),
final_distance=float(result["final_distance"]),
mean_min_goal_distance=float(result["mean_min_goal_distance"]),
energy=safe_float(result.get("energy")),
path_length=safe_float(result.get("path_length")),
steps=int(result["steps"]) if result.get("steps") is not None else None,
)
)
return summaries, episodes
def inferred_summaries(records: Iterable[SummaryRecord]) -> list[SummaryRecord]:
return [r for r in records if r.context_mode == "inferred" and r.method in METHOD_ORDER]
def inferred_episodes(records: Iterable[EpisodeRecord]) -> list[EpisodeRecord]:
return [r for r in records if r.context_mode == "inferred" and r.method in METHOD_ORDER]
def aggregate_success(records: Iterable[SummaryRecord], group_keys: tuple[str, ...]) -> list[dict[str, Any]]:
acc: dict[tuple[Any, ...], dict[str, Any]] = {}
sources: dict[tuple[Any, ...], set[str]] = defaultdict(set)
for r in records:
key = tuple(getattr(r, k) for k in group_keys)
entry = acc.setdefault(key, {k: getattr(r, k) for k in group_keys} | {"successes": 0, "episodes": 0})
entry["successes"] += r.successes
entry["episodes"] += r.episodes
sources[key].add(r.source_file)
rows = []
for key, entry in acc.items():
episodes = entry["episodes"]
success_rate = entry["successes"] / episodes if episodes else math.nan
rows.append(entry | {"success_rate": success_rate, "source_files": ";".join(sorted(sources[key]))})
return rows
def task_sort_key(task: str) -> int:
return TASK_ORDER.index(task) if task in TASK_ORDER else len(TASK_ORDER)
def boat_sort_key(boat: str) -> int:
return BOAT_ORDER.index(boat) if boat in BOAT_ORDER else len(BOAT_ORDER)
def flow_sort_key(flow: str) -> int:
return FLOW_ORDER.index(flow) if flow in FLOW_ORDER else len(FLOW_ORDER)
def method_sort_key(method: str) -> int:
return METHOD_ORDER.index(method) if method in METHOD_ORDER else len(METHOD_ORDER)
def extract_fig3() -> None:
fig_dir = OUT / "fig3"
frames_dir = fig_dir / "frames"
ensure_dir(frames_dir)
rows: list[dict[str, Any]] = []
cell_images: dict[tuple[str, str, str], Path] = {}
frame_names = ["first", "middle", "last"]
for task in TASK_ORDER:
flow = FIG3_TASK_FLOW[task]
for boat in BOAT_ORDER:
gif_path = GIF_DIR / f"image_planning_flowmo_inferred_{boat}_{task}_{flow}_ep{FIG3_EPISODE:03d}.gif"
if not gif_path.exists():
raise FileNotFoundError(f"Missing GIF for Fig. 3: {gif_path}")
with Image.open(gif_path) as im:
n_frames = getattr(im, "n_frames", 1)
frame_indices = [0, n_frames // 2, n_frames - 1]
for frame_name, frame_index in zip(frame_names, frame_indices):
im.seek(frame_index)
frame = im.convert("RGBA")
out_name = f"fig3_flowmo_inferred_{boat}_{task}_{flow}_ep{FIG3_EPISODE:03d}_{frame_name}_frame{frame_index:03d}.png"
out_path = frames_dir / out_name
frame.save(out_path)
cell_images[(task, boat, frame_name)] = out_path
rows.append(
{
"task": task,
"boat": boat,
"method": "flowmo",
"context_mode": "inferred",
"flow_type": flow,
"episode": FIG3_EPISODE,
"frame_name": frame_name,
"frame_index": frame_index,
"gif_frames": n_frames,
"source_gif": rel(gif_path),
"output_png": rel(out_path),
}
)
write_rows(
fig_dir / "figure3_frame_manifest",
rows,
[
"task",
"boat",
"method",
"context_mode",
"flow_type",
"episode",
"frame_name",
"frame_index",
"gif_frames",
"source_gif",
"output_png",
],
)
thumb_w, thumb_h = 210, 210
left = 138
top = 58
row_gap = 16
group_gap = 30
width = left + 6 * thumb_w + group_gap + 14
height = top + len(TASK_ORDER) * thumb_h + (len(TASK_ORDER) - 1) * row_gap + 14
canvas = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(canvas)
head_font = font(30, True)
flow_font = font(25, False)
group_font = font(30, True)
twin_x = left
tri_x = left + 3 * thumb_w + group_gap
draw.text((twin_x + 1.5 * thumb_w - draw.textlength("Twin", font=group_font) / 2, 12), "Twin", fill=(20, 25, 30), font=group_font)
draw.text((tri_x + 1.5 * thumb_w - draw.textlength("Triangle", font=group_font) / 2, 12), "Triangle", fill=(20, 25, 30), font=group_font)
for row, task in enumerate(TASK_ORDER):
y = top + row * (thumb_h + row_gap)
draw.text((12, y + 72), TASK_LABEL[task], fill=(20, 25, 30), font=head_font)
draw.text((12, y + 104), FLOW_LABEL[FIG3_TASK_FLOW[task]], fill=(80, 80, 80), font=flow_font)
for boat in BOAT_ORDER:
base_x = twin_x if boat == "twin" else tri_x
for frame_idx, frame_name in enumerate(frame_names):
x = base_x + frame_idx * thumb_w
src = cell_images[(task, boat, frame_name)]
img = Image.open(src).convert("RGB").resize((thumb_w, thumb_h), Image.Resampling.LANCZOS)
canvas.paste(img, (x, y))
draw.rectangle([x, y, x + thumb_w, y + thumb_h], outline=(220, 220, 220), width=1)
contact = fig_dir / "figure3_rollout_contact_sheet.png"
trim_whitespace(canvas, pad_x=8, pad_y=0).save(contact)
md_lines = [
"# Figure 3 Provenance",
"",
"Purpose: qualitative task rollouts extracted from experiment GIFs.",
"",
f"Selected method/context: `flowmo` / `inferred`.",
f"Selected episode: `{FIG3_EPISODE}`.",
"Layout: for each task and boat, the three adjacent frames are first, middle, and last; spacing appears only between task/boat groups.",
"",
"Selected flows by task:",
]
for task in TASK_ORDER:
md_lines.append(f"- `{task}`: `{FIG3_TASK_FLOW[task]}`")
md_lines += [
"",
"Frame rule: for each source GIF, extracted `first = 0`, `middle = n_frames // 2`, and `last = n_frames - 1`.",
"",
"Generated outputs:",
f"- `{rel(contact)}`",
f"- `{rel(frames_dir)}/`",
f"- `{rel(fig_dir / 'figure3_frame_manifest.tsv')}`",
f"- `{rel(fig_dir / 'figure3_frame_manifest.csv')}`",
"",
"Source GIFs:",
]
for source in sorted({row["source_gif"] for row in rows}):
md_lines.append(f"- `{source}`")
write_text(fig_dir / "figure3_provenance.md", "\n".join(md_lines) + "\n")
def load_prediction_rows() -> list[dict[str, Any]]:
data = read_json(PREDICTION_JSON)
rows: list[dict[str, Any]] = []
for item_index, item in enumerate(data):
method = item["method"]
if method not in LEARNED_METHODS:
continue
metrics = item["inferred"]
for horizon in HORIZONS:
rows.append(
{
"method": method,
"method_label": METHOD_LABEL[method],
"context_mode": "inferred",
"horizon": horizon,
"position_error": metrics[f"pos{horizon}"],
"heading_error": metrics[f"heading{horizon}"],
"source_file": rel(PREDICTION_JSON),
"json_path_position": f"$[{item_index}].inferred.pos{horizon}",
"json_path_heading": f"$[{item_index}].inferred.heading{horizon}",
}
)
return rows
def draw_line_panel(
draw: ImageDraw.ImageDraw,
box: tuple[int, int, int, int],
rows: list[dict[str, Any]],
title: str,
colors: dict[str, tuple[int, int, int]],
compact: bool = False,
show_legend: bool = True,
title_y: int | None = None,
) -> None:
x0, y0, x1, y1 = box
axis_font = font(20 if compact else 15)
title_font = font(26 if compact else 22, True)
title_offset = 48 if compact else 42
draw.text((x0, title_y if title_y is not None else y0 - title_offset), title, fill=(20, 25, 30), font=title_font)
draw.rectangle([x0, y0, x1, y1], outline=(210, 210, 210), width=1)
if compact:
pad_l, pad_b, pad_t, pad_r = 74, 68, 28, 24
else:
pad_l, pad_b, pad_t, pad_r = 64, 52, 24, 20
px0, py0, px1, py1 = x0 + pad_l, y0 + pad_t, x1 - pad_r, y1 - pad_b
max_y = max(float(r["position_error"]) for r in rows) * 1.12
min_h, max_h = min(HORIZONS), max(HORIZONS)
for tick in [0.0, 0.25, 0.50, 0.75, 1.0]:
y = py1 - tick * (py1 - py0)
val = tick * max_y
draw.line([px0, y, px1, y], fill=(232, 232, 232), width=1)
draw.text((x0 + 8, y - 11), f"{val:.2f}", fill=(70, 70, 70), font=axis_font)
draw.line([px0, py1, px1, py1], fill=(60, 60, 60), width=2)
draw.line([px0, py0, px0, py1], fill=(60, 60, 60), width=2)
for h in HORIZONS:
x = px0 + (h - min_h) / (max_h - min_h) * (px1 - px0)
draw.line([x, py1, x, py1 + 5], fill=(60, 60, 60), width=1)
if (compact or (x1 - x0) < 700) and h in [3, 6, 8, 30]:
continue
draw.text((x - 12, py1 + 12), str(h), fill=(70, 70, 70), font=axis_font)
draw.text(((px0 + px1) // 2 - (72 if compact else 60), y1 - (36 if compact else 28)), "rollout step", fill=(60, 60, 60), font=axis_font)
by_method: dict[str, list[dict[str, Any]]] = defaultdict(list)
for row in rows:
by_method[row["method"]].append(row)
for method in LEARNED_METHODS:
pts = []
for row in sorted(by_method[method], key=lambda x: int(x["horizon"])):
h = int(row["horizon"])
x = px0 + (h - min_h) / (max_h - min_h) * (px1 - px0)
y = py1 - float(row["position_error"]) / max_y * (py1 - py0)
pts.append((x, y))
if len(pts) >= 2:
draw.line(pts, fill=colors[method], width=4)
for x, y in pts:
rr = 5 if compact else 4
draw.ellipse([x - rr, y - rr, x + rr, y + rr], fill=colors[method])
if show_legend:
lx, ly = (px0 + 20, py0 + 14) if compact else (px1 - 185, py0 + 10)
for i, method in enumerate(LEARNED_METHODS):
if compact:
col = i % 2
row = i // 2
xx = lx + col * 250
yy = ly + row * 30
else:
xx = lx
yy = ly + i * 24
draw.line([xx, yy + 10, xx + 32, yy + 10], fill=colors[method], width=5 if compact else 4)
draw.text((xx + 42, yy - 2), METHOD_SHORT[method], fill=(40, 40, 40), font=axis_font)
def draw_success_bar_panel(
draw: ImageDraw.ImageDraw,
box: tuple[int, int, int, int],
rows: list[dict[str, Any]],
title: str,
) -> None:
x0, y0, x1, y1 = box
axis_font = font(15)
small_font = font(13)
title_font = font(22, True)
draw.text((x0, y0 - 42), title, fill=(20, 25, 30), font=title_font)
draw.rectangle([x0, y0, x1, y1], outline=(210, 210, 210), width=1)
pad_l, pad_b, pad_t, pad_r = 58, 70, 24, 18
px0, py0, px1, py1 = x0 + pad_l, y0 + pad_t, x1 - pad_r, y1 - pad_b
for tick in [0, 0.25, 0.50, 0.75, 1.0]:
y = py1 - tick * (py1 - py0)
draw.line([px0, y, px1, y], fill=(232, 232, 232), width=1)
draw.text((x0 + 10, y - 8), f"{int(tick * 100)}", fill=(70, 70, 70), font=axis_font)
draw.line([px0, py1, px1, py1], fill=(60, 60, 60), width=2)
draw.line([px0, py0, px0, py1], fill=(60, 60, 60), width=2)
rates = {(r["method"], r["boat"]): float(r["success_rate"]) for r in rows}
group_w = (px1 - px0) / len(METHOD_ORDER)
bar_w = group_w * 0.30
boat_colors = {"twin": (54, 119, 191), "triangle": (218, 119, 54)}
for idx, method in enumerate(METHOD_ORDER):
cx = px0 + idx * group_w + group_w * 0.5
for j, boat in enumerate(BOAT_ORDER):
rate = rates.get((method, boat), 0.0)
x_left = cx + (j - 0.5) * bar_w - bar_w * 0.5
x_right = x_left + bar_w
y_top = py1 - rate * (py1 - py0)
draw.rectangle([x_left, y_top, x_right, py1], fill=boat_colors[boat], outline=(255, 255, 255))
draw.text((cx - 24, py1 + 10), METHOD_SHORT[method], fill=(55, 55, 55), font=small_font)
if method == "tdmpc2":
split_x = px0 + (idx + 1) * group_w
draw.line([split_x, py0, split_x, py1 + 30], fill=(80, 80, 80), width=2)
lx, ly = px1 - 180, py0 + 12
for i, boat in enumerate(BOAT_ORDER):
yy = ly + i * 24
draw.rectangle([lx, yy, lx + 18, yy + 14], fill=boat_colors[boat])
draw.text((lx + 26, yy - 2), BOAT_LABEL[boat], fill=(40, 40, 40), font=axis_font)
draw.text((x0 + 8, y0 + 8), "success rate (%)", fill=(70, 70, 70), font=axis_font)
def draw_success_by_task_panel(
draw: ImageDraw.ImageDraw,
box: tuple[int, int, int, int],
rows: list[dict[str, Any]],
boat: str,
title: str,
show_legend: bool = False,
) -> None:
x0, y0, x1, y1 = box
axis_font = font(15)
small_font = font(12)
title_font = font(21, True)
draw.text((x0, y0 - 36), title, fill=(20, 25, 30), font=title_font)
draw.rectangle([x0, y0, x1, y1], outline=(210, 210, 210), width=1)
pad_l, pad_b, pad_t, pad_r = 58, 58, (76 if show_legend else 26), 16
px0, py0, px1, py1 = x0 + pad_l, y0 + pad_t, x1 - pad_r, y1 - pad_b
for tick in [0, 0.25, 0.50, 0.75, 1.0]:
y = py1 - tick * (py1 - py0)
draw.line([px0, y, px1, y], fill=(232, 232, 232), width=1)
draw.text((x0 + 10, y - 8), f"{int(tick * 100)}", fill=(70, 70, 70), font=axis_font)
draw.line([px0, py1, px1, py1], fill=(60, 60, 60), width=2)
draw.line([px0, py0, px0, py1], fill=(60, 60, 60), width=2)
draw.text((x0 + 8, y0 + 8), "success (%)", fill=(70, 70, 70), font=axis_font)
rates = {(r["task"], r["method"]): float(r["success_rate"]) for r in rows if r["boat"] == boat}
group_w = (px1 - px0) / len(TASK_ORDER)
bar_w = group_w * 0.70 / len(METHOD_ORDER)
for task_idx, task in enumerate(TASK_ORDER):
group_left = px0 + task_idx * group_w + group_w * 0.15
for method_idx, method in enumerate(METHOD_ORDER):
rate = rates.get((task, method), 0.0)
x_left = group_left + method_idx * bar_w
x_right = x_left + bar_w * 0.88
y_top = py1 - rate * (py1 - py0)
draw.rectangle([x_left, y_top, x_right, py1], fill=METHOD_COLORS[method], outline=(255, 255, 255))
label = TASK_LABEL[task]
tw = draw.textlength(label, font=small_font)
draw.text((px0 + task_idx * group_w + (group_w - tw) / 2, py1 + 12), label, fill=(45, 45, 45), font=small_font)
if task_idx > 0:
split_x = px0 + task_idx * group_w
draw.line([split_x, py0, split_x, py1 + 24], fill=(218, 218, 218), width=1)
if show_legend:
lx, ly = x0 + 78, y0 + 36
for i, method in enumerate(METHOD_ORDER):
row = i // 4
col = i % 4
xx = lx + col * 145
yy = ly + row * 22
draw.rectangle([xx, yy, xx + 16, yy + 12], fill=METHOD_COLORS[method])
draw.text((xx + 22, yy - 3), METHOD_SHORT[method], fill=(40, 40, 40), font=small_font)
def draw_learned_success_by_task_boat_panel(
draw: ImageDraw.ImageDraw,
box: tuple[int, int, int, int],
rows: list[dict[str, Any]],
title: str,
) -> None:
x0, y0, x1, y1 = box
axis_font = font(15)
small_font = font(12)
title_font = font(22, True)
draw.text((x0, y0 - 42), title, fill=(20, 25, 30), font=title_font)
draw.rectangle([x0, y0, x1, y1], outline=(210, 210, 210), width=1)
pad_l, pad_b, pad_t, pad_r = 58, 86, 66, 18
px0, py0, px1, py1 = x0 + pad_l, y0 + pad_t, x1 - pad_r, y1 - pad_b
for tick in [0, 0.25, 0.50, 0.75, 1.0]:
y = py1 - tick * (py1 - py0)
draw.line([px0, y, px1, y], fill=(232, 232, 232), width=1)
draw.text((x0 + 10, y - 8), f"{int(tick * 100)}", fill=(70, 70, 70), font=axis_font)
draw.line([px0, py1, px1, py1], fill=(60, 60, 60), width=2)
draw.line([px0, py0, px0, py1], fill=(60, 60, 60), width=2)
draw.text((x0 + 8, y0 + 8), "success (%)", fill=(70, 70, 70), font=axis_font)
lx, ly = x0 + 82, y0 + 36
for i, method in enumerate(LEARNED_METHODS):
xx = lx + i * 190
draw.rectangle([xx, ly, xx + 18, ly + 12], fill=METHOD_COLORS[method])
draw.text((xx + 24, ly - 4), METHOD_SHORT[method], fill=(40, 40, 40), font=small_font)
rates = {(r["task"], r["boat"], r["method"]): float(r["success_rate"]) for r in rows}
task_w = (px1 - px0) / len(TASK_ORDER)
boat_w = task_w / len(BOAT_ORDER)
bar_w = boat_w * 0.72 / len(LEARNED_METHODS)
for task_idx, task in enumerate(TASK_ORDER):
task_left = px0 + task_idx * task_w
if task_idx > 0:
draw.line([task_left, py0, task_left, py1 + 44], fill=(218, 218, 218), width=1)
task_label = TASK_LABEL[task]
tw = draw.textlength(task_label, font=axis_font)
draw.text((task_left + (task_w - tw) / 2, py1 + 42), task_label, fill=(35, 35, 35), font=axis_font)
for boat_idx, boat in enumerate(BOAT_ORDER):
boat_left = task_left + boat_idx * boat_w
group_left = boat_left + boat_w * 0.14
for method_idx, method in enumerate(LEARNED_METHODS):
rate = rates.get((task, boat, method), 0.0)
x_left = group_left + method_idx * bar_w
x_right = x_left + bar_w * 0.86
y_top = py1 - rate * (py1 - py0)
draw.rectangle([x_left, y_top, x_right, py1], fill=METHOD_COLORS[method], outline=(255, 255, 255))
boat_label = BOAT_LABEL[boat]
bw = draw.textlength(boat_label, font=small_font)
draw.text((boat_left + (boat_w - bw) / 2, py1 + 14), boat_label, fill=(55, 55, 55), font=small_font)
def draw_single_column_success_panel(
draw: ImageDraw.ImageDraw,
box: tuple[int, int, int, int],
rows: list[dict[str, Any]],
boat: str,
title: str,
show_legend: bool = False,
) -> None:
x0, y0, x1, y1 = box
axis_font = font(20)
small_font = font(18)
title_font = font(25, True)
draw.text((x0, y0 - 44), title, fill=(20, 25, 30), font=title_font)
draw.rectangle([x0, y0, x1, y1], outline=(210, 210, 210), width=1)
pad_l, pad_b, pad_t, pad_r = 76, 76, 70 if show_legend else 64, 24
px0, py0, px1, py1 = x0 + pad_l, y0 + pad_t, x1 - pad_r, y1 - pad_b
for tick in [0, 0.25, 0.50, 0.75, 1.0]:
y = py1 - tick * (py1 - py0)
draw.line([px0, y, px1, y], fill=(232, 232, 232), width=1)
draw.text((x0 + 12, y - 11), f"{int(tick * 100)}", fill=(70, 70, 70), font=axis_font)
draw.line([px0, py1, px1, py1], fill=(60, 60, 60), width=2)
draw.line([px0, py0, px0, py1], fill=(60, 60, 60), width=2)
draw.text((x0 + 10, y0 + 10), "success (%)", fill=(70, 70, 70), font=axis_font)
if show_legend:
lx, ly = x0 + 100, y0 + 38
for i, method in enumerate(LEARNED_METHODS):
xx = lx + i * 170
draw.rectangle([xx, ly, xx + 22, ly + 16], fill=METHOD_COLORS[method])
draw.text((xx + 30, ly - 5), METHOD_SHORT[method], fill=(40, 40, 40), font=small_font)
rates = {(r["task"], r["method"]): float(r["success_rate"]) for r in rows if r["boat"] == boat}
group_w = (px1 - px0) / len(TASK_ORDER)
bar_w = group_w * 0.72 / len(LEARNED_METHODS)
for task_idx, task in enumerate(TASK_ORDER):
group_left = px0 + task_idx * group_w + group_w * 0.14
if task_idx > 0:
split_x = px0 + task_idx * group_w
draw.line([split_x, py0, split_x, py1 + 38], fill=(218, 218, 218), width=1)
for method_idx, method in enumerate(LEARNED_METHODS):
rate = rates.get((task, method), 0.0)
x_left = group_left + method_idx * bar_w
x_right = x_left + bar_w * 0.86
y_top = py1 - rate * (py1 - py0)
draw.rectangle([x_left, y_top, x_right, py1], fill=METHOD_COLORS[method], outline=(255, 255, 255))
label = TASK_LABEL[task]
tw = draw.textlength(label, font=small_font)
draw.text((px0 + task_idx * group_w + (group_w - tw) / 2, py1 + 18), label, fill=(45, 45, 45), font=small_font)
def draw_compact_success_panel(
draw: ImageDraw.ImageDraw,
box: tuple[int, int, int, int],
rows: list[dict[str, Any]],
boat: str,
title: str,
title_y: int | None = None,
show_legend: bool = False,
) -> None:
x0, y0, x1, y1 = box
axis_font = font(14)
small_font = font(12)
title_font = font(20, True)
draw.text((x0, title_y if title_y is not None else y0 - 32), title, fill=(20, 25, 30), font=title_font)
draw.rectangle([x0, y0, x1, y1], outline=(210, 210, 210), width=1)
pad_l, pad_b, pad_t, pad_r = 44, 48, 52, 12
px0, py0, px1, py1 = x0 + pad_l, y0 + pad_t, x1 - pad_r, y1 - pad_b
y_min, y_max = 0.50, 1.00
for tick in [0.50, 0.60, 0.70, 0.80, 0.90, 1.00]:
y = py1 - ((tick - y_min) / (y_max - y_min)) * (py1 - py0)
draw.line([px0, y, px1, y], fill=(232, 232, 232), width=1)
draw.text((x0 + 8, y - 8), f"{int(tick * 100)}", fill=(70, 70, 70), font=axis_font)
draw.line([px0, py1, px1, py1], fill=(60, 60, 60), width=2)
draw.line([px0, py0, px0, py1], fill=(60, 60, 60), width=2)
draw.text((x0 + 8, y0 + 7), "success (%)", fill=(70, 70, 70), font=axis_font)
rates = {(r["task"], r["method"]): float(r["success_rate"]) for r in rows if r["boat"] == boat}
group_w = (px1 - px0) / len(TASK_ORDER)
bar_w = group_w * 0.74 / len(LEARNED_METHODS)
for task_idx, task in enumerate(TASK_ORDER):
group_left = px0 + task_idx * group_w + group_w * 0.13
if task_idx > 0:
split_x = px0 + task_idx * group_w
draw.line([split_x, py0, split_x, py1 + 26], fill=(218, 218, 218), width=1)
for method_idx, method in enumerate(LEARNED_METHODS):
rate = rates.get((task, method), 0.0)
x_left = group_left + method_idx * bar_w
x_right = x_left + bar_w * 0.86
scaled = max(0.0, min(1.0, (rate - y_min) / (y_max - y_min)))
y_top = py1 - scaled * (py1 - py0)
draw.rectangle([x_left, y_top, x_right, py1], fill=METHOD_COLORS[method], outline=(255, 255, 255))
label = TASK_LABEL[task]
tw = draw.textlength(label, font=small_font)
draw.text((px0 + task_idx * group_w + (group_w - tw) / 2, py1 + 12), label, fill=(45, 45, 45), font=small_font)
def make_fig4(summaries: list[SummaryRecord]) -> None:
fig_dir = OUT / "fig4"
ensure_dir(fig_dir)
prediction_rows = load_prediction_rows()
write_rows(
fig_dir / "figure4_prediction_error",
prediction_rows,
[
"method",
"method_label",
"context_mode",
"horizon",
"position_error",
"heading_error",
"source_file",
"json_path_position",
"json_path_heading",
],
)
learned_summaries = [r for r in inferred_summaries(summaries) if r.method in LEARNED_METHODS]
success_rows = aggregate_success(learned_summaries, ("task", "boat", "method"))
success_rows.sort(key=lambda r: (task_sort_key(r["task"]), boat_sort_key(r["boat"]), method_sort_key(r["method"])))
for row in success_rows:
row["method_label"] = METHOD_LABEL[row["method"]]
row["task_label"] = TASK_LABEL[row["task"]]
row["boat_label"] = BOAT_LABEL[row["boat"]]
write_rows(
fig_dir / "figure4_success_by_task_boat",
success_rows,
["task", "task_label", "boat", "boat_label", "method", "method_label", "successes", "episodes", "success_rate", "source_files"],
)
source_rows = []
for r in sorted(learned_summaries, key=lambda x: (task_sort_key(x.task), boat_sort_key(x.boat), method_sort_key(x.method), flow_sort_key(x.flow_type))):
source_rows.append(
{
"task": r.task,
"boat": r.boat,
"method": r.method,
"flow_type": r.flow_type,
"successes": r.successes,
"episodes": r.episodes,
"success_rate": r.success_rate,
"source_file": r.source_file,
"item_index": r.item_index,
"json_path_successes": f"$[{r.item_index}].by_context.inferred.successes",
"json_path_episodes": f"$[{r.item_index}].by_context.inferred.episodes",
}
)
write_rows(
fig_dir / "figure4_success_by_task_boat_source_rows",
source_rows,
[
"task",
"boat",
"method",
"flow_type",
"successes",
"episodes",
"success_rate",
"source_file",
"item_index",
"json_path_successes",
"json_path_episodes",
],
)
canvas = Image.new("RGB", (1800, 505), "white")
draw = ImageDraw.Draw(canvas)
panel_title_y = 36
draw_line_panel(draw, (34, 68, 620, 410), prediction_rows, "(A) Prediction error", METHOD_COLORS, show_legend=False, title_y=panel_title_y)
draw_compact_success_panel(draw, (656, 68, 1230, 410), success_rows, "twin", "(B) Twin planning success", show_legend=False, title_y=panel_title_y)
draw_compact_success_panel(draw, (1268, 68, 1766, 410), success_rows, "triangle", "(C) Triangle planning success", show_legend=False, title_y=panel_title_y)
legend_font = font(18)
legend_y = 462
legend_x = 485
for i, method in enumerate(LEARNED_METHODS):
xx = legend_x + i * 215
draw.line([xx, legend_y + 8, xx + 34, legend_y + 8], fill=METHOD_COLORS[method], width=5)
draw.text((xx + 44, legend_y - 4), METHOD_LABEL[method], fill=(40, 40, 40), font=legend_font)
out = fig_dir / "figure4_prediction_and_planning.png"
trim_whitespace(canvas, pad_x=8, pad_y=0).save(out)
md = [
"# Paper Figure 4 Provenance",
"",
"Purpose: quantitative paper Figure 4 with learned-world-model prediction error curves and planning success grouped by experiment/task.",
"",
"Generated outputs:",
f"- `{rel(out)}`",
f"- `{rel(fig_dir / 'figure4_prediction_error.tsv')}`",
f"- `{rel(fig_dir / 'figure4_prediction_error.csv')}`",
f"- `{rel(fig_dir / 'figure4_success_by_task_boat.tsv')}`",
f"- `{rel(fig_dir / 'figure4_success_by_task_boat.csv')}`",
f"- `{rel(fig_dir / 'figure4_success_by_task_boat_source_rows.tsv')}`",
f"- `{rel(fig_dir / 'figure4_success_by_task_boat_source_rows.csv')}`",
"",
"Panel (A) source:",
f"- `{rel(PREDICTION_JSON)}`",
"- JSON selectors: `$[method_index].inferred.pos{horizon}` and `$[method_index].inferred.heading{horizon}` for horizons 1, 3, 6, 8, 10, 20, 30, 40, 60.",
"- Included methods: `flowmo`, `leworldmodel`, `planet`, `tdmpc2`.",
"",
"Panels (B) and (C) source:",
f"- `{rel(PLANNING_DIR)}/*.json`",
"- JSON selectors: `$[item_index].by_context.inferred.successes` and `$[item_index].by_context.inferred.episodes`.",
"- Included methods: `flowmo`, `leworldmodel`, `planet`, `tdmpc2`.",
"- Aggregation: sum successes and episodes over all flow types for each task, learned method, and boat.",
"- Row-level source entries are recorded in `figure4_success_by_task_boat_source_rows.tsv/csv` with `source_file`, `item_index`, and JSON path columns.",
"- Excluded diagnostic FlowMo contexts: `zero`, `shuffled`.",
]
write_text(fig_dir / "figure4_provenance.md", "\n".join(md) + "\n")
def draw_failure_line_panel(
draw: ImageDraw.ImageDraw,
box: tuple[int, int, int, int],
rows: list[dict[str, Any]],
boat: str,
title: str,
show_legend: bool = False,
) -> None:
x0, y0, x1, y1 = box
axis_font = font(14)
small_font = font(12)
title_font = font(21, True)
draw.text((x0, y0 - 36), title, fill=(20, 25, 30), font=title_font)
draw.rectangle([x0, y0, x1, y1], outline=(210, 210, 210), width=1)
pad_l, pad_b, pad_t, pad_r = 62, 88, 26, 22
px0, py0, px1, py1 = x0 + pad_l, y0 + pad_t, x1 - pad_r, y1 - pad_b
boat_rows = [r for r in rows if r["boat"] == boat]
max_fail = max(float(r["failure_percent"]) for r in boat_rows)
y_max = max(5.0, math.ceil((max_fail + 3.0) / 5.0) * 5.0)
for tick in [0.0, 0.25, 0.50, 0.75, 1.0]:
y = py1 - tick * (py1 - py0)
val = tick * y_max
draw.line([px0, y, px1, y], fill=(232, 232, 232), width=1)
draw.text((x0 + 8, y - 8), f"{val:.0f}", fill=(70, 70, 70), font=axis_font)
draw.line([px0, py1, px1, py1], fill=(60, 60, 60), width=2)
draw.line([px0, py0, px0, py1], fill=(60, 60, 60), width=2)
draw.text((x0 + 8, y0 + 8), "failure (%)", fill=(70, 70, 70), font=axis_font)
flow_x = {}
for i, flow in enumerate(FLOW_ORDER):
x = px0 + i / (len(FLOW_ORDER) - 1) * (px1 - px0)
flow_x[flow] = x
draw.line([x, py1, x, py1 + 5], fill=(60, 60, 60), width=1)
label = FLOW_LABEL[flow]
draw.text((x - 28, py1 + 10), label, fill=(55, 55, 55), font=small_font)
failures = {(r["method"], r["flow_type"]): float(r["failure_percent"]) for r in boat_rows}
for method in METHOD_ORDER:
pts = []
for flow in FLOW_ORDER:
value = failures.get((method, flow), 0.0)
x = flow_x[flow]
y = py1 - (value / y_max) * (py1 - py0)
pts.append((x, y))
draw.line(pts, fill=METHOD_COLORS[method], width=3)
for x, y in pts:
draw.ellipse([x - 3, y - 3, x + 3, y + 3], fill=METHOD_COLORS[method])
if show_legend:
lx, ly = x0 + 80, y0 + 34
for i, method in enumerate(METHOD_ORDER):
row = i // 4
col = i % 4
xx = lx + col * 145
yy = ly + row * 22
draw.line([xx, yy + 7, xx + 22, yy + 7], fill=METHOD_COLORS[method], width=3)
draw.text((xx + 28, yy - 2), METHOD_SHORT[method], fill=(40, 40, 40), font=small_font)
def failure_color(value: float, max_value: float = 60.0) -> tuple[int, int, int]:
t = max(0.0, min(1.0, value / max_value))
stops = [
(0.00, (255, 255, 255)),
(0.15, (255, 238, 210)),
(0.35, (249, 177, 107)),
(0.65, (220, 72, 55)),
(1.00, (120, 28, 45)),
]
for (t0, c0), (t1, c1) in zip(stops, stops[1:]):
if t <= t1:
alpha = (t - t0) / (t1 - t0) if t1 > t0 else 0.0
return tuple(int(c0[i] + alpha * (c1[i] - c0[i])) for i in range(3))
return stops[-1][1]
def draw_failure_heatmap_panel(
draw: ImageDraw.ImageDraw,
box: tuple[int, int, int, int],
rows: list[dict[str, Any]],
boat: str,
title: str,
) -> None:
x0, y0, x1, y1 = box
title_font = font(30, True)
label_font = font(22, True)
tick_font = font(19)
cell_font = font(17, True)
draw.text((x0, y0), title, fill=(20, 25, 30), font=title_font)
left_w = 160
top_h = 58
heat_x0 = x0 + left_w
heat_y0 = y0 + top_h
heat_x1 = x1 - 22
heat_y1 = y1 - 22
cell_w = (heat_x1 - heat_x0) / len(FLOW_ORDER)
cell_h = (heat_y1 - heat_y0) / len(METHOD_ORDER)
values = {(r["method"], r["flow_type"]): float(r["failure_percent"]) for r in rows if r["boat"] == boat}
short_flow = {
"No flow": "No",
"Uniform": "Uni",
"Vortex": "Vort",
"Double gyre": "Gyre",
"Source/sink": "Src",
"Src/sink pair": "Pair",
"Gradient": "Grad",
"Shear": "Shear",
"Turbulent": "Turb",
"Fourier": "Fourier",
}
for col, flow in enumerate(FLOW_ORDER):
label = short_flow.get(FLOW_LABEL[flow], FLOW_LABEL[flow])
x = heat_x0 + col * cell_w
tw = draw.textlength(label, font=tick_font)
draw.text((x + (cell_w - tw) / 2, heat_y0 - 29), label, fill=(45, 45, 45), font=tick_font)
for row_idx, method in enumerate(METHOD_ORDER):
y = heat_y0 + row_idx * cell_h
label = METHOD_SHORT[method]
row_font = label_font if method == "flowmo" else tick_font
draw.text((x0 + 6, y + cell_h * 0.25), label, fill=(35, 35, 35), font=row_font)
if method == "pid_los_controller":
draw.line([x0, y, heat_x1, y], fill=(65, 65, 65), width=3)
for col, flow in enumerate(FLOW_ORDER):
x = heat_x0 + col * cell_w
value = values.get((method, flow), 0.0)
color = failure_color(value)
draw.rectangle([x, y, x + cell_w, y + cell_h], fill=color, outline=(238, 238, 238), width=1)
text = f"{value:.0f}"
text_color = (255, 255, 255) if value >= 36.0 else (35, 35, 35)
tw = draw.textlength(text, font=cell_font)
draw.text((x + (cell_w - tw) / 2, y + cell_h * 0.26), text, fill=text_color, font=cell_font)
draw.rectangle([heat_x0, heat_y0, heat_x1, heat_y1], outline=(120, 120, 120), width=2)
def draw_failure_colorbar(draw: ImageDraw.ImageDraw, box: tuple[int, int, int, int]) -> None:
x0, y0, x1, y1 = box
tick_font = font(18)
draw.text((x0, y0 - 28), "failure rate (%)", fill=(40, 40, 40), font=tick_font)
for i in range(x0, x1):
value = (i - x0) / max(1, x1 - x0) * 60.0
draw.line([i, y0, i, y1], fill=failure_color(value), width=1)
draw.rectangle([x0, y0, x1, y1], outline=(120, 120, 120), width=1)
for tick in [0, 15, 30, 45, 60]:
x = x0 + (tick / 60.0) * (x1 - x0)
draw.line([x, y1, x, y1 + 8], fill=(50, 50, 50), width=1)
draw.text((x - 12, y1 + 12), str(tick), fill=(50, 50, 50), font=tick_font)
def draw_failure_colorbar_vertical(draw: ImageDraw.ImageDraw, box: tuple[int, int, int, int]) -> None:
x0, y0, x1, y1 = box
tick_font = font(18)
label_font = font(19)
for y in range(y0, y1):
value = (y1 - y) / max(1, y1 - y0) * 60.0
draw.line([x0, y, x1, y], fill=failure_color(value), width=1)
draw.rectangle([x0, y0, x1, y1], outline=(120, 120, 120), width=1)
for tick in [0, 15, 30, 45, 60]:
y = y1 - (tick / 60.0) * (y1 - y0)
draw.line([x0 - 8, y, x0, y], fill=(50, 50, 50), width=1)
draw.text((x1 + 8, y - 10), str(tick), fill=(50, 50, 50), font=tick_font)
draw.text((x0 - 56, y0 - 34), "failure rate (%)", fill=(40, 40, 40), font=label_font)
def make_fig5(summaries: list[SummaryRecord]) -> None:
fig_dir = OUT / "fig5"
ensure_dir(fig_dir)
rows = aggregate_success(inferred_summaries(summaries), ("boat", "flow_type", "method"))
rows.sort(key=lambda r: (boat_sort_key(r["boat"]), flow_sort_key(r["flow_type"]), method_sort_key(r["method"])))
for row in rows:
row["boat_label"] = BOAT_LABEL[row["boat"]]
row["flow_label"] = FLOW_LABEL[row["flow_type"]]
row["method_label"] = METHOD_LABEL[row["method"]]
row["failure_rate"] = 1.0 - float(row["success_rate"])
row["failure_percent"] = 100.0 * row["failure_rate"]
row["success_percent"] = 100.0 * float(row["success_rate"])
write_rows(
fig_dir / "figure5_failure_by_flow",
rows,
[
"boat",
"boat_label",
"flow_type",
"flow_label",
"method",
"method_label",
"successes",
"episodes",
"success_rate",
"success_percent",
"failure_rate",
"failure_percent",
"source_files",
],
)
source_rows = []
for r in sorted(inferred_summaries(summaries), key=lambda x: (boat_sort_key(x.boat), flow_sort_key(x.flow_type), method_sort_key(x.method), task_sort_key(x.task))):
source_rows.append(
{
"boat": r.boat,
"flow_type": r.flow_type,
"method": r.method,
"task": r.task,
"successes": r.successes,
"episodes": r.episodes,
"success_rate": r.success_rate,
"failure_rate": 1.0 - r.success_rate,
"source_file": r.source_file,
"item_index": r.item_index,
"json_path_successes": f"$[{r.item_index}].by_context.inferred.successes",
"json_path_episodes": f"$[{r.item_index}].by_context.inferred.episodes",
}
)
write_rows(
fig_dir / "figure5_failure_by_flow_source_rows",
source_rows,
[
"boat",
"flow_type",
"method",
"task",
"successes",
"episodes",
"success_rate",
"failure_rate",
"source_file",
"item_index",
"json_path_successes",
"json_path_episodes",
],
)
canvas = Image.new("RGB", (2400, 585), "white")
draw = ImageDraw.Draw(canvas)
draw_failure_heatmap_panel(draw, (30, 26, 1132, 540), rows, "twin", "(A) Twin")
draw_failure_heatmap_panel(draw, (1162, 26, 2264, 540), rows, "triangle", "(B) Triangle")
draw_failure_colorbar_vertical(draw, (2310, 84, 2336, 518))
out = fig_dir / "figure5_failure_by_flow.png"
trim_whitespace(canvas, pad_x=8, pad_y=0).save(out)
md = [
"# Paper Figure 5 Provenance",
"",
"Purpose: paper Figure 5 flow-family breakdown of downstream planning failure rates. Failure rate is used because many success rates are 100% or near 100%.",
"",
"Generated outputs:",
f"- `{rel(out)}`",
f"- `{rel(fig_dir / 'figure5_failure_by_flow.tsv')}`",
f"- `{rel(fig_dir / 'figure5_failure_by_flow.csv')}`",
f"- `{rel(fig_dir / 'figure5_failure_by_flow_source_rows.tsv')}`",
f"- `{rel(fig_dir / 'figure5_failure_by_flow_source_rows.csv')}`",
"",
"Source:",
f"- `{rel(PLANNING_DIR)}/*.json`",
"- JSON selectors: `$[item_index].by_context.inferred.successes` and `$[item_index].by_context.inferred.episodes`.",
"- Aggregation: sum successes and episodes over all tasks for each method, boat, and flow type.",
"- Failure rate: `1 - successes / episodes`.",
"- Row-level source entries are recorded in `figure5_failure_by_flow_source_rows.tsv/csv` with `source_file`, `item_index`, and JSON path columns.",
"- Excluded diagnostic FlowMo contexts: `zero`, `shuffled`.",
"",
"Traditional controller naming:",
"- `No-Flow LOS`: line-of-sight controller with no ambient-flow compensation.",
"- `Current-Estimator LOS`: line-of-sight controller with an online drift/current estimate from recent pose history.",
"- `Oracle-Flow LOS`: line-of-sight controller with privileged true local simulator flow feed-forward.",
]
write_text(fig_dir / "figure5_provenance.md", "\n".join(md) + "\n")
def make_table1(summaries: list[SummaryRecord]) -> None:
table_dir = OUT / "tables"
ensure_dir(table_dir)
records = inferred_summaries(summaries)
by_key = {(r.task, r.boat, r.flow_type, r.method): r for r in records}
data_rows: list[dict[str, Any]] = []
latex_rows: list[list[str]] = []
for task in TASK_ORDER:
for boat in BOAT_ORDER:
for flow in FLOW_ORDER:
row: dict[str, Any] = {
"task": task,
"task_label": TASK_LABEL[task],
"boat": boat,
"boat_label": BOAT_LABEL[boat],
"flow_type": flow,
"flow_label": FLOW_LABEL[flow],
}
latex_row = [TASK_LABEL[task], BOAT_LABEL[boat], FLOW_LABEL[flow]]
for method in METHOD_ORDER:
rec = by_key.get((task, boat, flow, method))
if rec is None:
row[f"{method}_success_rate"] = ""
row[f"{method}_successes"] = ""
row[f"{method}_episodes"] = ""
row[f"{method}_source"] = ""
latex_row.append("--")
else:
row[f"{method}_success_rate"] = rec.success_rate
row[f"{method}_success_percent"] = 100.0 * rec.success_rate
row[f"{method}_successes"] = rec.successes
row[f"{method}_episodes"] = rec.episodes
row[f"{method}_source"] = rec.source_file
row[f"{method}_json_path"] = f"$[{rec.item_index}].by_context.inferred.success_rate"
latex_row.append(pct(rec.success_rate, 0))
data_rows.append(row)
latex_rows.append(latex_row)
fields = ["task", "task_label", "boat", "boat_label", "flow_type", "flow_label"]
for method in METHOD_ORDER:
fields += [
f"{method}_success_rate",
f"{method}_success_percent",
f"{method}_successes",
f"{method}_episodes",
f"{method}_source",
f"{method}_json_path",
]
write_rows(table_dir / "table1_success_by_task_boat_flow", data_rows, fields)
header = ["Task", "Boat", "Flow"] + [METHOD_SHORT[m] for m in METHOD_ORDER]
write_table(
table_dir / "table1_success_by_task_boat_flow.tex",
header,
latex_rows,
"Planning success rate by task, boat, and flow family. Values are percentages over 50 episodes per setting; NF-LOS, CE-LOS, and OF-LOS denote No-Flow LOS, Current-Estimator LOS, and Oracle-Flow LOS. FlowMo diagnostic zero/shuffled contexts are excluded.",
"tab:planning_success_task_boat_flow",
)
md = [
"# Table 1 Provenance",
"",
"Purpose: success rate for every task, boat, flow family, and method.",
"",
"Generated outputs:",
f"- `{rel(table_dir / 'table1_success_by_task_boat_flow.tex')}`",
f"- `{rel(table_dir / 'table1_success_by_task_boat_flow.tsv')}`",
f"- `{rel(table_dir / 'table1_success_by_task_boat_flow.csv')}`",
"",
"Source:",
f"- `{rel(PLANNING_DIR)}/*.json`",
"- JSON selector per cell: `$[item_index].by_context.inferred.success_rate`.",
"- Companion fields in TSV/CSV include `$[item_index].by_context.inferred.successes` and `episodes`.",
"- Excluded diagnostic FlowMo contexts: `zero`, `shuffled`.",
"",
"Traditional controller naming:",
"- `No-Flow LOS`: line-of-sight controller with no ambient-flow compensation.",
"- `Current-Estimator LOS`: line-of-sight controller with an online drift/current estimate from recent pose history.",
"- `Oracle-Flow LOS`: line-of-sight controller with privileged true local simulator flow feed-forward.",
]
write_text(table_dir / "table1_provenance.md", "\n".join(md) + "\n")
def make_table2(episodes: list[EpisodeRecord]) -> None:
table_dir = OUT / "tables"
ensure_dir(table_dir)
records = inferred_episodes(episodes)
grouped: dict[tuple[str, str, str], list[EpisodeRecord]] = defaultdict(list)
sources: dict[tuple[str, str, str], set[str]] = defaultdict(set)
for r in records:
key = (r.task, r.boat, r.method)
grouped[key].append(r)
sources[key].add(r.source_file)
data_rows: list[dict[str, Any]] = []
latex_rows: list[list[str]] = []
for task in TASK_ORDER:
for boat in BOAT_ORDER:
for method in METHOD_ORDER:
key = (task, boat, method)
items = grouped[key]
if not items:
continue
final_distance_mean = sum(r.final_distance for r in items) / len(items)
min_goal_distance_mean = sum(r.mean_min_goal_distance for r in items) / len(items)
success_items = [r for r in items if r.success and r.energy is not None]
energy_success_mean = sum(float(r.energy) for r in success_items) / len(success_items) if success_items else None
row = {
"task": task,
"task_label": TASK_LABEL[task],
"boat": boat,
"boat_label": BOAT_LABEL[boat],
"method": method,
"method_label": METHOD_LABEL[method],
"episodes": len(items),
"successful_energy_episodes": len(success_items),
"final_distance_mean": final_distance_mean,
"mean_min_goal_distance": min_goal_distance_mean,
"energy_success_mean": energy_success_mean if energy_success_mean is not None else "",
"source_files": ";".join(sorted(sources[key])),
"json_selector": "$[item_index].results[*] filtered by context_mode == inferred",
}
data_rows.append(row)
latex_rows.append(
[
TASK_LABEL[task],
BOAT_LABEL[boat],
METHOD_LABEL[method],
fmt(final_distance_mean),
fmt(energy_success_mean),
]
)
write_rows(
table_dir / "table2_energy_distance_by_task_boat_method",
data_rows,
[
"task",
"task_label",
"boat",
"boat_label",
"method",
"method_label",
"episodes",
"successful_energy_episodes",
"final_distance_mean",
"mean_min_goal_distance",
"energy_success_mean",
"source_files",
"json_selector",
],
)
write_table(
table_dir / "table2_energy_distance_by_task_boat_method.tex",
["Task", "Boat", "Method", "Final dist.", "Energy (succ.)"],
latex_rows,
"Planning distance and energy by task, boat, and method, aggregated over all flow families. Final distance is averaged over all episodes; energy is averaged over successful episodes.",
"tab:planning_energy_distance_task_boat",
)
md = [
"# Table 2 Provenance",
"",
"Purpose: distance and energy metrics by task, boat, and method.",
"",
"Generated outputs:",
f"- `{rel(table_dir / 'table2_energy_distance_by_task_boat_method.tex')}`",
f"- `{rel(table_dir / 'table2_energy_distance_by_task_boat_method.tsv')}`",
f"- `{rel(table_dir / 'table2_energy_distance_by_task_boat_method.csv')}`",
"",
"Source:",
f"- `{rel(PLANNING_DIR)}/*.json`",
"- JSON selector: `$[item_index].results[*]`, filtered to `context_mode == inferred`.",
"- Final distance: mean of `final_distance` over all filtered episodes.",
"- Energy: mean of `energy` over filtered successful episodes only.",
"- The TSV/CSV also includes `mean_min_goal_distance`, computed from the same filtered episodes.",
"- Excluded diagnostic FlowMo contexts: `zero`, `shuffled`.",
"",
"Traditional controller naming:",
"- `No-Flow LOS`: line-of-sight controller with no ambient-flow compensation.",
"- `Current-Estimator LOS`: line-of-sight controller with an online drift/current estimate from recent pose history.",
"- `Oracle-Flow LOS`: line-of-sight controller with privileged true local simulator flow feed-forward.",
]
write_text(table_dir / "table2_provenance.md", "\n".join(md) + "\n")
def make_probe_table() -> None:
table_dir = OUT / "tables"
ensure_dir(table_dir)
data = read_json(PROBE_JSON)
split = data["splits"]["test"]
feature_order = ["z", "c", "z_c"]
target_order = ["momentum", "local_flow", "episode_drift"]
rows: list[dict[str, Any]] = []
latex_rows: list[list[str]] = []
for target in target_order:
latex_row = [target.replace("_", " ")]
for feature in feature_order:
entry = split[target][feature]
rows.append(
{
"split": "test",
"target": target,
"feature": feature,
"r2_mean": entry["r2_mean"],
"rmse": entry["rmse"],
"source_file": rel(PROBE_JSON),
"json_path_r2": f"$.splits.test.{target}.{feature}.r2_mean",
"json_path_rmse": f"$.splits.test.{target}.{feature}.rmse",
}
)
latex_row.append(fmt(entry["r2_mean"], 3))
for feature in feature_order:
latex_row.append(fmt(split[target][feature]["rmse"], 3))
latex_rows.append(latex_row)
write_rows(
table_dir / "table3_probe_diagnostics",
rows,
["split", "target", "feature", "r2_mean", "rmse", "source_file", "json_path_r2", "json_path_rmse"],
)
write_table(
table_dir / "table3_probe_diagnostics.tex",
["Target", "z R2", "c R2", "z+c R2", "z RMSE", "c RMSE", "z+c RMSE"],
latex_rows,
"Frozen linear probe diagnostics for FlowMo representations on the test split.",
"tab:flowmo_probe_diagnostics",
)
md = [
"# Table 3 Provenance",
"",
"Purpose: frozen linear probe diagnostics for FlowMo latent variables.",
"",
"Generated outputs:",
f"- `{rel(table_dir / 'table3_probe_diagnostics.tex')}`",
f"- `{rel(table_dir / 'table3_probe_diagnostics.tsv')}`",
f"- `{rel(table_dir / 'table3_probe_diagnostics.csv')}`",
"",
"Source:",
f"- `{rel(PROBE_JSON)}`",
"- JSON selectors: `$.splits.test.<target>.<feature>.r2_mean` and `$.splits.test.<target>.<feature>.rmse`.",
"- Probe model: frozen FlowMo features with ridge regression, as reported by the JSON metadata.",
]
write_text(table_dir / "table3_provenance.md", "\n".join(md) + "\n")
def make_overview() -> None:
files = sorted(p for p in OUT.rglob("*") if p.is_file())
lines = [
"# Paper Artifact Export",
"",
"Generated from local experiment outputs under `experiments/reports/`.",
"",
"Important source files:",
f"- `{rel(PREDICTION_JSON)}`",
f"- `{rel(PROBE_JSON)}`",
f"- `{rel(PLANNING_DIR)}/*.json`",
f"- `{rel(GIF_DIR)}/*.gif`",
"",
"Traditional controller names used in these exports:",
]
for method in TRADITIONAL_METHODS:
lines.append(f"- `{METHOD_LABEL[method]}`: {METHOD_DESCRIPTION[method]}")
lines += [
"",
"Generated files:",
]
for p in files:
lines.append(f"- `{rel(p)}`")
write_text(OUT / "README.md", "\n".join(lines) + "\n")
def validate_inputs() -> None:
missing = [p for p in [PREDICTION_JSON, PROBE_JSON, PLANNING_DIR, GIF_DIR] if not p.exists()]
if missing:
raise FileNotFoundError("Missing required experiment outputs: " + ", ".join(str(p) for p in missing))
planning_files = sorted(PLANNING_DIR.glob("*.json"))
if len(planning_files) != len(TASK_ORDER) * len(BOAT_ORDER) * len(FLOW_ORDER):
raise RuntimeError(f"Expected {len(TASK_ORDER) * len(BOAT_ORDER) * len(FLOW_ORDER)} planning JSON files, found {len(planning_files)}")
def main() -> None:
validate_inputs()
if OUT.exists():
shutil.rmtree(OUT)
ensure_dir(OUT)
summaries, episodes = load_planning()
extract_fig3()
make_fig4(summaries)
make_fig5(summaries)
make_table1(summaries)
make_table2(episodes)
make_probe_table()
make_overview()
print(f"Wrote paper artifacts to {OUT}")
if __name__ == "__main__":
main()