| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import numpy as np |
| from PIL import Image, ImageDraw, ImageFont |
|
|
| from driftwm.sim.flow import sample_flow |
| from driftwm.utils import ensure_dir |
|
|
|
|
| PAPER_FLOW_FAMILIES = [ |
| "noflow", |
| "uniform", |
| "vortex_center", |
| "double_gyre", |
| "source_sink", |
| "source_sink_pair", |
| "gradient", |
| "shear", |
| "turbulent_patch", |
| "random_fourier", |
| ] |
|
|
|
|
| def color_map(values: np.ndarray, vmax: float = 0.38) -> np.ndarray: |
| x = np.clip(values / vmax, 0.0, 1.0) |
| stops = np.array( |
| [ |
| [18, 35, 61], |
| [31, 91, 107], |
| [64, 145, 108], |
| [164, 190, 105], |
| [244, 220, 102], |
| ], |
| dtype=np.float32, |
| ) |
| scaled = x * (len(stops) - 1) |
| lo = np.floor(scaled).astype(np.int32) |
| hi = np.clip(lo + 1, 0, len(stops) - 1) |
| frac = scaled[..., None] - lo[..., None] |
| return ((1.0 - frac) * stops[lo] + frac * stops[hi]).astype(np.uint8) |
|
|
|
|
| def draw_arrow(draw: ImageDraw.ImageDraw, start: tuple[float, float], vec: tuple[float, float], scale: float) -> None: |
| vx, vy = vec |
| norm = float(np.hypot(vx, vy)) |
| if norm < 1.0e-4: |
| return |
| sx, sy = start |
| ex = sx + scale * vx |
| ey = sy - scale * vy |
| draw.line((sx, sy, ex, ey), fill=(255, 255, 255), width=2) |
| ux = (ex - sx) / max(float(np.hypot(ex - sx, ey - sy)), 1.0e-6) |
| uy = (ey - sy) / max(float(np.hypot(ex - sx, ey - sy)), 1.0e-6) |
| px, py = -uy, ux |
| size = 6.0 |
| p1 = (ex, ey) |
| p2 = (ex - size * ux + 0.55 * size * px, ey - size * uy + 0.55 * size * py) |
| p3 = (ex - size * ux - 0.55 * size * px, ey - size * uy - 0.55 * size * py) |
| draw.polygon((p1, p2, p3), fill=(255, 255, 255)) |
|
|
|
|
| def make_panel( |
| family: str, |
| flow_seed: int, |
| panel_size: int = 320, |
| *, |
| show_header: bool = True, |
| show_seed: bool = True, |
| ) -> Image.Image: |
| rng = np.random.default_rng(flow_seed) |
| flow = sample_flow(family, rng, flow_id=flow_seed) |
| n_bg = 96 |
| xs = np.linspace(0.0, 10.0, n_bg, dtype=np.float32) |
| ys = np.linspace(0.0, 10.0, n_bg, dtype=np.float32) |
| bg_grid = np.stack(np.meshgrid(xs, ys), axis=-1).astype(np.float32) |
| bg_velocity = flow.velocity(bg_grid, t=0.0) |
| speed = np.linalg.norm(bg_velocity, axis=-1) |
| bg = Image.fromarray(color_map(np.flipud(speed)), mode="RGB").resize((panel_size, panel_size), Image.Resampling.BICUBIC) |
| header = 44 if show_header else 0 |
| panel = Image.new("RGB", (panel_size, panel_size + header), (245, 245, 240)) |
| panel.paste(bg, (0, header)) |
| draw = ImageDraw.Draw(panel) |
| font = ImageFont.load_default() |
| draw.rectangle((0, 0, panel_size - 1, panel_size + header - 1), outline=(22, 28, 33), width=2) |
| if show_header: |
| draw.text((10, 10), f"{family}", fill=(20, 24, 28), font=font) |
| if show_seed: |
| draw.text((10, 25), f"seed={flow_seed}", fill=(55, 61, 68), font=font) |
|
|
| n_arrow = 13 |
| ax = np.linspace(0.8, 9.2, n_arrow, dtype=np.float32) |
| ay = np.linspace(0.8, 9.2, n_arrow, dtype=np.float32) |
| arrow_grid = np.stack(np.meshgrid(ax, ay), axis=-1).astype(np.float32) |
| arrow_velocity = flow.velocity(arrow_grid, t=0.0) |
| for pos, vel in zip(arrow_grid.reshape(-1, 2), arrow_velocity.reshape(-1, 2), strict=True): |
| px = 0.5 + pos[0] / 10.0 * (panel_size - 1) |
| py = header + 0.5 + (1.0 - pos[1] / 10.0) * (panel_size - 1) |
| draw_arrow(draw, (float(px), float(py)), (float(vel[0]), float(vel[1])), scale=72.0) |
| return panel |
|
|
|
|
| def _family_seeds(seed: int) -> list[int]: |
| rng = np.random.default_rng(seed) |
| return [int(rng.integers(1, 2**31 - 1)) for _ in PAPER_FLOW_FAMILIES] |
|
|
|
|
| def make_flow_atlas(seed: int, out: str | Path, panel_size: int = 320) -> None: |
| seeds = _family_seeds(seed) |
| panels = [make_panel(family, flow_seed, panel_size=panel_size) for family, flow_seed in zip(PAPER_FLOW_FAMILIES, seeds, strict=True)] |
| width, height = panels[0].size |
| margin = 22 |
| header = 58 |
| cols = 5 |
| rows = int(np.ceil(len(panels) / cols)) |
| atlas = Image.new("RGB", (cols * width + (cols + 1) * margin, rows * height + (rows + 1) * margin + header), (250, 250, 246)) |
| draw = ImageDraw.Draw(atlas) |
| font = ImageFont.load_default() |
| draw.text((margin, 18), f"FlowMo-WM paper flow atlas, deterministic atlas seed={seed}", fill=(15, 19, 24), font=font) |
| for idx, panel in enumerate(panels): |
| row, col = divmod(idx, cols) |
| x = margin + col * (width + margin) |
| y = header + margin + row * (height + margin) |
| atlas.paste(panel, (x, y)) |
| out = Path(out) |
| ensure_dir(out.parent) |
| atlas.save(out) |
| print(f"wrote {out}") |
| for family, flow_seed in zip(PAPER_FLOW_FAMILIES, seeds, strict=True): |
| print(f"{family}: {flow_seed}") |
|
|
|
|
| def export_flow_panels(seed: int, out_dir: str | Path, panel_size: int = 640) -> None: |
| out_dir = Path(out_dir) |
| clean_dir = out_dir / "clean" |
| labeled_dir = out_dir / "labeled" |
| ensure_dir(clean_dir) |
| ensure_dir(labeled_dir) |
| rows = ["family\tseed\tclean_png\tlabeled_png"] |
| for family, flow_seed in zip(PAPER_FLOW_FAMILIES, _family_seeds(seed), strict=True): |
| clean = make_panel(family, flow_seed, panel_size=panel_size, show_header=False) |
| labeled = make_panel(family, flow_seed, panel_size=panel_size, show_header=True, show_seed=False) |
| clean_path = clean_dir / f"{family}.png" |
| labeled_path = labeled_dir / f"{family}.png" |
| clean.save(clean_path) |
| labeled.save(labeled_path) |
| rows.append(f"{family}\t{flow_seed}\t{clean_path}\t{labeled_path}") |
| print(f"wrote {clean_path}") |
| print(f"wrote {labeled_path}") |
| manifest = out_dir / "manifest.tsv" |
| manifest.write_text("\n".join(rows) + "\n") |
| print(f"wrote {manifest}") |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--seed", type=int, default=20260525) |
| parser.add_argument("--out", default="experiments/reports/figures/flow_family_atlas.png") |
| parser.add_argument("--panel-dir", default=None) |
| parser.add_argument("--panel-size", type=int, default=640) |
| args = parser.parse_args() |
| make_flow_atlas(args.seed, args.out) |
| if args.panel_dir is not None: |
| export_flow_panels(args.seed, args.panel_dir, panel_size=args.panel_size) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|