FlowMo-WM / experiments /plot_flow_atlas.py
cccat6's picture
Update FlowMo-WM code and static flow protocol
ccf9f1b verified
from __future__ import annotations
import argparse
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from driftwm.sim.flow import sample_flow
from driftwm.utils import ensure_dir
PAPER_FLOW_FAMILIES = [
"noflow",
"uniform",
"vortex_center",
"double_gyre",
"source_sink",
"source_sink_pair",
"gradient",
"shear",
"turbulent_patch",
"random_fourier",
]
def color_map(values: np.ndarray, vmax: float = 0.38) -> np.ndarray:
x = np.clip(values / vmax, 0.0, 1.0)
stops = np.array(
[
[18, 35, 61],
[31, 91, 107],
[64, 145, 108],
[164, 190, 105],
[244, 220, 102],
],
dtype=np.float32,
)
scaled = x * (len(stops) - 1)
lo = np.floor(scaled).astype(np.int32)
hi = np.clip(lo + 1, 0, len(stops) - 1)
frac = scaled[..., None] - lo[..., None]
return ((1.0 - frac) * stops[lo] + frac * stops[hi]).astype(np.uint8)
def draw_arrow(draw: ImageDraw.ImageDraw, start: tuple[float, float], vec: tuple[float, float], scale: float) -> None:
vx, vy = vec
norm = float(np.hypot(vx, vy))
if norm < 1.0e-4:
return
sx, sy = start
ex = sx + scale * vx
ey = sy - scale * vy
draw.line((sx, sy, ex, ey), fill=(255, 255, 255), width=2)
ux = (ex - sx) / max(float(np.hypot(ex - sx, ey - sy)), 1.0e-6)
uy = (ey - sy) / max(float(np.hypot(ex - sx, ey - sy)), 1.0e-6)
px, py = -uy, ux
size = 6.0
p1 = (ex, ey)
p2 = (ex - size * ux + 0.55 * size * px, ey - size * uy + 0.55 * size * py)
p3 = (ex - size * ux - 0.55 * size * px, ey - size * uy - 0.55 * size * py)
draw.polygon((p1, p2, p3), fill=(255, 255, 255))
def make_panel(
family: str,
flow_seed: int,
panel_size: int = 320,
*,
show_header: bool = True,
show_seed: bool = True,
) -> Image.Image:
rng = np.random.default_rng(flow_seed)
flow = sample_flow(family, rng, flow_id=flow_seed)
n_bg = 96
xs = np.linspace(0.0, 10.0, n_bg, dtype=np.float32)
ys = np.linspace(0.0, 10.0, n_bg, dtype=np.float32)
bg_grid = np.stack(np.meshgrid(xs, ys), axis=-1).astype(np.float32)
bg_velocity = flow.velocity(bg_grid, t=0.0)
speed = np.linalg.norm(bg_velocity, axis=-1)
bg = Image.fromarray(color_map(np.flipud(speed)), mode="RGB").resize((panel_size, panel_size), Image.Resampling.BICUBIC)
header = 44 if show_header else 0
panel = Image.new("RGB", (panel_size, panel_size + header), (245, 245, 240))
panel.paste(bg, (0, header))
draw = ImageDraw.Draw(panel)
font = ImageFont.load_default()
draw.rectangle((0, 0, panel_size - 1, panel_size + header - 1), outline=(22, 28, 33), width=2)
if show_header:
draw.text((10, 10), f"{family}", fill=(20, 24, 28), font=font)
if show_seed:
draw.text((10, 25), f"seed={flow_seed}", fill=(55, 61, 68), font=font)
n_arrow = 13
ax = np.linspace(0.8, 9.2, n_arrow, dtype=np.float32)
ay = np.linspace(0.8, 9.2, n_arrow, dtype=np.float32)
arrow_grid = np.stack(np.meshgrid(ax, ay), axis=-1).astype(np.float32)
arrow_velocity = flow.velocity(arrow_grid, t=0.0)
for pos, vel in zip(arrow_grid.reshape(-1, 2), arrow_velocity.reshape(-1, 2), strict=True):
px = 0.5 + pos[0] / 10.0 * (panel_size - 1)
py = header + 0.5 + (1.0 - pos[1] / 10.0) * (panel_size - 1)
draw_arrow(draw, (float(px), float(py)), (float(vel[0]), float(vel[1])), scale=72.0)
return panel
def _family_seeds(seed: int) -> list[int]:
rng = np.random.default_rng(seed)
return [int(rng.integers(1, 2**31 - 1)) for _ in PAPER_FLOW_FAMILIES]
def make_flow_atlas(seed: int, out: str | Path, panel_size: int = 320) -> None:
seeds = _family_seeds(seed)
panels = [make_panel(family, flow_seed, panel_size=panel_size) for family, flow_seed in zip(PAPER_FLOW_FAMILIES, seeds, strict=True)]
width, height = panels[0].size
margin = 22
header = 58
cols = 5
rows = int(np.ceil(len(panels) / cols))
atlas = Image.new("RGB", (cols * width + (cols + 1) * margin, rows * height + (rows + 1) * margin + header), (250, 250, 246))
draw = ImageDraw.Draw(atlas)
font = ImageFont.load_default()
draw.text((margin, 18), f"FlowMo-WM paper flow atlas, deterministic atlas seed={seed}", fill=(15, 19, 24), font=font)
for idx, panel in enumerate(panels):
row, col = divmod(idx, cols)
x = margin + col * (width + margin)
y = header + margin + row * (height + margin)
atlas.paste(panel, (x, y))
out = Path(out)
ensure_dir(out.parent)
atlas.save(out)
print(f"wrote {out}")
for family, flow_seed in zip(PAPER_FLOW_FAMILIES, seeds, strict=True):
print(f"{family}: {flow_seed}")
def export_flow_panels(seed: int, out_dir: str | Path, panel_size: int = 640) -> None:
out_dir = Path(out_dir)
clean_dir = out_dir / "clean"
labeled_dir = out_dir / "labeled"
ensure_dir(clean_dir)
ensure_dir(labeled_dir)
rows = ["family\tseed\tclean_png\tlabeled_png"]
for family, flow_seed in zip(PAPER_FLOW_FAMILIES, _family_seeds(seed), strict=True):
clean = make_panel(family, flow_seed, panel_size=panel_size, show_header=False)
labeled = make_panel(family, flow_seed, panel_size=panel_size, show_header=True, show_seed=False)
clean_path = clean_dir / f"{family}.png"
labeled_path = labeled_dir / f"{family}.png"
clean.save(clean_path)
labeled.save(labeled_path)
rows.append(f"{family}\t{flow_seed}\t{clean_path}\t{labeled_path}")
print(f"wrote {clean_path}")
print(f"wrote {labeled_path}")
manifest = out_dir / "manifest.tsv"
manifest.write_text("\n".join(rows) + "\n")
print(f"wrote {manifest}")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=20260525)
parser.add_argument("--out", default="experiments/reports/figures/flow_family_atlas.png")
parser.add_argument("--panel-dir", default=None)
parser.add_argument("--panel-size", type=int, default=640)
args = parser.parse_args()
make_flow_atlas(args.seed, args.out)
if args.panel_dir is not None:
export_flow_panels(args.seed, args.panel_dir, panel_size=args.panel_size)
if __name__ == "__main__":
main()