arabeh commited on
Commit
dd48369
·
1 Parent(s): 184a045

minor edits

Browse files
Files changed (3) hide show
  1. app.py +1 -2
  2. app_v1.py +57 -39
  3. app_v2.py +67 -37
app.py CHANGED
@@ -1,10 +1,9 @@
1
  import warnings
2
 
3
- # (optional) reduce noisy logs
4
  warnings.filterwarnings("ignore", message="Can't initialize NVML")
5
  warnings.filterwarnings("ignore", category=FutureWarning)
6
 
7
- # IMPORTANT: import spaces BEFORE torch/cuda gets touched anywhere
8
  try:
9
  import spaces # noqa: F401
10
  except Exception:
 
1
  import warnings
2
 
 
3
  warnings.filterwarnings("ignore", message="Can't initialize NVML")
4
  warnings.filterwarnings("ignore", category=FutureWarning)
5
 
6
+ # Must happen before torch/cuda is touched anywhere in imports
7
  try:
8
  import spaces # noqa: F401
9
  except Exception:
app_v1.py CHANGED
@@ -4,12 +4,20 @@ import zipfile
4
  import tempfile
5
  from functools import lru_cache
6
  import warnings
 
7
  warnings.filterwarnings("ignore", message="Can't initialize NVML")
8
  warnings.filterwarnings("ignore", category=FutureWarning)
 
9
  try:
10
- import spaces
11
  except Exception:
12
- pass
 
 
 
 
 
 
13
 
14
  import numpy as np
15
  import torch
@@ -22,16 +30,15 @@ from einops import rearrange
22
 
23
  from models.geometric_deeponet.geometric_deeponet import GeometricDeepONetTime
24
 
25
- # ---------------- Config ----------------
26
  REPO_ID = "BGLab/DeepONet-FlowBench-FPO"
27
  CKPTS = {
28
- "1": "checkpoints/time-dependent-deeponet_1in.ckpt",
29
- "4": "checkpoints/time-dependent-deeponet_4in.ckpt",
30
- "8": "checkpoints/time-dependent-deeponet_8in.ckpt",
31
  "16": "checkpoints/time-dependent-deeponet_16in.ckpt",
32
  }
33
  SAMPLES_DIR = Path("sample_cases")
34
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
  TMP = Path(tempfile.gettempdir())
36
 
