from pathlib import Path import io import zipfile import tempfile from functools import lru_cache import warnings warnings.filterwarnings("ignore", message="Can't initialize NVML") warnings.filterwarnings("ignore", category=FutureWarning) try: import spaces # must be imported before torch/cuda usage on ZeroGPU except Exception: class spaces: # type: ignore @staticmethod def GPU(*args, **kwargs): def deco(fn): return fn return deco import numpy as np import torch import gradio as gr from huggingface_hub import hf_hub_download import matplotlib.pyplot as plt import imageio.v2 as imageio from mpl_toolkits.axes_grid1 import make_axes_locatable from einops import rearrange from models.geometric_deeponet.geometric_deeponet import GeometricDeepONetTime REPO_ID = "BGLab/DeepONet-FlowBench-FPO" CKPTS = { "1": "checkpoints/time-dependent-deeponet_1in.ckpt", "4": "checkpoints/time-dependent-deeponet_4in.ckpt", "8": "checkpoints/time-dependent-deeponet_8in.ckpt", "16": "checkpoints/time-dependent-deeponet_16in.ckpt", } SAMPLES_DIR = Path("sample_cases") / "few_timesteps" TMP = Path(tempfile.gettempdir()) RANGES = { "u": (-2.0, 2.0), "v": (-1.0, 1.0), } def _device_str() -> str: return "cuda" if torch.cuda.is_available() else "cpu" def _tag() -> str: return next(tempfile._get_candidate_names()) def _tmp(tag: str, name: str) -> str: out_dir = TMP / f"deeponet_fpo_{tag}" out_dir.mkdir(parents=True, exist_ok=True) return str(out_dir / name) def list_samples(): if not SAMPLES_DIR.is_dir(): return [] ids = [] for p in SAMPLES_DIR.glob("sample_*_input.npy"): sid = p.stem.split("_")[1] if sid.isdigit(): ids.append(sid) return sorted(set(ids), key=int) def load_sample(sample_id: str): sdf = np.load(SAMPLES_DIR / f"sample_{sample_id}_input.npy").astype(np.float32) # [1,H,W] y16 = np.load(SAMPLES_DIR / f"sample_{sample_id}_output.npy").astype(np.float32) # [16,2,H,W] return sdf, y16 @lru_cache(maxsize=8) def load_model(history_s: int, device_str: str) -> GeometricDeepONetTime: device = torch.device(device_str) ckpt_path = hf_hub_download(REPO_ID, CKPTS[str(history_s)]) model = GeometricDeepONetTime.load_from_checkpoint(ckpt_path, map_location=device) return model.eval().to(device) def static_tensors(hparams, sdf_np: np.ndarray, device: torch.device): _, H, W = sdf_np.shape x = np.linspace(0.0, float(hparams.domain_length_x), W, dtype=np.float32) y = np.linspace(0.0, float(hparams.domain_length_y), H, dtype=np.float32) yv, xv = np.meshgrid(y, x, indexing="ij") coords = np.stack([xv, yv], axis=0)[None] # [1,2,H,W] sdf_t = torch.from_numpy(sdf_np)[None].to(device) # [1,1,H,W] coords_t = torch.from_numpy(coords).to(device) # [1,2,H,W] re_t = torch.zeros_like(sdf_t) # [1,1,H,W] return sdf_t, coords_t, re_t, H, W def rollout_pred(sample_id: str, history_s: str, n_steps: int): s = int(history_s) n_steps = int(n_steps) if n_steps <= 0: raise ValueError("Number of rollout steps must be a positive integer.") if n_steps < s: n_steps = s dev_str = _device_str() device = torch.device(dev_str) model = load_model(s, dev_str) sdf, y16 = load_sample(sample_id) if y16.ndim != 4 or y16.shape[1] != 2: raise ValueError(f"Expected y shape [T,2,H,W], got {y16.shape}") if y16.shape[0] < s: raise ValueError(f"Sample only has {y16.shape[0]} timesteps, but checkpoint needs s={s}.") _, _, H, W = y16.shape sdf_t, coords_t, re_t, _, _ = static_tensors(model.hparams, sdf, device) seed = y16[:s].copy() y_out = np.zeros((n_steps, 2, H, W), dtype=np.float32) y_out[:s] = seed history = seed.copy() for t in range(s, n_steps): branch = rearrange(history, "nb c h w -> (nb c) h w")[None] # [1,s*2,H,W] branch_t = torch.from_numpy(branch).to(device) with torch.no_grad(): y_hat = model((branch_t, re_t, coords_t, sdf_t)) # [1,1,p,2] frame = ( y_hat[0, 0] .view(H, W, 2) .permute(2, 0, 1) .cpu() .numpy() .astype(np.float32) ) # [2,H,W] y_out[t] = frame history = frame[None] if s == 1 else np.concatenate([history[1:], frame[None]], axis=0) return y_out, s, dev_str def single_png(field2d: np.ndarray, label: str, t: int) -> bytes: vmin, vmax = RANGES.get(label, (-1.0, 1.0)) fig, ax = plt.subplots(1, 1, figsize=(3.4, 2.8)) im = ax.imshow(field2d, origin="lower", vmin=vmin, vmax=vmax) ax.set_title(f"{label} – t={t}") ax.axis("off") divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="6%", pad=0.05) fig.colorbar(im, cax=cax) buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight", dpi=120) plt.close(fig) return buf.getvalue() def write_gif(tag: str, y: np.ndarray, comp: int, label: str) -> str: path = _tmp(tag, f"{label}_rollout.gif") with imageio.get_writer(path, mode="I", duration=0.1, loop=0) as w: for t in range(y.shape[0]): png = single_png(y[t, comp], label, t) w.append_data(imageio.imread(io.BytesIO(png))) return path def write_zip(tag: str, y: np.ndarray, comp: int, label: str) -> str: path = _tmp(tag, f"{label}_frames.zip") with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf: for t in range(y.shape[0]): zf.writestr(f"{label}_frame_{t:03d}.png", single_png(y[t, comp], label, t)) return path def run_v2(sample_id: str, history_s: str, n_steps: int): tag = _tag() y, s, dev_str = rollout_pred(sample_id, history_s, n_steps) u_gif = write_gif(tag, y, comp=0, label="u") v_gif = write_gif(tag, y, comp=1, label="v") u_zip = write_zip(tag, y, comp=0, label="u") v_zip = write_zip(tag, y, comp=1, label="v") summary = ( f"Device: {dev_str}\n" f"Seeded with s={s} timesteps from {SAMPLES_DIR}.\n" f"Generated rollout length N={y.shape[0]} (seed frames t