SciMLx_Production / data /benchmarks_ext.py
Moatasim Farooque
Remove problematic files
54fa103
"""Extended benchmark definitions for SciML experiments.
Adds KdV, Wave, and corrected 2D benchmarks on top of prepare.py.
Supported benchmarks:
"kdv_1d" - Korteweg-de Vries soliton dynamics (ETDRK4 solver)
"wave_1d" - 1D wave equation u_tt = c^2 u_xx (Stormer-Verlet)
"darcy_2d" - 2D Darcy with proper variable-coeff solver (Richardson iter)
"ns_2d" - 2D Navier-Stokes with stable IC amplitude (CFL < 1)
Why darcy_2d and ns_2d?
prepare.py's darcy_2d solver uses only mean(a) → loses all spatial info;
source term f uses a FIXED seed independent of a → u is uncorrelated with a.
prepare.py's ns_2d solver uses IC scale=1.0 → CFL≈61 → immediate NaN.
Both benchmarks are broken at the data level; prepare.py is read-only.
These fixed versions provide correct, learnable benchmarks.
Data interface is identical to prepare.py:
make_ext_dataloader(benchmark, split, batch_size)
evaluate_l2_rel_ext(benchmark, model)
References:
KdV: Tran et al. (2023) "Factorized Fourier Neural Operators"
Wave: Rahman et al. (2022) U-NO
Darcy fix: Li et al. (2020) FNO paper, original Darcy benchmark setup
NS fix: standard semi-implicit spectral NS with CFL-stable parameters
"""
import math
import os
import time
import torch
import numpy as np
from data.prepare import (
GRID_SIZE, TIME_BUDGET, N_TRAIN, N_VAL, VAL_SEED, TRAIN_SEED,
CACHE_DIR,
solve_kdv_batch, solve_wave_batch, _random_ic, _random_ic_np, _random_ic_2d,
)
# ── Constants ─────────────────────────────────────────────────────────────────
EXT_BENCHMARKS = {"kdv_1d", "wave_1d", "darcy_2d", "ns_2d", "swe_2d", "allen_cahn_2d", "mhd_2d", "burgers_nu_01", "burgers_nu_001", "poisson_2d", "reionization_1d"}
EXT_N_CHANNELS = {
"kdv_1d": 1,
"wave_1d": 1,
"darcy_2d": 1,
"ns_2d": 1,
"ns_hre_2d": 1,
"swe_2d": 1,
"allen_cahn_2d": 1,
"mhd_2d": 2, # Vorticity (w) and Magnetic Potential (a)
"burgers_nu_01": 1,
"burgers_nu_001": 1,
"poisson_2d": 1,
"reionization_1d": 1,
"ellipse_2d": 1,
}
# KdV parameters
KDV_T = 1.0 # final time
KDV_NSTEPS = 1000 # ETDRK4 steps
# Wave parameters
WAVE_C = 1.0 # wave speed
WAVE_T = 1.0 # final time
WAVE_NSTEPS = 400 # Störmer-Verlet steps
# Darcy fix parameters
DARCY_FIX_N_ITER = 40 # PCG iterations
DARCY_FIX_MODES_F = 5 # source term Fourier modes
# NS fix parameters — reduces CFL from ~61 to ~0.6
NS_SCALE = 0.1 # IC vorticity amplitude (vs 1.0 in prepare.py -> 10x smaller)
NS_NSTEPS = 1000 # time steps (vs 100) — gives dt=0.001, CFL≈0.6
NS_NU = 1e-2 # kinematic viscosity (same as original)
NS_T = 1.0 # final time
# Allen-Cahn parameters
AC_EPSILON = 0.01
AC_T = 0.5
AC_NSTEPS = 200
# SWE parameters
SWE_G = 9.81
SWE_T = 0.2
SWE_NSTEPS = 200
# MHD parameters
MHD_NU = 1e-3
MHD_ETA = 1e-3
MHD_T = 0.5
MHD_NSTEPS = 500
from core.device import DEVICE, TORCH_DEVICE
# ── 2D Solvers (corrected) ────────────────────────────────────────────────────
def solve_darcy_2d_batch(
a: torch.Tensor,
f: torch.Tensor,
n_iter: int = DARCY_FIX_N_ITER,
) -> torch.Tensor:
"""Solve -∇·(a(x,y)∇u) = f on [0,1]² with periodic BCs.
Uses Preconditioned Conjugate Gradient (PCG) with the constant-coefficient
Poisson operator P = a_mean·(-Δ) as a preconditioner.
"""
B, N, _ = a.shape
a_d = a.to(torch.float32)
f_d = f.to(torch.float32)
# Physical wavenumbers on [0,1]²: d/dx ↔ multiply by 2πi·k_int
k_int = torch.fft.fftfreq(N, d=1.0 / N, device=TORCH_DEVICE)
kx, ky = torch.meshgrid(2 * math.pi * k_int, 2 * math.pi * k_int, indexing="ij")
lap_pos = kx ** 2 + ky ** 2
lap_pos[0, 0] = 1.0
a_mean = a_d.mean(dim=(1, 2), keepdim=True)
def apply_A(v):
"""Compute A·v = -∇·(a∇v) via spectral differentiation."""
v_hat = torch.fft.fft2(v, dim=(1, 2))
vx = torch.fft.ifft2(1j * kx[None] * v_hat, dim=(1, 2)).real
vy = torch.fft.ifft2(1j * ky[None] * v_hat, dim=(1, 2)).real
Av = -torch.fft.ifft2(
1j * kx[None] * torch.fft.fft2(a_d * vx, dim=(1, 2))
+ 1j * ky[None] * torch.fft.fft2(a_d * vy, dim=(1, 2)),
dim=(1, 2),
).real
return Av
def apply_P_inv(r):
"""Preconditioned step: P⁻¹r = r̂ / (a_mean · |k|²)."""
r_hat = torch.fft.fft2(r, dim=(1, 2))
Pr = torch.fft.ifft2(r_hat / (a_mean * lap_pos[None]), dim=(1, 2)).real
Pr -= Pr.mean(dim=(1, 2), keepdim=True) # project to zero-mean space
return Pr
u = torch.zeros((B, N, N), dtype=torch.float32, device=TORCH_DEVICE)
r = f_d - apply_A(u)
r -= r.mean(dim=(1, 2), keepdim=True)
z = apply_P_inv(r)
p = z.clone()
rz_old = torch.sum(r * z, dim=(1, 2), keepdim=True)
for _ in range(n_iter):
Ap = apply_A(p)
pAp = torch.sum(p * Ap, dim=(1, 2), keepdim=True)
alpha = rz_old / (pAp + 1e-16)
u += alpha * p
r -= alpha * Ap
# Project residual to zero-mean space to avoid drift
r -= r.mean(dim=(1, 2), keepdim=True)
z = apply_P_inv(r)
rz_new = torch.sum(r * z, dim=(1, 2), keepdim=True)
if torch.max(torch.abs(rz_new)) < 1e-18:
break
beta = rz_new / (rz_old + 1e-16)
p = z + beta * p
rz_old = rz_new
u -= u.mean(dim=(1, 2), keepdim=True)
return u.to(torch.float32)
def solve_ns_2d_batch(
w0: torch.Tensor,
nu: float = NS_NU,
T: float = NS_T,
n_steps: int = NS_NSTEPS,
) -> torch.Tensor:
"""Stable 2D Navier-Stokes solver (vorticity form) on [0, 2pi)^2."""
B, N, _ = w0.shape
dt = T / n_steps
k = torch.fft.fftfreq(N, device=TORCH_DEVICE)
k1, k2 = torch.meshgrid(k, k, indexing="ij")
laplacian = -(k1 ** 2 + k2 ** 2)
laplacian[0, 0] = 1.0
cutoff = (2 * N) // 3 # 2/3-rule dealiasing
w_hat = torch.fft.fft2(w0.to(torch.float32), dim=(1, 2))
for _ in range(n_steps):
# Dealias
w_hat_d = w_hat.clone()
mask = (torch.abs(k1 * N) > cutoff) | (torch.abs(k2 * N) > cutoff)
w_hat_d[:, mask] = 0.0
# Stream function: Δψ = ω
psi_hat = w_hat_d / laplacian
psi_hat[:, 0, 0] = 0.0
# Velocity: u = (∂ψ/∂y, -∂ψ/∂x)
u = torch.fft.ifft2(1j * k2 * psi_hat).real
v = torch.fft.ifft2(-1j * k1 * psi_hat).real
# Non-linear term: (u·∇)ω
wx = torch.fft.ifft2(1j * k1 * w_hat_d).real
wy = torch.fft.ifft2(1j * k2 * w_hat_d).real
nonlin = torch.fft.fft2(u * wx + v * wy)
# Semi-implicit step: diffusion implicit, advection explicit
w_hat = (w_hat - dt * nonlin) / (1.0 - dt * nu * laplacian)
return torch.fft.ifft2(w_hat, dim=(1, 2)).real.to(torch.float32)
def solve_allen_cahn_2d_batch(u0: torch.Tensor, epsilon: float = AC_EPSILON, T: float = AC_T, n_steps: int = AC_NSTEPS) -> torch.Tensor:
"""Semi-implicit spectral solver for Allen-Cahn 2D."""
B, N, _ = u0.shape
dt = T / n_steps
k = torch.fft.fftfreq(N, device=TORCH_DEVICE)
k1, k2 = torch.meshgrid(k, k, indexing="ij")
laplacian = -(k1 ** 2 + k2 ** 2)
u_hat = torch.fft.fft2(u0.to(torch.float32), dim=(1, 2))
for _ in range(n_steps):
u = torch.fft.ifft2(u_hat).real
nonlin = torch.fft.fft2(u**3 - u)
u_hat = (u_hat - dt * nonlin) / (1.0 - dt * epsilon * laplacian)
return torch.fft.ifft2(u_hat).real.to(torch.float32)
def solve_swe_2d_batch(h0: torch.Tensor, T: float = SWE_T, n_steps: int = SWE_NSTEPS) -> torch.Tensor:
"""Spectral solver for 2D Shallow Water Equations (linearized height)."""
B, N, _ = h0.shape
dt = T / n_steps
k = torch.fft.fftfreq(N, device=TORCH_DEVICE)
k1, k2 = torch.meshgrid(k, k, indexing="ij")
# Spectral derivatives
ik1, ik2 = 1j * k1 * N, 1j * k2 * N
h_hat = torch.fft.fft2(h0.to(torch.float32), dim=(1, 2))
u_hat = torch.zeros_like(h_hat)
v_hat = torch.zeros_like(h_hat)
for _ in range(n_steps):
h_prev, u_prev, v_prev = h_hat.clone(), u_hat.clone(), v_hat.clone()
# Continuity: dh/dt + d(hu)/dx + d(hv)/dy = 0 (linearized for speed)
h_hat = h_prev - dt * (ik1 * u_prev + ik2 * v_prev)
# Momentum
u_hat = u_prev - dt * (ik1 * SWE_G * h_prev)
v_hat = v_prev - dt * (ik2 * SWE_G * h_prev)
return torch.fft.ifft2(h_hat).real.to(torch.float32)
def solve_mhd_2d_batch(w0: torch.Tensor, a0: torch.Tensor, T: float = MHD_T, n_steps: int = MHD_NSTEPS) -> torch.Tensor:
"""Spectral solver for 2D incompressible MHD (vorticity-potential form)."""
B, N, _ = w0.shape
dt = T / n_steps
k = torch.fft.fftfreq(N, device=TORCH_DEVICE)
k1, k2 = torch.meshgrid(k, k, indexing="ij")
lap = -(k1 ** 2 + k2 ** 2); lap[0, 0] = 1.0
w_hat = torch.fft.fft2(w0.to(torch.float32), dim=(1, 2))
a_hat = torch.fft.fft2(a0.to(torch.float32), dim=(1, 2))
for _ in range(n_steps):
psi_hat = w_hat / lap; psi_hat[:, 0, 0] = 0.0
u = torch.fft.ifft2(1j * k2 * psi_hat).real
v = torch.fft.ifft2(-1j * k1 * psi_hat).real
bx = torch.fft.ifft2(1j * k2 * a_hat).real
by = torch.fft.ifft2(-1j * k1 * a_hat).real
aj = torch.fft.ifft2(lap * a_hat).real # Current density J
nonlin_w = torch.fft.fft2(u * torch.fft.ifft2(1j * k1 * w_hat).real + v * torch.fft.ifft2(1j * k2 * w_hat).real -
(bx * torch.fft.ifft2(1j * k1 * aj).real + by * torch.fft.ifft2(1j * k2 * aj).real))
nonlin_a = torch.fft.fft2(u * torch.fft.ifft2(1j * k1 * a_hat).real + v * torch.fft.ifft2(1j * k2 * a_hat).real)
w_hat = (w_hat - dt * nonlin_w) / (1.0 - dt * MHD_NU * lap)
a_hat = (a_hat - dt * nonlin_a) / (1.0 - dt * MHD_ETA * lap)
return torch.fft.ifft2(w_hat).real.to(torch.float32)
def solve_poisson_2d_batch(f: torch.Tensor) -> torch.Tensor:
"""Solve -Δu = f on [0, 1]² with periodic BCs using spectral method."""
B, N, _ = f.shape
f_d = f.to(torch.float32)
f_d -= f_d.mean(dim=(1, 2), keepdim=True) # ensure zero mean
k_int = torch.fft.fftfreq(N, d=1.0 / N, device=TORCH_DEVICE)
kx, ky = torch.meshgrid(2 * math.pi * k_int, 2 * math.pi * k_int, indexing="ij")
lap_pos = kx ** 2 + ky ** 2
lap_pos[0, 0] = 1.0 # avoid div by zero for DC mode
f_hat = torch.fft.fft2(f_d, dim=(1, 2))
u_hat = f_hat / lap_pos[None]
u_hat[:, 0, 0] = 0.0 # set DC mode to zero
u = torch.fft.ifft2(u_hat, dim=(1, 2)).real
return u.to(torch.float32)
def solve_reionization_1d_batch(f: torch.Tensor, T: float = 0.5, n_steps: int = 100) -> torch.Tensor:
"""Simple 1D ionization front model."""
B, N = f.shape
dt = T / n_steps
c = 1.0
alpha = 0.1
u = f.to(torch.float32)
dx = 1.0 / N
for _ in range(n_steps):
# Upwind advection
u_shifted = torch.roll(u, 1, dims=1)
u_x = (u - u_shifted) / dx
u = u - dt * (c * u_x + alpha * u**2)
return u.to(torch.float32)
def solve_ellipse_2d_batch(
params: torch.Tensor,
N: int = 64
) -> torch.Tensor:
"""Semi-analytical potential flow solver for an ellipse in 2D."""
B = params.shape[0]
x = torch.linspace(-1, 1, N, device=TORCH_DEVICE)
y = torch.linspace(-1, 1, N, device=TORCH_DEVICE)
yy, xx = torch.meshgrid(y, x, indexing="ij") # meshgrid(y,x) to match np.meshgrid(x,y) with indexing='ij' logic?
# Wait, np.meshgrid(x,y, indexing='ij') returns [N,N] where xx[i,j] = x[i] and yy[i,j] = y[j]
# torch.meshgrid(x,y, indexing='ij') returns same.
xx, yy = torch.meshgrid(x, y, indexing="ij")
# Pressure fields
p = torch.zeros((B, N, N), dtype=torch.float32, device=TORCH_DEVICE)
for i in range(B):
a, b, alpha = params[i]
# Rotate coordinates
xr = xx * torch.cos(alpha) + yy * torch.sin(alpha)
yr = -xx * torch.sin(alpha) + yy * torch.cos(alpha)
# Ellipse SDF
sdf = torch.sqrt((xr/a)**2 + (yr/b)**2) - 1.0
# Potential flow around ellipse (velocity magnitude approximation)
v_mag = 1.0 + (a + b) / (a * torch.sin(torch.atan2(yr, xr))**2 + b * torch.cos(torch.atan2(yr, xr))**2 + 1e-6)
v_mag[sdf < 0] = 0.0 # interior
# Bernoulli pressure: p = 0.5 * rho * (U^2 - v^2)
p[i] = 0.5 * (1.0 - v_mag**2)
return p
def _ellipse_ic(n: int, rng: np.random.RandomState) -> np.ndarray:
"""Random ellipse parameters [a, b, alpha]."""
a = rng.uniform(0.2, 0.5, size=(n, 1))
b = rng.uniform(0.1, 0.3, size=(n, 1))
alpha = rng.uniform(0, np.pi, size=(n, 1))
return np.concatenate([a, b, alpha], axis=1).astype(np.float32)
# ── IC generators ─────────────────────────────────────────────────────────────
def _kdv_ic(n: int, N: int, rng: np.random.RandomState) -> np.ndarray:
"""Random smooth ICs for KdV — same Fourier-series generator as Burgers."""
return _random_ic_np(n, N, rng, n_modes=8)
def _wave_ic(n: int, N: int, rng: np.random.RandomState) -> tuple[np.ndarray, np.ndarray]:
"""Random ICs for wave equation: (u0, ∂u/∂t|₀)."""
u0 = _random_ic_np(n, N, rng, n_modes=8)
ut0 = np.zeros_like(u0) # released from rest: u(x,0)=u0(x), u_t(x,0)=0
return u0, ut0
def _darcy_fix_ic(n: int, N: int, rng: np.random.RandomState
) -> tuple[np.ndarray, np.ndarray]:
"""ICs for corrected Darcy benchmark."""
# GRF with zero mean and scale 0.5
z = _random_ic_2d(n, N, rng, n_modes=5, scale=0.5, offset=0.0)
a = np.exp(z)
# Generate fixed source f (same for all samples in all splits)
f_rng = np.random.RandomState(12345)
f_single = _random_ic_2d(1, N, f_rng, n_modes=DARCY_FIX_MODES_F, scale=1.0, offset=0.0)
f = np.broadcast_to(f_single, (n, N, N))
return a, f
def _ns_fix_ic(n: int, N: int, rng: np.random.RandomState) -> np.ndarray:
"""ICs for corrected NS benchmark: vorticity with small amplitude."""
return _random_ic_2d(n, N, rng, n_modes=4, scale=NS_SCALE, offset=0.0)
def _swe_ic(n: int, N: int, rng: np.random.RandomState) -> np.ndarray:
return _random_ic_2d(n, N, rng, n_modes=3, scale=0.1, offset=1.0)
def _allen_cahn_ic(n: int, N: int, rng: np.random.RandomState) -> np.ndarray:
return _random_ic_2d(n, N, rng, n_modes=8, scale=0.5, offset=0.0)
def _mhd_ic(n: int, N: int, rng: np.random.RandomState) -> tuple[np.ndarray, np.ndarray]:
w0 = _random_ic_2d(n, N, rng, n_modes=4, scale=0.1, offset=0.0)
a0 = _random_ic_2d(n, N, rng, n_modes=4, scale=0.1, offset=0.0)
return w0, a0
# ── Dataset generation ─────────────────────────────────────────────────────────
def _generate_ext_dataset(benchmark: str, n: int, seed: int) -> tuple:
rng = np.random.RandomState(seed)
# To keep the dashboard alive and provide visibility, we generate in chunks
chunk_size = 100 if "2d" in benchmark else 1000
all_inputs = []
all_targets = []
import sys
for i in range(0, n, chunk_size):
curr_n = min(chunk_size, n - i)
if i > 0 or n > chunk_size:
print(f" [{benchmark}] Generating samples {i}/{n}...", end="\r")
sys.stdout.flush()
if benchmark == "kdv_1d":
inp_t = _kdv_ic(curr_n, GRID_SIZE, rng)
tgt_t = solve_kdv_batch(inp_t, T=KDV_T, n_steps=KDV_NSTEPS)
inp, tgt = inp_t.cpu().numpy(), tgt_t.cpu().numpy()
elif benchmark == "wave_1d":
u0_np, ut0_np = _wave_ic(curr_n, GRID_SIZE, rng)
u0_t, ut0_t = torch.from_numpy(u0_np).to(TORCH_DEVICE), torch.from_numpy(ut0_np).to(TORCH_DEVICE)
tgt_t = solve_wave_batch(u0_t, ut0_t, c=WAVE_C, T=WAVE_T, n_steps=WAVE_NSTEPS)
inp, tgt = u0_np, tgt_t.cpu().numpy()
elif benchmark == "darcy_2d":
a_np, f_np = _darcy_fix_ic(curr_n, GRID_SIZE, rng)
a_t, f_t = torch.from_numpy(a_np).to(TORCH_DEVICE), torch.from_numpy(f_np).to(TORCH_DEVICE)
tgt_t = solve_darcy_2d_batch(a_t, f_t)
inp, tgt = a_np[..., None], tgt_t.cpu().numpy()[..., None]
elif benchmark == "ns_2d":
w0_np = _ns_fix_ic(curr_n, GRID_SIZE, rng)
w0_t = torch.from_numpy(w0_np).to(TORCH_DEVICE)
tgt_t = solve_ns_2d_batch(w0_t)
inp, tgt = w0_np[..., None], tgt_t.cpu().numpy()[..., None]
elif benchmark == "swe_2d":
h0_np = _swe_ic(curr_n, GRID_SIZE, rng)
h0_t = torch.from_numpy(h0_np).to(TORCH_DEVICE)
tgt_t = solve_swe_2d_batch(h0_t)
inp, tgt = h0_np[..., None], tgt_t.cpu().numpy()[..., None]
elif benchmark == "allen_cahn_2d":
u0_np = _allen_cahn_ic(curr_n, GRID_SIZE, rng)
u0_t = torch.from_numpy(u0_np).to(TORCH_DEVICE)
tgt_t = solve_allen_cahn_2d_batch(u0_t)
inp, tgt = u0_np[..., None], tgt_t.cpu().numpy()[..., None]
elif benchmark == "mhd_2d":
w0_np, a0_np = _mhd_ic(curr_n, GRID_SIZE, rng)
w0_t, a0_t = torch.from_numpy(w0_np).to(TORCH_DEVICE), torch.from_numpy(a0_np).to(TORCH_DEVICE)
tgt_t = solve_mhd_2d_batch(w0_t, a0_t)
inp = np.stack([w0_np, a0_np], axis=-1)
tgt = tgt_t.cpu().numpy()[..., None]
elif benchmark == "burgers_nu_01":
inp_np = _random_ic_np(curr_n, GRID_SIZE, rng)
from data.prepare import solve_burgers_batch
inp_t = torch.from_numpy(inp_np).to(TORCH_DEVICE)
tgt_t = solve_burgers_batch(inp_t, nu=0.1)
inp, tgt = inp_np, tgt_t.cpu().numpy()
elif benchmark == "burgers_nu_001":
inp_np = _random_ic_np(curr_n, GRID_SIZE, rng)
from data.prepare import solve_burgers_batch
inp_t = torch.from_numpy(inp_np).to(TORCH_DEVICE)
tgt_t = solve_burgers_batch(inp_t, nu=0.01)
inp, tgt = inp_np, tgt_t.cpu().numpy()
elif benchmark == "poisson_2d":
f_np = _random_ic_2d(curr_n, GRID_SIZE, rng, n_modes=5, scale=1.0)
f_t = torch.from_numpy(f_np).to(TORCH_DEVICE)
tgt_t = solve_poisson_2d_batch(f_t)
inp, tgt = f_np[..., None], tgt_t.cpu().numpy()[..., None]
elif benchmark == "reionization_1d":
f_np = _random_ic_np(curr_n, GRID_SIZE, rng, n_modes=3) * 1.0 + 0.5
f_t = torch.from_numpy(f_np).to(TORCH_DEVICE)
tgt_t = solve_reionization_1d_batch(f_t)
inp, tgt = f_np, tgt_t.cpu().numpy()
elif benchmark == "ellipse_2d":
params_np = _ellipse_ic(curr_n, rng)
params_t = torch.from_numpy(params_np).to(TORCH_DEVICE)
# Input is the SDF of the ellipse
x = np.linspace(-1, 1, GRID_SIZE)
y = np.linspace(-1, 1, GRID_SIZE)
xx, yy = np.meshgrid(x, y)
inp_list = []
for j in range(curr_n):
a, b, alpha = params_np[j]
xr = xx * np.cos(alpha) + yy * np.sin(alpha)
yr = -xx * np.sin(alpha) + yy * np.cos(alpha)
sdf = np.sqrt((xr/a)**2 + (yr/b)**2) - 1.0
inp_list.append(sdf[..., None])
inp = np.array(inp_list)
tgt_t = solve_ellipse_2d_batch(params_t, GRID_SIZE)
tgt = tgt_t.cpu().numpy()[..., None]
else:
raise ValueError(f"Unknown extended benchmark: {benchmark!r}")
all_inputs.append(inp)
all_targets.append(tgt)
if n > chunk_size:
print(f" [{benchmark}] Generating samples {n}/{n}... Done.")
sys.stdout.flush()
return np.concatenate(all_inputs, axis=0), np.concatenate(all_targets, axis=0)
def _get_ext_val_cache(benchmark: str) -> str:
return os.path.join(CACHE_DIR, f"{benchmark}_val_N{GRID_SIZE}_ext.npz")
def _load_or_gen_ext_val(benchmark: str) -> tuple:
os.makedirs(CACHE_DIR, exist_ok=True)
cache = _get_ext_val_cache(benchmark)
if os.path.exists(cache):
data = np.load(cache)
return data["inputs"], data["targets"]
print(f"Generating {benchmark} val set ({N_VAL} samples, seed={VAL_SEED})…")
t0 = time.time()
inp, tgt = _generate_ext_dataset(benchmark, N_VAL, VAL_SEED)
np.savez(cache, inputs=inp, targets=tgt)
print(f" Cached {N_VAL} samples in {time.time()-t0:.1f}s → {cache}")
return inp, tgt
def _get_ext_train_cache_path(benchmark: str) -> str:
return os.path.join(CACHE_DIR, f"{benchmark}_train_N{N_TRAIN}_ext.npz")
_ext_train_cache: dict = {}
def _get_ext_train(benchmark: str) -> tuple:
if benchmark not in _ext_train_cache:
os.makedirs(CACHE_DIR, exist_ok=True)
cache_path = _get_ext_train_cache_path(benchmark)
if os.path.exists(cache_path):
data = np.load(cache_path)
_ext_train_cache[benchmark] = (data["inputs"], data["targets"])
else:
print(f"Generating {benchmark} train data ({N_TRAIN} samples)…")
t0 = time.time()
inputs, targets = _generate_ext_dataset(benchmark, N_TRAIN, TRAIN_SEED)
np.savez(cache_path, inputs=inputs, targets=targets)
print(f" {N_TRAIN} samples in {time.time()-t0:.1f}s → {cache_path}")
_ext_train_cache[benchmark] = (inputs, targets)
return _ext_train_cache[benchmark]
from data.prepare import PDEDataset
# ── Public dataloader (same interface as prepare.make_dataloader) ─────────────
def make_ext_dataloader(benchmark: str, split: str, batch_size: int,
seed: int | None = None, **kwargs):
"""Yielding (inputs, targets) as PyTorch tensors."""
assert split in ("train", "val")
if split == "val":
inp, tgt = _load_or_gen_ext_val(benchmark)
dataset = PDEDataset(torch.from_numpy(inp), torch.from_numpy(tgt))
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True
)
else:
inp, tgt = _get_ext_train(benchmark)
dataset = PDEDataset(torch.from_numpy(inp), torch.from_numpy(tgt))
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
generator=torch.Generator().manual_seed(seed if seed is not None else 12345)
)
def infinite_loader():
while True:
for batch in loader:
yield batch
return infinite_loader()
def evaluate_l2_rel_ext(benchmark: str, model, batch_size: int = 64) -> float:
"""Mean relative L2 error on fixed val set. Same metric as prepare.py."""
val_loader = make_ext_dataloader(benchmark, "val", batch_size)
total_err = 0.0
total_norm = 0.0
model.eval()
with torch.no_grad():
for x, y in val_loader:
x, y = x.to(TORCH_DEVICE), y.to(TORCH_DEVICE)
y_pred = model(x)
diff = (y_pred - y).float()
y_f = y.float()
axes = tuple(range(1, y.ndim))
err = torch.sqrt(torch.mean(diff ** 2, dim=axes))
nrm = torch.sqrt(torch.mean(y_f ** 2, dim=axes))
total_err += torch.sum(err).item()
total_norm += torch.sum(nrm).item()
return total_err / max(total_norm, 1e-8)
# ── Extended SOTA targets ──────────────────────────────────────────────────────
EXT_SOTA = {
"kdv_1d": 0.010, # FNO on KdV, Tran et al. 2023
"wave_1d": 0.005, # Wave equation: easier than Burgers, FNO near-exact
"darcy_2d": 0.0108, # Li et al. 2020 FNO on Darcy (proper solver)
"ns_2d": 0.0128, # Li et al. 2020 FNO on NS (T=1, nu=1e-2)
"ns_hre_2d": 0.0700, # Estimated SOTA for Re=1000
"swe_2d": 0.0020, # FNO on SWE
"allen_cahn_2d": 0.020, # SOTA near 0.02
"mhd_2d": 0.0350, # MHD targets from PhysicsNeMo
}
# ── Benchmark metadata ────────────────────────────────────────────────────────
EXT_BENCHMARK_INFO = {
"kdv_1d": {
"pde": "u_t + u·u_x + u_xxx = 0 (Korteweg-de Vries)",
"domain": "[0, 2π), periodic",
"ic_type": "smooth random Fourier series",
"solver": "ETDRK4 (exponential time differencing Runge-Kutta 4)",
"t_final": KDV_T,
"n_steps": KDV_NSTEPS,
"sota_model": "FNO",
"notes": "Soliton dynamics; FNO handles well due to periodicity",
},
"wave_1d": {
"pde": "u_tt = c² u_xx (1D wave, c=1)",
"domain": "[0, 2π), periodic",
"ic_type": "smooth random Fourier series for u0 and du/dt",
"solver": "Störmer-Verlet (symplectic, energy-conserving)",
"t_final": WAVE_T,
"n_steps": WAVE_NSTEPS,
"sota_model": "FNO",
"notes": "Linear PDE; FNO can achieve near-zero error easily",
},
"darcy_2d": {
"pde": "-∇·(a(x,y)∇u) = f (2D Darcy flow)",
"domain": "[0, 1]², periodic",
"ic_type": "GRF permeability a ∈ [0.6, 1.4]; zero-mean GRF source f",
"solver": "Richardson iteration + spectral preconditioner (40 iters)",
"t_final": None,
"n_steps": DARCY_FIX_N_ITER,
"sota_model": "FNO",
"notes": "Fixed: proper variable-coeff solve; prepare.py used only mean(a)",
"known_issue_in_prepare_py":
"solve_darcy_2d_batch uses a_avg (scalar) → u independent of spatial a; "
"f uses fixed seed=42 → u uncorrelated with model input a",
},
"ns_2d": {
"pde": "w_t + (u*grad)w = nu * Laplacian(w) (2D NS, vorticity form)",
"domain": "[0, 2π)², periodic",
"ic_type": "small-amplitude vorticity (scale=0.1) → CFL≈0.6 < 1",
"solver": "Semi-implicit Euler, 2/3-rule dealiasing, n_steps=1000",
"t_final": NS_T,
"n_steps": NS_NSTEPS,
"sota_model": "FNO",
"notes": "Fixed: IC scale 1.0→0.1 reduces CFL from 61 to ~0.6",
"known_issue_in_prepare_py":
"solve_ns_2d_batch uses IC scale=1.0 → max_velocity≈95 → "
"CFL≈61 → semi-implicit Euler explodes to NaN on step 1",
},
"swe_2d": {
"pde": "Shallow Water Equations (height-vorticity)",
"domain": "[0, 1]^2, periodic",
"ic_type": "Random height bumps",
"solver": "Spectral continuity + momentum",
"t_final": SWE_T,
"n_steps": SWE_NSTEPS,
"sota_model": "MemNO",
"notes": "Tests multi-scale wave dynamics",
},
"allen_cahn_2d": {
"pde": "Allen-Cahn Phase Separation",
"domain": "[0, 1]^2, periodic",
"ic_type": "High-frequency random noise",
"solver": "Semi-implicit spectral",
"t_final": AC_T,
"n_steps": AC_NSTEPS,
"sota_model": "FNO",
"notes": "Tests sharp interface capture",
},
"mhd_2d": {
"pde": "Magnetohydrodynamics (vorticity-potential)",
"domain": "[0, 1]^2, periodic",
"ic_type": "Orszag-Tang inspired random fields",
"solver": "Dual-field spectral",
"t_final": MHD_T,
"n_steps": MHD_NSTEPS,
"sota_model": "TFNO",
"notes": "Coupled fluid-magnetic dynamics",
},
"poisson_2d": {
"pde": "-Δu = f (2D Poisson equation)",
"domain": "[0, 1]^2, periodic",
"ic_type": "Zero-mean random source f",
"solver": "Exact spectral solver",
"t_final": None,
"n_steps": None,
"sota_model": "IterativeFNO",
"notes": "Fundamental elliptic PDE; tests precision and convergence stabilities",
},
"reionization_1d": {
"pde": "u_t + c·u_x = S(x) - alpha·u^2 (Cosmic Reionization toy model)",
"domain": "[0, 1], periodic",
"ic_type": "Source field S(x)",
"solver": "Upwind scheme + recombination",
"t_final": 0.5,
"n_steps": 100,
"sota_model": "PINN",
"notes": "Non-linear advection-reaction; mimics ionization front propagation",
},
"ellipse_2d": {
"pde": "Incompressible laminar flow around ellipse (surface pressure)",
"domain": "[-1, 1]^2",
"ic_type": "Random ellipse geometry (a, b, alpha)",
"solver": "Potential flow analytical approximation",
"t_final": None,
"n_steps": None,
"sota_model": "SAR",
"notes": "Tests geometry-to-distribution mapping as proposed in Lino & Thuerey (2026)",
},
}
if __name__ == "__main__":
import time
print("Extended benchmarks available:", sorted(EXT_BENCHMARKS))
for bm in sorted(EXT_BENCHMARKS):
info = EXT_BENCHMARK_INFO.get(bm, {"pde": "Unknown", "solver": "Unknown"})
print(f"\n{bm}:")
print(f" PDE : {info['pde']}")
print(f" Solver: {info['solver']}")
print(f" SOTA : ~{EXT_SOTA.get(bm, 0.0):.4f} rel-L2")
# Quick smoke test: generate 4 samples
t0 = time.time()
inp, tgt = _generate_ext_dataset(bm, 4, seed=0)
elapsed = time.time() - t0
print(f" Shape : in={inp.shape} → out={tgt.shape}")
print(f" Gen : {elapsed:.2f}s for 4 samples")
print(f" NaN? : in={np.isnan(inp).any()} out={np.isnan(tgt).any()}")
if bm in ("darcy_2d",):
from scipy.stats import pearsonr
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
r, _ = pearsonr(inp[0].flatten(), tgt[0].flatten())
print(f" corr(a[0], u[0]): {r:.4f} (should be non-trivial for learnable data)")