37
  RANGES = {
@@ -40,8 +47,11 @@ RANGES = {
40
  }
41
 
42
 
 
 
 
 
43
  def _tag() -> str:
44
- # unique per request (avoids filename collisions across sessions)
45
  return next(tempfile._get_candidate_names())
46
 
47
 
@@ -51,13 +61,11 @@ def _tmp(tag: str, name: str) -> str:
51
  return str(out_dir / name)
52
 
53
 
54
- # ---------------- Samples ----------------
55
  def list_samples():
56
  if not SAMPLES_DIR.is_dir():
57
  return []
58
  ids = []
59
  for p in SAMPLES_DIR.glob("sample_*_input.npy"):
60
- # sample_{id}_input.npy
61
  sid = p.stem.split("_")[1]
62
  if sid.isdigit():
63
  ids.append(sid)
@@ -65,62 +73,64 @@ def list_samples():
65
 
66
 
67
  def load_sample(sample_id: str):
68
- sdf = np.load(SAMPLES_DIR / f"sample_{sample_id}_input.npy").astype(np.float32) # [1,H,W]
69
- y = np.load(SAMPLES_DIR / f"sample_{sample_id}_output.npy").astype(np.float32) # [T,2,H,W]
70
  return sdf, y
71
 
72
 
73
- # ---------------- Model ----------------
74
- @lru_cache(maxsize=4)
75
- def load_model(history_s: int) -> GeometricDeepONetTime:
76
  ckpt_path = hf_hub_download(REPO_ID, CKPTS[str(history_s)])
77
- model = GeometricDeepONetTime.load_from_checkpoint(ckpt_path, map_location=DEVICE)
78
- return model.eval().to(DEVICE)
79
 
80
 
81
- def static_tensors(hparams, sdf_np: np.ndarray):
82
  _, H, W = sdf_np.shape
83
  x = np.linspace(0.0, float(hparams.domain_length_x), W, dtype=np.float32)
84
  y = np.linspace(0.0, float(hparams.domain_length_y), H, dtype=np.float32)
85
  yv, xv = np.meshgrid(y, x, indexing="ij")
86
  coords = np.stack([xv, yv], axis=0)[None] # [1,2,H,W]
87
 
88
- sdf_t = torch.from_numpy(sdf_np)[None].to(DEVICE) # [1,1,H,W]
89
- coords_t = torch.from_numpy(coords).to(DEVICE) # [1,2,H,W]
90
- re_t = torch.zeros_like(sdf_t) # [1,1,H,W]
91
  return sdf_t, coords_t, re_t, H, W
92
 
93
 
94
- # ---------------- Rollout + metrics ----------------
95
  def rollout(sample_id: str, history_s: str):
96
- s = int(history_s)
97
- model = load_model(s)
 
 
 
98
 
99
  sdf, y_true = load_sample(sample_id)
100
  T, C, H, W = y_true.shape
101
  if C != 2:
102
  raise ValueError(f"Expected 2 channels (u,v), got {C}")
103
 
104
- s = min(s, T - 1) # ensure s < T
105
- sdf_t, coords_t, re_t, _, _ = static_tensors(model.hparams, sdf)
106
 
107
  y_pred = np.zeros_like(y_true)
108
  y_pred[:s] = y_true[:s]
109
- history = y_true[:s].copy() # [s,2,H,W]
110
 
111
  for t in range(s, T):
112
  branch = rearrange(history, "nb c h w -> (nb c) h w")[None] # [1,s*2,H,W]
113
- branch_t = torch.from_numpy(branch).to(DEVICE)
114
 
115
  with torch.no_grad():
116
  y_hat = model((branch_t, re_t, coords_t, sdf_t)) # [1,1,p,2]
117
 
118
- frame = y_hat[0, 0].view(H, W, 2).permute(2, 0, 1).cpu().numpy() # [2,H,W]
119
  y_pred[t] = frame
120
 
121
  history = frame[None] if s == 1 else np.concatenate([history[1:], frame[None]], axis=0)
122
 
123
- return y_true, y_pred, s
124
 
125
 
126
  def rollout_errors(y_true: np.ndarray, y_pred: np.ndarray, s: int):
@@ -132,7 +142,9 @@ def rollout_errors(y_true: np.ndarray, y_pred: np.ndarray, s: int):
132
  def rel(comp: int):
133
  d = diff[:, comp].reshape(len(ts), -1)
134
  t = yt[:, comp].reshape(len(ts), -1)
135
- return np.linalg.norm(d, axis=1) / np.linalg.norm(t, axis=1)
 
 
136
 
137
  err_u = rel(0)
138
  err_v = rel(1)
@@ -140,7 +152,7 @@ def rollout_errors(y_true: np.ndarray, y_pred: np.ndarray, s: int):
140
 
141
 
142
  def pair_png(gt2d: np.ndarray, pred2d: np.ndarray, label: str, t: int) -> bytes:
143
- vmin, vmax = RANGES.get(label, (-1.0, 1.0)) # fallback if label changes
144
 
145
  fig, ax = plt.subplots(1, 2, figsize=(6.5, 2.6))
146
 
@@ -152,7 +164,6 @@ def pair_png(gt2d: np.ndarray, pred2d: np.ndarray, label: str, t: int) -> bytes:
152
  ax[1].set_title(f"{label} Pred – t={t}")
153
  ax[1].axis("off")
154
 
155
- # Colorbar height == ax[1] image height
156
  divider = make_axes_locatable(ax[1])
157
  cax = divider.append_axes("right", size="5%", pad=0.05)
158
  fig.colorbar(im2, cax=cax)
@@ -176,7 +187,10 @@ def write_zip(tag: str, y_true: np.ndarray, y_pred: np.ndarray, comp: int, label
176
  path = _tmp(tag, f"{label}_frames.zip")
177
  with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
178
  for t in range(y_true.shape[0]):
179
- zf.writestr(f"{label}_frame_{t:03d}.png", pair_png(y_true[t, comp], y_pred[t, comp], label, t))
 
 
 
180
  return path
181
 
182
 
@@ -206,11 +220,10 @@ def write_error_assets(tag: str, ts: np.ndarray, err_u: np.ndarray, err_v: np.nd
206
  return png, csv
207
 
208
 
209
- # ---------------- Gradio callback ----------------
210
  def predict_rollout(sample_id: str, history_s: str):
211
  tag = _tag()
212
 
213
- y_true, y_pred, s = rollout(sample_id, history_s)
214
  ts, err_u, err_v, avg_u, avg_v = rollout_errors(y_true, y_pred, s)
215
 
216
  u_gif = write_gif(tag, y_true, y_pred, 0, "u")
@@ -220,6 +233,7 @@ def predict_rollout(sample_id: str, history_s: str):
220
  err_png, csv = write_error_assets(tag, ts, err_u, err_v)
221
 
222
  metrics = (
 
223
  f"Rollout relative L2 error (averaged over t ≥ {s}):\n"
224
  f" u: {avg_u:.3e}\n"
225
  f" v: {avg_v:.3e}"
@@ -228,12 +242,16 @@ def predict_rollout(sample_id: str, history_s: str):
228
  return (u_gif, u_gif, u_zip, v_gif, v_gif, v_zip, err_png, csv, metrics)
229
 
230
 
231
- # ---------------- UI builder ----------------
 
 
 
 
232
  def build_demo():
233
  sample_choices = list_samples() or ["0"]
234
 
235
  return gr.Interface(
236
- fn=predict_rollout,
237
  inputs=[
238
  gr.Radio(sample_choices, value=sample_choices[0], label="Sample ID"),
239
  gr.Radio(["1", "4", "8", "16"], value="16", label="History length s"),
@@ -249,7 +267,7 @@ def build_demo():
249
  gr.File(label="Download L2 vs time (CSV)"),
250
  gr.Textbox(label="Summary metrics"),
251
  ],
252
- title="Time-Dependent DeepONet – FPO Rollout Demo",
253
  description=(
254
  "Auto-regressive 60-step rollout of u and v fields for a selected sample. "
255
  "Choose history length s (1, 4, 8, 16). Download videos/frames and relative error vs time (CSV)."
 
4
  import tempfile
5
  from functools import lru_cache
6
  import warnings
7
+
8
  warnings.filterwarnings("ignore", message="Can't initialize NVML")
9
  warnings.filterwarnings("ignore", category=FutureWarning)
10
+
11
  try:
12
+ import spaces # must be imported before torch/cuda usage on ZeroGPU
13
  except Exception:
14
+ class spaces: # type: ignore
15
+ @staticmethod
16
+ def GPU(*args, **kwargs):
17
+ def deco(fn):
18
+ return fn
19
+ return deco
20
+
21
 
22
  import numpy as np
23
  import torch
 
30
 
31
  from models.geometric_deeponet.geometric_deeponet import GeometricDeepONetTime
32
 
33
+
34
  REPO_ID = "BGLab/DeepONet-FlowBench-FPO"
35
  CKPTS = {
36
+ "1": "checkpoints/time-dependent-deeponet_1in.ckpt",
37
+ "4": "checkpoints/time-dependent-deeponet_4in.ckpt",
38
+ "8": "checkpoints/time-dependent-deeponet_8in.ckpt",
39
  "16": "checkpoints/time-dependent-deeponet_16in.ckpt",
40
  }
41
  SAMPLES_DIR = Path("sample_cases")
 
42
  TMP = Path(tempfile.gettempdir())
43
 
44
  RANGES = {
 
47
  }
48
 
49
 
50
+ def _device_str() -> str:
51
+ return "cuda" if torch.cuda.is_available() else "cpu"
52
+
53
+
54
  def _tag() -> str:
 
55
  return next(tempfile._get_candidate_names())
56
 
57
 
 
61
  return str(out_dir / name)
62
 
63
 
 
64
  def list_samples():
65
  if not SAMPLES_DIR.is_dir():
66
  return []
67
  ids = []
68
  for p in SAMPLES_DIR.glob("sample_*_input.npy"):
 
69
  sid = p.stem.split("_")[1]
70
  if sid.isdigit():
71
  ids.append(sid)
 
73
 
74
 
75
  def load_sample(sample_id: str):
76
+ sdf = np.load(SAMPLES_DIR / f"sample_{sample_id}_input.npy").astype(np.float32) # [1,H,W]
77
+ y = np.load(SAMPLES_DIR / f"sample_{sample_id}_output.npy").astype(np.float32) # [T,2,H,W]
78
  return sdf, y
79
 
80
 
81
+ @lru_cache(maxsize=8)
82
+ def load_model(history_s: int, device_str: str) -> GeometricDeepONetTime:
83
+ device = torch.device(device_str)
84
  ckpt_path = hf_hub_download(REPO_ID, CKPTS[str(history_s)])
85
+ model = GeometricDeepONetTime.load_from_checkpoint(ckpt_path, map_location=device)
86
+ return model.eval().to(device)
87
 
88
 
89
+ def static_tensors(hparams, sdf_np: np.ndarray, device: torch.device):
90
  _, H, W = sdf_np.shape
91
  x = np.linspace(0.0, float(hparams.domain_length_x), W, dtype=np.float32)
92
  y = np.linspace(0.0, float(hparams.domain_length_y), H, dtype=np.float32)
93
  yv, xv = np.meshgrid(y, x, indexing="ij")
94
  coords = np.stack([xv, yv], axis=0)[None] # [1,2,H,W]
95
 
96
+ sdf_t = torch.from_numpy(sdf_np)[None].to(device) # [1,1,H,W]
97
+ coords_t = torch.from_numpy(coords).to(device) # [1,2,H,W]
98
+ re_t = torch.zeros_like(sdf_t) # [1,1,H,W]
99
  return sdf_t, coords_t, re_t, H, W
100
 
101
 
 
102
  def rollout(sample_id: str, history_s: str):
103
+ s_req = int(history_s)
104
+
105
+ dev_str = _device_str()
106
+ device = torch.device(dev_str)
107
+ model = load_model(s_req, dev_str)
108
 
109
  sdf, y_true = load_sample(sample_id)
110
  T, C, H, W = y_true.shape
111
  if C != 2:
112
  raise ValueError(f"Expected 2 channels (u,v), got {C}")
113
 
114
+ s = min(s_req, T - 1)
115
+ sdf_t, coords_t, re_t, _, _ = static_tensors(model.hparams, sdf, device)
116
 
117
  y_pred = np.zeros_like(y_true)
118
  y_pred[:s] = y_true[:s]
119
+ history = y_true[:s].copy()
120
 
121
  for t in range(s, T):
122
  branch = rearrange(history, "nb c h w -> (nb c) h w")[None] # [1,s*2,H,W]
123
+ branch_t = torch.from_numpy(branch).to(device)
124
 
125
  with torch.no_grad():
126
  y_hat = model((branch_t, re_t, coords_t, sdf_t)) # [1,1,p,2]
127
 
128
+ frame = y_hat[0, 0].view(H, W, 2).permute(2, 0, 1).cpu().numpy().astype(np.float32) # [2,H,W]
129
  y_pred[t] = frame
130
 
131
  history = frame[None] if s == 1 else np.concatenate([history[1:], frame[None]], axis=0)
132
 
133
+ return y_true, y_pred, s, dev_str
134
 
135
 
136
  def rollout_errors(y_true: np.ndarray, y_pred: np.ndarray, s: int):
 
142
  def rel(comp: int):
143
  d = diff[:, comp].reshape(len(ts), -1)
144
  t = yt[:, comp].reshape(len(ts), -1)
145
+ denom = np.linalg.norm(t, axis=1)
146
+ denom = np.where(denom == 0.0, 1.0, denom)
147
+ return np.linalg.norm(d, axis=1) / denom
148
 
149
  err_u = rel(0)
150
  err_v = rel(1)
 
152
 
153
 
154
  def pair_png(gt2d: np.ndarray, pred2d: np.ndarray, label: str, t: int) -> bytes:
155
+ vmin, vmax = RANGES.get(label, (-1.0, 1.0))
156
 
157
  fig, ax = plt.subplots(1, 2, figsize=(6.5, 2.6))
158
 
 
164
  ax[1].set_title(f"{label} Pred – t={t}")
165
  ax[1].axis("off")
166
 
 
167
  divider = make_axes_locatable(ax[1])
168
  cax = divider.append_axes("right", size="5%", pad=0.05)
169
  fig.colorbar(im2, cax=cax)
 
187
  path = _tmp(tag, f"{label}_frames.zip")
188
  with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
189
  for t in range(y_true.shape[0]):
190
+ zf.writestr(
191
+ f"{label}_frame_{t:03d}.png",
192
+ pair_png(y_true[t, comp], y_pred[t, comp], label, t),
193
+ )
194
  return path
195
 
196
 
 
220
  return png, csv
221
 
222
 
 
223
  def predict_rollout(sample_id: str, history_s: str):
224
  tag = _tag()
225
 
226
+ y_true, y_pred, s, dev_str = rollout(sample_id, history_s)
227
  ts, err_u, err_v, avg_u, avg_v = rollout_errors(y_true, y_pred, s)
228
 
229
  u_gif = write_gif(tag, y_true, y_pred, 0, "u")
 
233
  err_png, csv = write_error_assets(tag, ts, err_u, err_v)
234
 
235
  metrics = (
236
+ f"Device: {dev_str}\n"
237
  f"Rollout relative L2 error (averaged over t ≥ {s}):\n"
238
  f" u: {avg_u:.3e}\n"
239
  f" v: {avg_v:.3e}"
 
242
  return (u_gif, u_gif, u_zip, v_gif, v_gif, v_zip, err_png, csv, metrics)
243
 
244
 
245
+ @spaces.GPU(duration=180)
246
+ def predict_rollout_gpu(sample_id: str, history_s: str):
247
+ return predict_rollout(sample_id, history_s)
248
+
249
+
250
  def build_demo():
251
  sample_choices = list_samples() or ["0"]
252
 
253
  return gr.Interface(
254
+ fn=predict_rollout_gpu,
255
  inputs=[
256
  gr.Radio(sample_choices, value=sample_choices[0], label="Sample ID"),
257
  gr.Radio(["1", "4", "8", "16"], value="16", label="History length s"),
 
267
  gr.File(label="Download L2 vs time (CSV)"),
268
  gr.Textbox(label="Summary metrics"),
269
  ],
270
+ title="Time-Dependent DeepONet – FPO Rollout Demo",
271
  description=(
272
  "Auto-regressive 60-step rollout of u and v fields for a selected sample. "
273
  "Choose history length s (1, 4, 8, 16). Download videos/frames and relative error vs time (CSV)."
app_v2.py CHANGED
@@ -4,12 +4,20 @@ import zipfile
4
  import tempfile
5
  from functools import lru_cache
6
  import warnings
 
7
  warnings.filterwarnings("ignore", message="Can't initialize NVML")
8
  warnings.filterwarnings("ignore", category=FutureWarning)
 
9
  try:
10
- import spaces
11
  except Exception:
12
- pass
 
 
 
 
 
 
13
 
14
  import numpy as np
15
  import torch
@@ -22,19 +30,16 @@ from einops import rearrange
22
 
23
  from models.geometric_deeponet.geometric_deeponet import GeometricDeepONetTime
24
 
25
- # ---------------- Config ----------------
26
  REPO_ID = "BGLab/DeepONet-FlowBench-FPO"
27
  CKPTS = {
28
- "1": "checkpoints/time-dependent-deeponet_1in.ckpt",
29
- "4": "checkpoints/time-dependent-deeponet_4in.ckpt",
30
- "8": "checkpoints/time-dependent-deeponet_8in.ckpt",
31
  "16": "checkpoints/time-dependent-deeponet_16in.ckpt",
32
  }
33
 
34
- # v2 samples live here (only 16 GT timesteps per sample)
35
  SAMPLES_DIR = Path("sample_cases") / "few_timesteps"
36
-
37
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
  TMP = Path(tempfile.gettempdir())
39
 
40
  RANGES = {
@@ -42,15 +47,21 @@ RANGES = {
42
  "v": (-1.0, 1.0),
43
  }
44
 
 
 
 
 
 
45
  def _tag() -> str:
46
  return next(tempfile._get_candidate_names())
47
 
 
48
  def _tmp(tag: str, name: str) -> str:
49
  out_dir = TMP / f"deeponet_fpo_{tag}"
50
  out_dir.mkdir(parents=True, exist_ok=True)
51
  return str(out_dir / name)
52
 
53
- # ---------------- Samples ----------------
54
  def list_samples():
55
  if not SAMPLES_DIR.is_dir():
56
  return []
@@ -61,32 +72,34 @@ def list_samples():
61
  ids.append(sid)
62
  return sorted(set(ids), key=int)
63
 
 
64
  def load_sample(sample_id: str):
65
  sdf = np.load(SAMPLES_DIR / f"sample_{sample_id}_input.npy").astype(np.float32) # [1,H,W]
66
  y16 = np.load(SAMPLES_DIR / f"sample_{sample_id}_output.npy").astype(np.float32) # [16,2,H,W]
67
  return sdf, y16
68
 
69
- # ---------------- Model ----------------
70
- @lru_cache(maxsize=4)
71
- def load_model(history_s: int) -> GeometricDeepONetTime:
 
72
  ckpt_path = hf_hub_download(REPO_ID, CKPTS[str(history_s)])
73
- model = GeometricDeepONetTime.load_from_checkpoint(ckpt_path, map_location=DEVICE)
74
- return model.eval().to(DEVICE)
75
 
76
- def static_tensors(hparams, sdf_np: np.ndarray):
77
- _, H, W = sdf_np.shape
78
 
 
 
79
  x = np.linspace(0.0, float(hparams.domain_length_x), W, dtype=np.float32)
80
  y = np.linspace(0.0, float(hparams.domain_length_y), H, dtype=np.float32)
81
  yv, xv = np.meshgrid(y, x, indexing="ij")
82
  coords = np.stack([xv, yv], axis=0)[None] # [1,2,H,W]
83
 
84
- sdf_t = torch.from_numpy(sdf_np)[None].to(DEVICE) # [1,1,H,W]
85
- coords_t = torch.from_numpy(coords).to(DEVICE) # [1,2,H,W]
86
  re_t = torch.zeros_like(sdf_t) # [1,1,H,W]
87
  return sdf_t, coords_t, re_t, H, W
88
 
89
- # ---------------- Rollout ----------------
90
  def rollout_pred(sample_id: str, history_s: str, n_steps: int):
91
  s = int(history_s)
92
  n_steps = int(n_steps)
@@ -94,39 +107,48 @@ def rollout_pred(sample_id: str, history_s: str, n_steps: int):
94
  if n_steps <= 0:
95
  raise ValueError("Number of rollout steps must be a positive integer.")
96
  if n_steps < s:
97
- n_steps = s # must have at least s frames to seed
 
 
 
 
98
 
99
- model = load_model(s)
100
  sdf, y16 = load_sample(sample_id)
101
 
102
- # Expect [16,2,H,W] (or more), but we ONLY use first s to seed the model.
103
  if y16.ndim != 4 or y16.shape[1] != 2:
104
  raise ValueError(f"Expected y shape [T,2,H,W], got {y16.shape}")
105
  if y16.shape[0] < s:
106
  raise ValueError(f"Sample only has {y16.shape[0]} timesteps, but checkpoint needs s={s}.")
107
 
108
  _, _, H, W = y16.shape
109
- sdf_t, coords_t, re_t, _, _ = static_tensors(model.hparams, sdf)
110
 
111
- seed = y16[:s].copy() # [s,2,H,W] (GT seed only)
112
  y_out = np.zeros((n_steps, 2, H, W), dtype=np.float32)
113
  y_out[:s] = seed
114
 
115
  history = seed.copy()
116
  for t in range(s, n_steps):
117
  branch = rearrange(history, "nb c h w -> (nb c) h w")[None] # [1,s*2,H,W]
118
- branch_t = torch.from_numpy(branch).to(DEVICE)
119
 
120
  with torch.no_grad():
121
  y_hat = model((branch_t, re_t, coords_t, sdf_t)) # [1,1,p,2]
122
 
123
- frame = y_hat[0, 0].view(H, W, 2).permute(2, 0, 1).cpu().numpy().astype(np.float32) # [2,H,W]
 
 
 
 
 
 
 
124
  y_out[t] = frame
125
  history = frame[None] if s == 1 else np.concatenate([history[1:], frame[None]], axis=0)
126
 
127
- return y_out, s
 
128
 
129
- # ---------------- Rendering (prediction-only) ----------------
130
  def single_png(field2d: np.ndarray, label: str, t: int) -> bytes:
131
  vmin, vmax = RANGES.get(label, (-1.0, 1.0))
132
 
@@ -144,6 +166,7 @@ def single_png(field2d: np.ndarray, label: str, t: int) -> bytes:
144
  plt.close(fig)
145
  return buf.getvalue()
146
 
 
147
  def write_gif(tag: str, y: np.ndarray, comp: int, label: str) -> str:
148
  path = _tmp(tag, f"{label}_rollout.gif")
149
  with imageio.get_writer(path, mode="I", duration=0.1, loop=0) as w:
@@ -152,6 +175,7 @@ def write_gif(tag: str, y: np.ndarray, comp: int, label: str) -> str:
152
  w.append_data(imageio.imread(io.BytesIO(png)))
153
  return path
154
 
 
155
  def write_zip(tag: str, y: np.ndarray, comp: int, label: str) -> str:
156
  path = _tmp(tag, f"{label}_frames.zip")
157
  with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
@@ -159,10 +183,10 @@ def write_zip(tag: str, y: np.ndarray, comp: int, label: str) -> str:
159
  zf.writestr(f"{label}_frame_{t:03d}.png", single_png(y[t, comp], label, t))
160
  return path
161
 
162
- # ---------------- Gradio callback ----------------
163
  def run_v2(sample_id: str, history_s: str, n_steps: int):
164
  tag = _tag()
165
- y, s = rollout_pred(sample_id, history_s, n_steps)
166
 
167
  u_gif = write_gif(tag, y, comp=0, label="u")
168
  v_gif = write_gif(tag, y, comp=1, label="v")
@@ -170,8 +194,9 @@ def run_v2(sample_id: str, history_s: str, n_steps: int):
170
  v_zip = write_zip(tag, y, comp=1, label="v")
171
 
172
  summary = (
 
173
  f"Seeded with s={s} timesteps from {SAMPLES_DIR}.\n"
174
- f"Generated rollout length N={y.shape[0]} (frames labeled seed for t<s, pred for t≥s)."
175
  )
176
 
177
  return (
@@ -180,13 +205,18 @@ def run_v2(sample_id: str, history_s: str, n_steps: int):
180
  summary,
181
  )
182
 
183
- # ---------------- UI builder ----------------
 
 
 
 
 
184
  def build_demo():
185
  sample_choices = list_samples() or ["0"]
186
  history_choices = ["1", "4", "8", "16"]
187
 
188
  return gr.Interface(
189
- fn=run_v2,
190
  inputs=[
191
  gr.Radio(sample_choices, value=sample_choices[0], label="Sample ID"),
192
  gr.Radio(history_choices, value="16", label="History length s (checkpoint)"),
@@ -201,10 +231,10 @@ def build_demo():
201
  gr.File(label="Download all v frames (ZIP)"),
202
  gr.Textbox(label="Run summary"),
203
  ],
204
- title="Time-Dependent DeepONet – FPO Rollout Demo",
205
  description=(
206
- "Auto-regressive rollout of u and v fields for a selected sample. "
207
- "Choose history length s (1, 4, 8, 16). Download videos/frames."
208
  ),
209
  )
210
 
 
4
  import tempfile
5
  from functools import lru_cache
6
  import warnings
7
+
8
  warnings.filterwarnings("ignore", message="Can't initialize NVML")
9
  warnings.filterwarnings("ignore", category=FutureWarning)
10
+
11
  try:
12
+ import spaces # must be imported before torch/cuda usage on ZeroGPU
13
  except Exception:
14
+ class spaces: # type: ignore
15
+ @staticmethod
16
+ def GPU(*args, **kwargs):
17
+ def deco(fn):
18
+ return fn
19
+ return deco
20
+
21
 
22
  import numpy as np
23
  import torch
 
30
 
31
  from models.geometric_deeponet.geometric_deeponet import GeometricDeepONetTime
32
 
33
+
34
  REPO_ID = "BGLab/DeepONet-FlowBench-FPO"
35
  CKPTS = {
36
+ "1": "checkpoints/time-dependent-deeponet_1in.ckpt",
37
+ "4": "checkpoints/time-dependent-deeponet_4in.ckpt",
38
+ "8": "checkpoints/time-dependent-deeponet_8in.ckpt",
39
  "16": "checkpoints/time-dependent-deeponet_16in.ckpt",
40
  }
41
 
 
42
  SAMPLES_DIR = Path("sample_cases") / "few_timesteps"
 
 
43
  TMP = Path(tempfile.gettempdir())
44
 
45
  RANGES = {
 
47
  "v": (-1.0, 1.0),
48
  }
49
 
50
+
51
+ def _device_str() -> str:
52
+ return "cuda" if torch.cuda.is_available() else "cpu"
53
+
54
+
55
  def _tag() -> str:
56
  return next(tempfile._get_candidate_names())
57
 
58
+
59
  def _tmp(tag: str, name: str) -> str:
60
  out_dir = TMP / f"deeponet_fpo_{tag}"
61
  out_dir.mkdir(parents=True, exist_ok=True)
62
  return str(out_dir / name)
63
 
64
+
65
  def list_samples():
66
  if not SAMPLES_DIR.is_dir():
67
  return []
 
72
  ids.append(sid)
73
  return sorted(set(ids), key=int)
74
 
75
+
76
  def load_sample(sample_id: str):
77
  sdf = np.load(SAMPLES_DIR / f"sample_{sample_id}_input.npy").astype(np.float32) # [1,H,W]
78
  y16 = np.load(SAMPLES_DIR / f"sample_{sample_id}_output.npy").astype(np.float32) # [16,2,H,W]
79
  return sdf, y16
80
 
81
+
82
+ @lru_cache(maxsize=8)
83
+ def load_model(history_s: int, device_str: str) -> GeometricDeepONetTime:
84
+ device = torch.device(device_str)
85
  ckpt_path = hf_hub_download(REPO_ID, CKPTS[str(history_s)])
86
+ model = GeometricDeepONetTime.load_from_checkpoint(ckpt_path, map_location=device)
87
+ return model.eval().to(device)
88
 
 
 
89
 
90
+ def static_tensors(hparams, sdf_np: np.ndarray, device: torch.device):
91
+ _, H, W = sdf_np.shape
92
  x = np.linspace(0.0, float(hparams.domain_length_x), W, dtype=np.float32)
93
  y = np.linspace(0.0, float(hparams.domain_length_y), H, dtype=np.float32)
94
  yv, xv = np.meshgrid(y, x, indexing="ij")
95
  coords = np.stack([xv, yv], axis=0)[None] # [1,2,H,W]
96
 
97
+ sdf_t = torch.from_numpy(sdf_np)[None].to(device) # [1,1,H,W]
98
+ coords_t = torch.from_numpy(coords).to(device) # [1,2,H,W]
99
  re_t = torch.zeros_like(sdf_t) # [1,1,H,W]
100
  return sdf_t, coords_t, re_t, H, W
101
 
102
+
103
  def rollout_pred(sample_id: str, history_s: str, n_steps: int):
104
  s = int(history_s)
105
  n_steps = int(n_steps)
 
107
  if n_steps <= 0:
108
  raise ValueError("Number of rollout steps must be a positive integer.")
109
  if n_steps < s:
110
+ n_steps = s
111
+
112
+ dev_str = _device_str()
113
+ device = torch.device(dev_str)
114
+ model = load_model(s, dev_str)
115
 
 
116
  sdf, y16 = load_sample(sample_id)
117
 
 
118
  if y16.ndim != 4 or y16.shape[1] != 2:
119
  raise ValueError(f"Expected y shape [T,2,H,W], got {y16.shape}")
120
  if y16.shape[0] < s:
121
  raise ValueError(f"Sample only has {y16.shape[0]} timesteps, but checkpoint needs s={s}.")
122
 
123
  _, _, H, W = y16.shape
124
+ sdf_t, coords_t, re_t, _, _ = static_tensors(model.hparams, sdf, device)
125
 
126
+ seed = y16[:s].copy()
127
  y_out = np.zeros((n_steps, 2, H, W), dtype=np.float32)
128
  y_out[:s] = seed
129
 
130
  history = seed.copy()
131
  for t in range(s, n_steps):
132
  branch = rearrange(history, "nb c h w -> (nb c) h w")[None] # [1,s*2,H,W]
133
+ branch_t = torch.from_numpy(branch).to(device)
134
 
135
  with torch.no_grad():
136
  y_hat = model((branch_t, re_t, coords_t, sdf_t)) # [1,1,p,2]
137
 
138
+ frame = (
139
+ y_hat[0, 0]
140
+ .view(H, W, 2)
141
+ .permute(2, 0, 1)
142
+ .cpu()
143
+ .numpy()
144
+ .astype(np.float32)
145
+ ) # [2,H,W]
146
  y_out[t] = frame
147
  history = frame[None] if s == 1 else np.concatenate([history[1:], frame[None]], axis=0)
148
 
149
+ return y_out, s, dev_str
150
+
151
 
 
152
  def single_png(field2d: np.ndarray, label: str, t: int) -> bytes:
153
  vmin, vmax = RANGES.get(label, (-1.0, 1.0))
154
 
 
166
  plt.close(fig)
167
  return buf.getvalue()
168
 
169
+
170
  def write_gif(tag: str, y: np.ndarray, comp: int, label: str) -> str:
171
  path = _tmp(tag, f"{label}_rollout.gif")
172
  with imageio.get_writer(path, mode="I", duration=0.1, loop=0) as w:
 
175
  w.append_data(imageio.imread(io.BytesIO(png)))
176
  return path
177
 
178
+
179
  def write_zip(tag: str, y: np.ndarray, comp: int, label: str) -> str:
180
  path = _tmp(tag, f"{label}_frames.zip")
181
  with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
 
183
  zf.writestr(f"{label}_frame_{t:03d}.png", single_png(y[t, comp], label, t))
184
  return path
185
 
186
+
187
  def run_v2(sample_id: str, history_s: str, n_steps: int):
188
  tag = _tag()
189
+ y, s, dev_str = rollout_pred(sample_id, history_s, n_steps)
190
 
191
  u_gif = write_gif(tag, y, comp=0, label="u")
192
  v_gif = write_gif(tag, y, comp=1, label="v")
 
194
  v_zip = write_zip(tag, y, comp=1, label="v")
195
 
196
  summary = (
197
+ f"Device: {dev_str}\n"
198
  f"Seeded with s={s} timesteps from {SAMPLES_DIR}.\n"
199
+ f"Generated rollout length N={y.shape[0]} (seed frames t<s, predicted frames t≥s)."
200
  )
201
 
202
  return (
 
205
  summary,
206
  )
207
 
208
+
209
+ @spaces.GPU(duration=180)
210
+ def run_v2_gpu(sample_id: str, history_s: str, n_steps: int):
211
+ return run_v2(sample_id, history_s, n_steps)
212
+
213
+
214
  def build_demo():
215
  sample_choices = list_samples() or ["0"]
216
  history_choices = ["1", "4", "8", "16"]
217
 
218
  return gr.Interface(
219
+ fn=run_v2_gpu,
220
  inputs=[
221
  gr.Radio(sample_choices, value=sample_choices[0], label="Sample ID"),
222
  gr.Radio(history_choices, value="16", label="History length s (checkpoint)"),
 
231
  gr.File(label="Download all v frames (ZIP)"),
232
  gr.Textbox(label="Run summary"),
233
  ],
234
+ title="Time-Dependent DeepONet – FPO Rollout Demo",
235
  description=(
236
+ "Prediction-only auto-regressive rollout of u and v fields for a selected sample. "
237
+ "Choose history length s (1, 4, 8, 16) and rollout length N."
238
  ),
239
  )
240