arabeh commited on
Commit
dbd1ab1
·
1 Parent(s): 2234e51

Add v1/v2 apps, update README and deps

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .venv/
2
+ **/__pycache__/
3
+ *.pyc
README.md CHANGED
@@ -1,6 +1,5 @@
1
  ---
2
  title: DeepONet FPO Demo
3
- emoji: 🐨
4
  colorFrom: green
5
  colorTo: green
6
  sdk: gradio
@@ -11,4 +10,48 @@ license: mit
11
  short_description: 'Demo of unsteady flow around varied geometries '
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: DeepONet FPO Demo
 
3
  colorFrom: green
4
  colorTo: green
5
  sdk: gradio
 
10
  short_description: 'Demo of unsteady flow around varied geometries '
11
  ---
12
 
13
+ # DeepONet FPO Demo (FlowBench)
14
+
15
+ This Space runs **time-dependent DeepONet** checkpoints (s ∈ {1,4,8,16}) to generate **auto-regressive rollouts** of 2D velocity fields **(u, v)** around complex geometries (FPO / FlowBench).
16
+
17
+ You have two runnable apps:
18
+
19
+ - **`app_v1.py` (GT + metrics)**
20
+ - Uses `sample_cases/` containing the **full target sequence**.
21
+ - Produces:
22
+ - GT vs Pred **GIFs** for u and v
23
+ - ZIPs with all frames
24
+ - Relative L2 vs time **plot** + **CSV**
25
+ - Summary metrics (avg rel-L2 over t ≥ s)
26
+
27
+ - **`app_v2.py` (prediction-only, arbitrary rollout length)**
28
+ - Uses `sample_cases/few_timesteps/` containing **only the first 16 GT frames** (enough to seed any checkpoint).
29
+ - Produces:
30
+ - Prediction-only **GIFs** for u and v
31
+ - ZIPs with all frames
32
+ - Short run summary
33
+ - User chooses rollout length **N** (can be larger than 16; seeding uses only the first `s` frames).
34
+
35
+ ## Sample format
36
+
37
+ Each sample uses two files:
38
+
39
+ - `sample_{id}_input.npy` → SDF geometry: **[1, H, W]**
40
+ - `sample_{id}_output.npy` → velocity sequence: **[T, 2, H, W]** (channels are u, v)
41
+
42
+ For **v2**, `T = 16` and files must be located in:
43
+ - `sample_cases/few_timesteps/`
44
+
45
+ ## Checkpoints
46
+
47
+ Checkpoints are downloaded from the Hub at runtime:
48
+
49
+ - `checkpoints/time-dependent-deeponet_1in.ckpt`
50
+ - `checkpoints/time-dependent-deeponet_4in.ckpt`
51
+ - `checkpoints/time-dependent-deeponet_8in.ckpt`
52
+ - `checkpoints/time-dependent-deeponet_16in.ckpt`
53
+
54
+ Repo ID used by both apps:
55
+ ```text
56
+ arabeh/DeepONet-FlowBench-FPO
57
+
app.py CHANGED
@@ -1,22 +1,12 @@
1
- import torch
2
- from huggingface_hub import hf_hub_download
3
- from models.geometric_deeponet.geometric_deeponet import GeometricDeepONetTime
4
 
5
- REPO_ID = "arabeh/DeepONet-FlowBench-FPO"
6
- CKPTS = {
7
- "s=1": "checkpoints/time-dependent-deeponet_1in.ckpt",
8
- "s=4": "checkpoints/time-dependent-deeponet_4in.ckpt",
9
- "s=8": "checkpoints/time-dependent-deeponet_8in.ckpt",
10
- "s=16": "checkpoints/time-dependent-deeponet_16in.ckpt",
11
- }
12
 
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
14
 
15
- def load_model(s_key="s=4"):
16
- ckpt_path = hf_hub_download(REPO_ID, CKPTS[s_key])
17
- model = GeometricDeepONetTime.load_from_checkpoint(
18
- ckpt_path,
19
- map_location=device,
20
- )
21
- model.eval()
22
- return model.to(device)
 
1
+ import os
 
 
2
 
3
+ # default to v2; set DEMO_VERSION=v1 to run v1
4
+ ver = os.getenv("DEMO_VERSION", "v2").lower()
 
 
 
 
 
5
 
6
+ if ver == "v1":
7
+ from app_v1 import demo
8
+ else:
9
+ from app_v2 import demo
10
 
11
+ if __name__ == "__main__":
12
+ demo.launch()
 
 
 
 
 
 
app_v1.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import io
3
+ import zipfile
4
+ import tempfile
5
+ from functools import lru_cache
6
+ import numpy as np
7
+ import torch
8
+ import gradio as gr
9
+ from huggingface_hub import hf_hub_download
10
+ import matplotlib.pyplot as plt
11
+ import imageio.v2 as imageio
12
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
13
+ from einops import rearrange
14
+
15
+ from models.geometric_deeponet.geometric_deeponet import GeometricDeepONetTime
16
+
17
+ # ---------------- Config ----------------
18
+ REPO_ID = "arabeh/DeepONet-FlowBench-FPO"
19
+ CKPTS = {
20
+ "1": "checkpoints/time-dependent-deeponet_1in.ckpt",
21
+ "4": "checkpoints/time-dependent-deeponet_4in.ckpt",
22
+ "8": "checkpoints/time-dependent-deeponet_8in.ckpt",
23
+ "16": "checkpoints/time-dependent-deeponet_16in.ckpt",
24
+ }
25
+ SAMPLES_DIR = Path("sample_cases")
26
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ TMP = Path(tempfile.gettempdir())
28
+
29
+ RANGES = {
30
+ "u": (-2.0, 2.0),
31
+ "v": (-1.0, 1.0),
32
+ }
33
+
34
+
35
+ def _tag() -> str:
36
+ # unique per request (avoids filename collisions across sessions)
37
+ return next(tempfile._get_candidate_names())
38
+
39
+
40
+ def _tmp(tag: str, name: str) -> str:
41
+ return str(TMP / f"{tag}_{name}")
42
+
43
+
44
+ # ---------------- Samples ----------------
45
+ def list_samples():
46
+ if not SAMPLES_DIR.is_dir():
47
+ return []
48
+ ids = []
49
+ for p in SAMPLES_DIR.glob("sample_*_input.npy"):
50
+ # sample_{id}_input.npy
51
+ sid = p.stem.split("_")[1]
52
+ if sid.isdigit():
53
+ ids.append(sid)
54
+ return sorted(set(ids), key=int)
55
+
56
+
57
+ def load_sample(sample_id: str):
58
+ sdf = np.load(SAMPLES_DIR / f"sample_{sample_id}_input.npy").astype(np.float32) # [1,H,W]
59
+ y = np.load(SAMPLES_DIR / f"sample_{sample_id}_output.npy").astype(np.float32) # [T,2,H,W]
60
+ return sdf, y
61
+
62
+
63
+ # ---------------- Model ----------------
64
+ @lru_cache(maxsize=4)
65
+ def load_model(history_s: int) -> GeometricDeepONetTime:
66
+ ckpt_path = hf_hub_download(REPO_ID, CKPTS[str(history_s)])
67
+ model = GeometricDeepONetTime.load_from_checkpoint(ckpt_path, map_location=DEVICE)
68
+ return model.eval().to(DEVICE)
69
+
70
+
71
+ def static_tensors(hparams, sdf_np: np.ndarray):
72
+ _, H, W = sdf_np.shape
73
+ x = np.linspace(0.0, float(hparams.domain_length_x), W, dtype=np.float32)
74
+ y = np.linspace(0.0, float(hparams.domain_length_y), H, dtype=np.float32)
75
+ yv, xv = np.meshgrid(y, x, indexing="ij")
76
+ coords = np.stack([xv, yv], axis=0)[None] # [1,2,H,W]
77
+
78
+ sdf_t = torch.from_numpy(sdf_np)[None].to(DEVICE) # [1,1,H,W]
79
+ coords_t = torch.from_numpy(coords).to(DEVICE) # [1,2,H,W]
80
+ re_t = torch.zeros_like(sdf_t) # [1,1,H,W]
81
+ return sdf_t, coords_t, re_t, H, W
82
+
83
+
84
+ # ---------------- Rollout + metrics ----------------
85
+ def rollout(sample_id: str, history_s: str):
86
+ s = int(history_s)
87
+ model = load_model(s)
88
+
89
+ sdf, y_true = load_sample(sample_id)
90
+ T, C, H, W = y_true.shape
91
+ if C != 2:
92
+ raise ValueError(f"Expected 2 channels (u,v), got {C}")
93
+
94
+ s = min(s, T - 1) # ensure s < T
95
+ sdf_t, coords_t, re_t, _, _ = static_tensors(model.hparams, sdf)
96
+
97
+ y_pred = np.zeros_like(y_true)
98
+ y_pred[:s] = y_true[:s]
99
+ history = y_true[:s].copy() # [s,2,H,W]
100
+
101
+ for t in range(s, T):
102
+ branch = rearrange(history, "nb c h w -> (nb c) h w")[None] # [1,s*2,H,W]
103
+ branch_t = torch.from_numpy(branch).to(DEVICE)
104
+
105
+ with torch.no_grad():
106
+ y_hat = model((branch_t, re_t, coords_t, sdf_t)) # [1,1,p,2]
107
+
108
+ frame = y_hat[0, 0].view(H, W, 2).permute(2, 0, 1).cpu().numpy() # [2,H,W]
109
+ y_pred[t] = frame
110
+
111
+ history = frame[None] if s == 1 else np.concatenate([history[1:], frame[None]], axis=0)
112
+
113
+ return y_true, y_pred, s
114
+
115
+
116
+ def rollout_errors(y_true: np.ndarray, y_pred: np.ndarray, s: int):
117
+ yt = y_true[s:]
118
+ yp = y_pred[s:]
119
+ diff = yp - yt
120
+ ts = np.arange(s, y_true.shape[0])
121
+
122
+ def rel(comp: int):
123
+ d = diff[:, comp].reshape(len(ts), -1)
124
+ t = yt[:, comp].reshape(len(ts), -1)
125
+ return np.linalg.norm(d, axis=1) / np.linalg.norm(t, axis=1)
126
+
127
+ err_u = rel(0)
128
+ err_v = rel(1)
129
+ return ts, err_u, err_v, float(err_u.mean()), float(err_v.mean())
130
+
131
+
132
+ def pair_png(gt2d: np.ndarray, pred2d: np.ndarray, label: str, t: int) -> bytes:
133
+ vmin, vmax = RANGES.get(label, (-1.0, 1.0)) # fallback if label changes
134
+
135
+ fig, ax = plt.subplots(1, 2, figsize=(6.5, 2.6))
136
+
137
+ ax[0].imshow(gt2d, origin="lower", vmin=vmin, vmax=vmax)
138
+ ax[0].set_title(f"{label} GT – t={t}")
139
+ ax[0].axis("off")
140
+
141
+ im2 = ax[1].imshow(pred2d, origin="lower", vmin=vmin, vmax=vmax)
142
+ ax[1].set_title(f"{label} Pred – t={t}")
143
+ ax[1].axis("off")
144
+
145
+ # Colorbar height == ax[1] image height
146
+ divider = make_axes_locatable(ax[1])
147
+ cax = divider.append_axes("right", size="5%", pad=0.05)
148
+ fig.colorbar(im2, cax=cax)
149
+
150
+ buf = io.BytesIO()
151
+ fig.savefig(buf, format="png", bbox_inches="tight", dpi=110)
152
+ plt.close(fig)
153
+ return buf.getvalue()
154
+
155
+
156
+
157
+ def write_gif(tag: str, y_true: np.ndarray, y_pred: np.ndarray, comp: int, label: str) -> str:
158
+ path = _tmp(tag, f"{label}_rollout.gif")
159
+ with imageio.get_writer(path, mode="I", duration=0.1, loop=0) as w:
160
+ for t in range(y_true.shape[0]):
161
+ png = pair_png(y_true[t, comp], y_pred[t, comp], label, t)
162
+ w.append_data(imageio.imread(io.BytesIO(png)))
163
+ return path
164
+
165
+
166
+ def write_zip(tag: str, y_true: np.ndarray, y_pred: np.ndarray, comp: int, label: str) -> str:
167
+ path = _tmp(tag, f"{label}_frames.zip")
168
+ with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
169
+ for t in range(y_true.shape[0]):
170
+ zf.writestr(f"{label}_frame_{t:03d}.png", pair_png(y_true[t, comp], y_pred[t, comp], label, t))
171
+ return path
172
+
173
+
174
+ def write_error_assets(tag: str, ts: np.ndarray, err_u: np.ndarray, err_v: np.ndarray):
175
+ png = _tmp(tag, "relL2_vs_time.png")
176
+ csv = _tmp(tag, "relL2_vs_time.csv")
177
+
178
+ np.savetxt(
179
+ csv,
180
+ np.c_[ts, err_u, err_v],
181
+ delimiter=",",
182
+ header="timestep,rel_L2_u,rel_L2_v",
183
+ comments="",
184
+ )
185
+
186
+ fig, ax = plt.subplots(figsize=(5, 3))
187
+ ax.plot(ts, err_u, label="u")
188
+ ax.plot(ts, err_v, label="v")
189
+ ax.set_xlabel("Timestep")
190
+ ax.set_ylabel("Relative L2")
191
+ ax.set_title("Rollout rel. L2 vs time")
192
+ ax.legend()
193
+ ax.grid(True, alpha=0.3)
194
+ fig.savefig(png, dpi=120, bbox_inches="tight")
195
+ plt.close(fig)
196
+
197
+ return png, csv
198
+
199
+
200
+ # ---------------- Gradio callback ----------------
201
+ def predict_rollout(sample_id: str, history_s: str):
202
+ tag = _tag()
203
+
204
+ y_true, y_pred, s = rollout(sample_id, history_s)
205
+ ts, err_u, err_v, avg_u, avg_v = rollout_errors(y_true, y_pred, s)
206
+
207
+ u_gif = write_gif(tag, y_true, y_pred, 0, "u")
208
+ v_gif = write_gif(tag, y_true, y_pred, 1, "v")
209
+ u_zip = write_zip(tag, y_true, y_pred, 0, "u")
210
+ v_zip = write_zip(tag, y_true, y_pred, 1, "v")
211
+ err_png, csv = write_error_assets(tag, ts, err_u, err_v)
212
+
213
+ metrics = (
214
+ f"Rollout relative L2 error (averaged over t ≥ {s}):\n"
215
+ f" u: {avg_u:.3e}\n"
216
+ f" v: {avg_v:.3e}"
217
+ )
218
+
219
+ return (u_gif, u_gif, u_zip, v_gif, v_gif, v_zip, err_png, csv, metrics)
220
+
221
+
222
+ # ---------------- UI ----------------
223
+ sample_choices = list_samples() or ["0"]
224
+
225
+ demo = gr.Interface(
226
+ fn=predict_rollout,
227
+ inputs=[
228
+ gr.Radio(sample_choices, value=sample_choices[0], label="Sample ID"),
229
+ gr.Radio(["1", "4", "8", "16"], value="16", label="History length s"),
230
+ ],
231
+ outputs=[
232
+ gr.Image(type="filepath", label="u rollout (GIF)"),
233
+ gr.File(label="Download u rollout (GIF)"),
234
+ gr.File(label="Download all u frames (ZIP)"),
235
+ gr.Image(type="filepath", label="v rollout (GIF)"),
236
+ gr.File(label="Download v rollout (GIF)"),
237
+ gr.File(label="Download all v frames (ZIP)"),
238
+ gr.Image(type="filepath", label="Relative L2 vs time"),
239
+ gr.File(label="Download L2 vs time (CSV)"),
240
+ gr.Textbox(label="Summary metrics"),
241
+ ],
242
+ title="Time-Dependent DeepONet – FPO Rollout Demo",
243
+ description=(
244
+ "Auto-regressive 60-step rollout of u and v fields for a selected sample. "
245
+ "Choose history length s (1, 4, 8, 16). Download videos/frames and relative error vs time (CSV)."
246
+ ),
247
+ )
248
+
249
+ if __name__ == "__main__":
250
+ demo.launch()
251
+
app_v2.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import io
3
+ import zipfile
4
+ import tempfile
5
+ from functools import lru_cache
6
+
7
+ import numpy as np
8
+ import torch
9
+ import gradio as gr
10
+ from huggingface_hub import hf_hub_download
11
+ import matplotlib.pyplot as plt
12
+ import imageio.v2 as imageio
13
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
14
+ from einops import rearrange
15
+
16
+ from models.geometric_deeponet.geometric_deeponet import GeometricDeepONetTime
17
+
18
+ # ---------------- Config ----------------
19
+ REPO_ID = "arabeh/DeepONet-FlowBench-FPO"
20
+ CKPTS = {
21
+ "1": "checkpoints/time-dependent-deeponet_1in.ckpt",
22
+ "4": "checkpoints/time-dependent-deeponet_4in.ckpt",
23
+ "8": "checkpoints/time-dependent-deeponet_8in.ckpt",
24
+ "16": "checkpoints/time-dependent-deeponet_16in.ckpt",
25
+ }
26
+
27
+ # v2 samples live here (only 16 GT timesteps per sample)
28
+ SAMPLES_DIR = Path("sample_cases") / "few_timesteps"
29
+
30
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ TMP = Path(tempfile.gettempdir())
32
+
33
+ RANGES = {
34
+ "u": (-2.0, 2.0),
35
+ "v": (-1.0, 1.0),
36
+ }
37
+
38
+
39
+ def _tag() -> str:
40
+ return next(tempfile._get_candidate_names())
41
+
42
+
43
+ def _tmp(tag: str, name: str) -> str:
44
+ return str(TMP / f"{tag}_{name}")
45
+
46
+
47
+ # ---------------- Samples ----------------
48
+ def list_samples():
49
+ if not SAMPLES_DIR.is_dir():
50
+ return []
51
+ ids = []
52
+ for p in SAMPLES_DIR.glob("sample_*_input.npy"):
53
+ sid = p.stem.split("_")[1]
54
+ if sid.isdigit():
55
+ ids.append(sid)
56
+ return sorted(set(ids), key=int)
57
+
58
+
59
+ def load_sample(sample_id: str):
60
+ sdf = np.load(SAMPLES_DIR / f"sample_{sample_id}_input.npy").astype(np.float32) # [1,H,W]
61
+ y16 = np.load(SAMPLES_DIR / f"sample_{sample_id}_output.npy").astype(np.float32) # [16,2,H,W]
62
+ return sdf, y16
63
+
64
+
65
+ # ---------------- Model ----------------
66
+ @lru_cache(maxsize=4)
67
+ def load_model(history_s: int) -> GeometricDeepONetTime:
68
+ ckpt_path = hf_hub_download(REPO_ID, CKPTS[str(history_s)])
69
+ model = GeometricDeepONetTime.load_from_checkpoint(ckpt_path, map_location=DEVICE)
70
+ return model.eval().to(DEVICE)
71
+
72
+
73
+ def static_tensors(hparams, sdf_np: np.ndarray):
74
+ _, H, W = sdf_np.shape
75
+
76
+ x = np.linspace(0.0, float(hparams.domain_length_x), W, dtype=np.float32)
77
+ y = np.linspace(0.0, float(hparams.domain_length_y), H, dtype=np.float32)
78
+ yv, xv = np.meshgrid(y, x, indexing="ij")
79
+ coords = np.stack([xv, yv], axis=0)[None] # [1,2,H,W]
80
+
81
+ sdf_t = torch.from_numpy(sdf_np)[None].to(DEVICE) # [1,1,H,W]
82
+ coords_t = torch.from_numpy(coords).to(DEVICE) # [1,2,H,W]
83
+ re_t = torch.zeros_like(sdf_t) # [1,1,H,W]
84
+ return sdf_t, coords_t, re_t, H, W
85
+
86
+
87
+ # ---------------- Rollout ----------------
88
+ def rollout_pred(sample_id: str, history_s: str, n_steps: int):
89
+ s = int(history_s)
90
+ n_steps = int(n_steps)
91
+
92
+ if n_steps <= 0:
93
+ raise ValueError("Number of rollout steps must be a positive integer.")
94
+ if n_steps < s:
95
+ n_steps = s # must have at least s frames to seed
96
+
97
+ model = load_model(s)
98
+ sdf, y16 = load_sample(sample_id)
99
+
100
+ # Expect [16,2,H,W] (or more), but we ONLY use first s to seed the model.
101
+ if y16.ndim != 4 or y16.shape[1] != 2:
102
+ raise ValueError(f"Expected y shape [T,2,H,W], got {y16.shape}")
103
+ if y16.shape[0] < s:
104
+ raise ValueError(f"Sample only has {y16.shape[0]} timesteps, but checkpoint needs s={s}.")
105
+
106
+ _, _, H, W = y16.shape
107
+ sdf_t, coords_t, re_t, _, _ = static_tensors(model.hparams, sdf)
108
+
109
+ seed = y16[:s].copy() # [s,2,H,W] (GT seed only)
110
+ y_out = np.zeros((n_steps, 2, H, W), dtype=np.float32)
111
+ y_out[:s] = seed
112
+
113
+ history = seed.copy()
114
+ for t in range(s, n_steps):
115
+ branch = rearrange(history, "nb c h w -> (nb c) h w")[None] # [1,s*2,H,W]
116
+ branch_t = torch.from_numpy(branch).to(DEVICE)
117
+
118
+ with torch.no_grad():
119
+ y_hat = model((branch_t, re_t, coords_t, sdf_t)) # [1,1,p,2]
120
+
121
+ frame = y_hat[0, 0].view(H, W, 2).permute(2, 0, 1).cpu().numpy().astype(np.float32) # [2,H,W]
122
+ y_out[t] = frame
123
+ history = frame[None] if s == 1 else np.concatenate([history[1:], frame[None]], axis=0)
124
+
125
+ return y_out, s
126
+
127
+
128
+ # ---------------- Rendering (prediction-only) ----------------
129
+ def single_png(field2d: np.ndarray, label: str, t: int) -> bytes:
130
+ vmin, vmax = RANGES.get(label, (-1.0, 1.0))
131
+
132
+ fig, ax = plt.subplots(1, 1, figsize=(3.4, 2.8))
133
+ im = ax.imshow(field2d, origin="lower", vmin=vmin, vmax=vmax)
134
+ ax.set_title(f"{label} – t={t}")
135
+ ax.axis("off")
136
+
137
+ divider = make_axes_locatable(ax)
138
+ cax = divider.append_axes("right", size="6%", pad=0.05)
139
+ fig.colorbar(im, cax=cax)
140
+
141
+ buf = io.BytesIO()
142
+ fig.savefig(buf, format="png", bbox_inches="tight", dpi=120)
143
+ plt.close(fig)
144
+ return buf.getvalue()
145
+
146
+
147
+ def write_gif(tag: str, y: np.ndarray, comp: int, label: str) -> str:
148
+ path = _tmp(tag, f"{label}_rollout.gif")
149
+ with imageio.get_writer(path, mode="I", duration=0.1, loop=0) as w:
150
+ for t in range(y.shape[0]):
151
+ png = single_png(y[t, comp], label, t)
152
+ w.append_data(imageio.imread(io.BytesIO(png)))
153
+ return path
154
+
155
+
156
+ def write_zip(tag: str, y: np.ndarray, comp: int, label: str) -> str:
157
+ path = _tmp(tag, f"{label}_frames.zip")
158
+ with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
159
+ for t in range(y.shape[0]):
160
+ zf.writestr(f"{label}_frame_{t:03d}.png", single_png(y[t, comp], label, t))
161
+ return path
162
+
163
+ # ---------------- Gradio callback ----------------
164
+ def run_v2(sample_id: str, history_s: str, n_steps: int):
165
+ tag = _tag()
166
+ y, s = rollout_pred(sample_id, history_s, n_steps)
167
+
168
+ u_gif = write_gif(tag, y, comp=0, label="u")
169
+ v_gif = write_gif(tag, y, comp=1, label="v")
170
+ u_zip = write_zip(tag, y, comp=0, label="u")
171
+ v_zip = write_zip(tag, y, comp=1, label="v")
172
+
173
+ summary = (
174
+ f"Seeded with s={s} timesteps from {SAMPLES_DIR}.\n"
175
+ f"Generated rollout length N={y.shape[0]} (frames labeled seed for t<s, pred for t≥s)."
176
+ )
177
+
178
+ return (
179
+ u_gif, u_gif, u_zip,
180
+ v_gif, v_gif, v_zip,
181
+ summary,
182
+ )
183
+
184
+
185
+ # ---------------- UI ----------------
186
+ sample_choices = list_samples() or ["0"]
187
+ history_choices = ["1", "4", "8", "16"]
188
+
189
+ demo = gr.Interface(
190
+ fn=run_v2,
191
+ inputs=[
192
+ gr.Radio(sample_choices, value=sample_choices[0], label="Sample ID"),
193
+ gr.Radio(history_choices, value="16", label="History length s (checkpoint)"),
194
+ gr.Number(value=60, precision=0, label="Rollout steps N (total frames)"),
195
+ ],
196
+ outputs=[
197
+ gr.Image(type="filepath", label="u rollout (GIF)"),
198
+ gr.File(label="Download u rollout (GIF)"),
199
+ gr.File(label="Download all u frames (ZIP)"),
200
+ gr.Image(type="filepath", label="v rollout (GIF)"),
201
+ gr.File(label="Download v rollout (GIF)"),
202
+ gr.File(label="Download all v frames (ZIP)"),
203
+ gr.Textbox(label="Run summary"),
204
+ ],
205
+ title="Time-Dependent DeepONet – FPO Rollout Demo",
206
+ description=(
207
+ "Auto-regressive rollout of u and v fields for a selected sample. "
208
+ "Choose history length s (1, 4, 8, 16). Download videos/frames."
209
+ ),
210
+ )
211
+
212
+ if __name__ == "__main__":
213
+ demo.launch()
models/base.py CHANGED
@@ -3,72 +3,12 @@ import torch
3
 
4
 
5
  class BaseLightningModule(pl.LightningModule):
6
- def configure_optimizers(self):
7
- return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
8
-
9
- def _masked_mse(self, y_hat, y_true, sdf):
10
- mask = (sdf > 0).flatten(1).unsqueeze(-1)
11
- se = ((y_hat - y_true) ** 2) * mask
12
- return se.sum() / mask.sum()
13
 
14
- def training_step(self, batch, batch_idx):
15
- (branch, re, coords, sdf), tgt = batch
16
- y_hat = self.model((branch, re, coords, sdf))
17
- if self.hparams.use_derivative_loss:
18
- loss = self._derivative_loss(y_hat, tgt, sdf)
19
- else:
20
- loss = self._masked_mse(y_hat, tgt, sdf)
21
-
22
- self.log('train_loss', loss)
23
- return loss
24
-
25
- def validation_step(self, batch, batch_idx):
26
- (branch, re, coords, sdf), tgt = batch
27
- y_hat = self.model((branch, re, coords, sdf))
28
- if self.hparams.use_derivative_loss:
29
- loss = self._derivative_loss(y_hat, tgt, sdf)
30
- else:
31
- loss = self._masked_mse(y_hat, tgt, sdf)
32
-
33
- self.log('val_loss', loss)
34
- return loss
35
 
36
- def _derivative_loss(self, y_hat, y_true, sdf):
37
- # --- reshape [B,1,p,C] → [B,C,H,W] ---
38
- B, _, p, C = y_hat.shape
39
- H, W = self.hparams.height, self.hparams.width
40
- yh = y_hat.squeeze(1).permute(0,2,1).reshape(B, C, H, W)
41
- yt = y_true.squeeze(1).permute(0,2,1).reshape(B, C, H, W)
42
 
43
- deriv_hat = self.deriv_calc(yh)
44
- deriv_true = self.deriv_calc(yt)
45
- fluid_mask = (sdf > 0) # [B,1,H,W]
46
- delta = self.hparams.domain_length_y / H
47
- loss = 0.0
48
- # Derivative tensors come out at resolution (H-1)x(W-1) so crop the fluid_mask to match:
49
- dm = fluid_mask[:, :, :-1, :-1].unsqueeze(1) # → [B,1,1,H-1,W-1]
50
- for key in ('u_x','u_y','v_x','v_y'):
51
- diff = deriv_hat[key] - deriv_true[key] # [B,ngp,1,H-1,W-1]
52
- # apply mask before averaging
53
- deriv_loss = delta * (diff.pow(2) * dm).sum() / dm.sum()
54
- self.log(f"deriv_loss/{key}", deriv_loss, on_step=False, on_epoch=True)
55
- loss = loss + deriv_loss
56
-
57
- inner = (sdf > 0) & (sdf <= delta) # [B,1,H,W]
58
- if inner.any().item():
59
- u_hat = yh[:, 0:1] # [B,1,H,W]
60
- v_hat = yh[:, 1:2]
61
- if self.hparams.use_zero_bc:
62
- bc_loss = 1000 * (u_hat[inner].pow(2) + v_hat[inner].pow(2)).mean()
63
-
64
- else:
65
- u_true = yt[:, 0:1]
66
- v_true = yt[:, 1:2]
67
- u_target = u_true[inner]
68
- v_target = v_true[inner]
69
- bc_loss = ((u_hat[inner] - u_target).pow(2) + (v_hat[inner] - v_target).pow(2)).mean()
70
-
71
- self.log("boundary_bc_loss", bc_loss, on_step=False, on_epoch=True)
72
- loss = loss + bc_loss
73
-
74
- return loss
 
3
 
4
 
5
  class BaseLightningModule(pl.LightningModule):
6
+ """
7
+ Minimal LightningModule base for the demo.
8
+ """
 
 
 
 
9
 
10
+ def configure_optimizers(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ lr = getattr(self.hparams, "lr", 1e-3)
13
+ return torch.optim.Adam(self.parameters(), lr=lr)
 
 
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -2,6 +2,7 @@ gradio
2
  torch
3
  pytorch-lightning
4
  huggingface_hub
5
- einops
6
  numpy
7
- matplotlib
 
 
 
2
  torch
3
  pytorch-lightning
4
  huggingface_hub
 
5
  numpy
6
+ imageio
7
+ einops
8
+ matplotlib
sample_cases/few_timesteps/sample_0_input.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:698f0a4c56d4d2beb1c4e9b3ccaba1f9c104a69ec60ca9501470589ff68f69f1
3
+ size 1048704
sample_cases/few_timesteps/sample_0_output.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12f2b8972a72ecd5a74294160f6d2fe38c2dc341364f85f25e235b4808168297
3
+ size 67108992
sample_cases/few_timesteps/sample_1_input.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7af7b5ee1e43eb97b1d99fb80a012b9edc4607d5b3139afc61c30d9aca4c60ed
3
+ size 1048704
sample_cases/few_timesteps/sample_1_output.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56650a750acd232d9fb173b6860c45d080b66ff3a239c13b6957e91674e12dee
3
+ size 67108992