"""Extended benchmark definitions for SciML experiments. Adds KdV, Wave, and corrected 2D benchmarks on top of prepare.py. Supported benchmarks: "kdv_1d" - Korteweg-de Vries soliton dynamics (ETDRK4 solver) "wave_1d" - 1D wave equation u_tt = c^2 u_xx (Stormer-Verlet) "darcy_2d" - 2D Darcy with proper variable-coeff solver (Richardson iter) "ns_2d" - 2D Navier-Stokes with stable IC amplitude (CFL < 1) Why darcy_2d and ns_2d? prepare.py's darcy_2d solver uses only mean(a) → loses all spatial info; source term f uses a FIXED seed independent of a → u is uncorrelated with a. prepare.py's ns_2d solver uses IC scale=1.0 → CFL≈61 → immediate NaN. Both benchmarks are broken at the data level; prepare.py is read-only. These fixed versions provide correct, learnable benchmarks. Data interface is identical to prepare.py: make_ext_dataloader(benchmark, split, batch_size) evaluate_l2_rel_ext(benchmark, model) References: KdV: Tran et al. (2023) "Factorized Fourier Neural Operators" Wave: Rahman et al. (2022) U-NO Darcy fix: Li et al. (2020) FNO paper, original Darcy benchmark setup NS fix: standard semi-implicit spectral NS with CFL-stable parameters """ import math import os import time import torch import numpy as np from data.prepare import ( GRID_SIZE, TIME_BUDGET, N_TRAIN, N_VAL, VAL_SEED, TRAIN_SEED, CACHE_DIR, solve_kdv_batch, solve_wave_batch, _random_ic, _random_ic_np, _random_ic_2d, ) # ── Constants ───────────────────────────────────────────────────────────────── EXT_BENCHMARKS = {"kdv_1d", "wave_1d", "darcy_2d", "ns_2d", "swe_2d", "allen_cahn_2d", "mhd_2d", "burgers_nu_01", "burgers_nu_001", "poisson_2d", "reionization_1d"} EXT_N_CHANNELS = { "kdv_1d": 1, "wave_1d": 1, "darcy_2d": 1, "ns_2d": 1, "ns_hre_2d": 1, "swe_2d": 1, "allen_cahn_2d": 1, "mhd_2d": 2, # Vorticity (w) and Magnetic Potential (a) "burgers_nu_01": 1, "burgers_nu_001": 1, "poisson_2d": 1, "reionization_1d": 1, "ellipse_2d": 1, } # KdV parameters KDV_T = 1.0 # final time KDV_NSTEPS = 1000 # ETDRK4 steps # Wave parameters WAVE_C = 1.0 # wave speed WAVE_T = 1.0 # final time WAVE_NSTEPS = 400 # Störmer-Verlet steps # Darcy fix parameters DARCY_FIX_N_ITER = 40 # PCG iterations DARCY_FIX_MODES_F = 5 # source term Fourier modes # NS fix parameters — reduces CFL from ~61 to ~0.6 NS_SCALE = 0.1 # IC vorticity amplitude (vs 1.0 in prepare.py -> 10x smaller) NS_NSTEPS = 1000 # time steps (vs 100) — gives dt=0.001, CFL≈0.6 NS_NU = 1e-2 # kinematic viscosity (same as original) NS_T = 1.0 # final time # Allen-Cahn parameters AC_EPSILON = 0.01 AC_T = 0.5 AC_NSTEPS = 200 # SWE parameters SWE_G = 9.81 SWE_T = 0.2 SWE_NSTEPS = 200 # MHD parameters MHD_NU = 1e-3 MHD_ETA = 1e-3 MHD_T = 0.5 MHD_NSTEPS = 500 from core.device import DEVICE, TORCH_DEVICE # ── 2D Solvers (corrected) ──────────────────────────────────────────────────── def solve_darcy_2d_batch( a: torch.Tensor, f: torch.Tensor, n_iter: int = DARCY_FIX_N_ITER, ) -> torch.Tensor: """Solve -∇·(a(x,y)∇u) = f on [0,1]² with periodic BCs. Uses Preconditioned Conjugate Gradient (PCG) with the constant-coefficient Poisson operator P = a_mean·(-Δ) as a preconditioner. """ B, N, _ = a.shape a_d = a.to(torch.float32) f_d = f.to(torch.float32) # Physical wavenumbers on [0,1]²: d/dx ↔ multiply by 2πi·k_int k_int = torch.fft.fftfreq(N, d=1.0 / N, device=TORCH_DEVICE) kx, ky = torch.meshgrid(2 * math.pi * k_int, 2 * math.pi * k_int, indexing="ij") lap_pos = kx ** 2 + ky ** 2 lap_pos[0, 0] = 1.0 a_mean = a_d.mean(dim=(1, 2), keepdim=True) def apply_A(v): """Compute A·v = -∇·(a∇v) via spectral differentiation.""" v_hat = torch.fft.fft2(v, dim=(1, 2)) vx = torch.fft.ifft2(1j * kx[None] * v_hat, dim=(1, 2)).real vy = torch.fft.ifft2(1j * ky[None] * v_hat, dim=(1, 2)).real Av = -torch.fft.ifft2( 1j * kx[None] * torch.fft.fft2(a_d * vx, dim=(1, 2)) + 1j * ky[None] * torch.fft.fft2(a_d * vy, dim=(1, 2)), dim=(1, 2), ).real return Av def apply_P_inv(r): """Preconditioned step: P⁻¹r = r̂ / (a_mean · |k|²).""" r_hat = torch.fft.fft2(r, dim=(1, 2)) Pr = torch.fft.ifft2(r_hat / (a_mean * lap_pos[None]), dim=(1, 2)).real Pr -= Pr.mean(dim=(1, 2), keepdim=True) # project to zero-mean space return Pr u = torch.zeros((B, N, N), dtype=torch.float32, device=TORCH_DEVICE) r = f_d - apply_A(u) r -= r.mean(dim=(1, 2), keepdim=True) z = apply_P_inv(r) p = z.clone() rz_old = torch.sum(r * z, dim=(1, 2), keepdim=True) for _ in range(n_iter): Ap = apply_A(p) pAp = torch.sum(p * Ap, dim=(1, 2), keepdim=True) alpha = rz_old / (pAp + 1e-16) u += alpha * p r -= alpha * Ap # Project residual to zero-mean space to avoid drift r -= r.mean(dim=(1, 2), keepdim=True) z = apply_P_inv(r) rz_new = torch.sum(r * z, dim=(1, 2), keepdim=True) if torch.max(torch.abs(rz_new)) < 1e-18: break beta = rz_new / (rz_old + 1e-16) p = z + beta * p rz_old = rz_new u -= u.mean(dim=(1, 2), keepdim=True) return u.to(torch.float32) def solve_ns_2d_batch( w0: torch.Tensor, nu: float = NS_NU, T: float = NS_T, n_steps: int = NS_NSTEPS, ) -> torch.Tensor: """Stable 2D Navier-Stokes solver (vorticity form) on [0, 2pi)^2.""" B, N, _ = w0.shape dt = T / n_steps k = torch.fft.fftfreq(N, device=TORCH_DEVICE) k1, k2 = torch.meshgrid(k, k, indexing="ij") laplacian = -(k1 ** 2 + k2 ** 2) laplacian[0, 0] = 1.0 cutoff = (2 * N) // 3 # 2/3-rule dealiasing w_hat = torch.fft.fft2(w0.to(torch.float32), dim=(1, 2)) for _ in range(n_steps): # Dealias w_hat_d = w_hat.clone() mask = (torch.abs(k1 * N) > cutoff) | (torch.abs(k2 * N) > cutoff) w_hat_d[:, mask] = 0.0 # Stream function: Δψ = ω psi_hat = w_hat_d / laplacian psi_hat[:, 0, 0] = 0.0 # Velocity: u = (∂ψ/∂y, -∂ψ/∂x) u = torch.fft.ifft2(1j * k2 * psi_hat).real v = torch.fft.ifft2(-1j * k1 * psi_hat).real # Non-linear term: (u·∇)ω wx = torch.fft.ifft2(1j * k1 * w_hat_d).real wy = torch.fft.ifft2(1j * k2 * w_hat_d).real nonlin = torch.fft.fft2(u * wx + v * wy) # Semi-implicit step: diffusion implicit, advection explicit w_hat = (w_hat - dt * nonlin) / (1.0 - dt * nu * laplacian) return torch.fft.ifft2(w_hat, dim=(1, 2)).real.to(torch.float32) def solve_allen_cahn_2d_batch(u0: torch.Tensor, epsilon: float = AC_EPSILON, T: float = AC_T, n_steps: int = AC_NSTEPS) -> torch.Tensor: """Semi-implicit spectral solver for Allen-Cahn 2D.""" B, N, _ = u0.shape dt = T / n_steps k = torch.fft.fftfreq(N, device=TORCH_DEVICE) k1, k2 = torch.meshgrid(k, k, indexing="ij") laplacian = -(k1 ** 2 + k2 ** 2) u_hat = torch.fft.fft2(u0.to(torch.float32), dim=(1, 2)) for _ in range(n_steps): u = torch.fft.ifft2(u_hat).real nonlin = torch.fft.fft2(u**3 - u) u_hat = (u_hat - dt * nonlin) / (1.0 - dt * epsilon * laplacian) return torch.fft.ifft2(u_hat).real.to(torch.float32) def solve_swe_2d_batch(h0: torch.Tensor, T: float = SWE_T, n_steps: int = SWE_NSTEPS) -> torch.Tensor: """Spectral solver for 2D Shallow Water Equations (linearized height).""" B, N, _ = h0.shape dt = T / n_steps k = torch.fft.fftfreq(N, device=TORCH_DEVICE) k1, k2 = torch.meshgrid(k, k, indexing="ij") # Spectral derivatives ik1, ik2 = 1j * k1 * N, 1j * k2 * N h_hat = torch.fft.fft2(h0.to(torch.float32), dim=(1, 2)) u_hat = torch.zeros_like(h_hat) v_hat = torch.zeros_like(h_hat) for _ in range(n_steps): h_prev, u_prev, v_prev = h_hat.clone(), u_hat.clone(), v_hat.clone() # Continuity: dh/dt + d(hu)/dx + d(hv)/dy = 0 (linearized for speed) h_hat = h_prev - dt * (ik1 * u_prev + ik2 * v_prev) # Momentum u_hat = u_prev - dt * (ik1 * SWE_G * h_prev) v_hat = v_prev - dt * (ik2 * SWE_G * h_prev) return torch.fft.ifft2(h_hat).real.to(torch.float32) def solve_mhd_2d_batch(w0: torch.Tensor, a0: torch.Tensor, T: float = MHD_T, n_steps: int = MHD_NSTEPS) -> torch.Tensor: """Spectral solver for 2D incompressible MHD (vorticity-potential form).""" B, N, _ = w0.shape dt = T / n_steps k = torch.fft.fftfreq(N, device=TORCH_DEVICE) k1, k2 = torch.meshgrid(k, k, indexing="ij") lap = -(k1 ** 2 + k2 ** 2); lap[0, 0] = 1.0 w_hat = torch.fft.fft2(w0.to(torch.float32), dim=(1, 2)) a_hat = torch.fft.fft2(a0.to(torch.float32), dim=(1, 2)) for _ in range(n_steps): psi_hat = w_hat / lap; psi_hat[:, 0, 0] = 0.0 u = torch.fft.ifft2(1j * k2 * psi_hat).real v = torch.fft.ifft2(-1j * k1 * psi_hat).real bx = torch.fft.ifft2(1j * k2 * a_hat).real by = torch.fft.ifft2(-1j * k1 * a_hat).real aj = torch.fft.ifft2(lap * a_hat).real # Current density J nonlin_w = torch.fft.fft2(u * torch.fft.ifft2(1j * k1 * w_hat).real + v * torch.fft.ifft2(1j * k2 * w_hat).real - (bx * torch.fft.ifft2(1j * k1 * aj).real + by * torch.fft.ifft2(1j * k2 * aj).real)) nonlin_a = torch.fft.fft2(u * torch.fft.ifft2(1j * k1 * a_hat).real + v * torch.fft.ifft2(1j * k2 * a_hat).real) w_hat = (w_hat - dt * nonlin_w) / (1.0 - dt * MHD_NU * lap) a_hat = (a_hat - dt * nonlin_a) / (1.0 - dt * MHD_ETA * lap) return torch.fft.ifft2(w_hat).real.to(torch.float32) def solve_poisson_2d_batch(f: torch.Tensor) -> torch.Tensor: """Solve -Δu = f on [0, 1]² with periodic BCs using spectral method.""" B, N, _ = f.shape f_d = f.to(torch.float32) f_d -= f_d.mean(dim=(1, 2), keepdim=True) # ensure zero mean k_int = torch.fft.fftfreq(N, d=1.0 / N, device=TORCH_DEVICE) kx, ky = torch.meshgrid(2 * math.pi * k_int, 2 * math.pi * k_int, indexing="ij") lap_pos = kx ** 2 + ky ** 2 lap_pos[0, 0] = 1.0 # avoid div by zero for DC mode f_hat = torch.fft.fft2(f_d, dim=(1, 2)) u_hat = f_hat / lap_pos[None] u_hat[:, 0, 0] = 0.0 # set DC mode to zero u = torch.fft.ifft2(u_hat, dim=(1, 2)).real return u.to(torch.float32) def solve_reionization_1d_batch(f: torch.Tensor, T: float = 0.5, n_steps: int = 100) -> torch.Tensor: """Simple 1D ionization front model.""" B, N = f.shape dt = T / n_steps c = 1.0 alpha = 0.1 u = f.to(torch.float32) dx = 1.0 / N for _ in range(n_steps): # Upwind advection u_shifted = torch.roll(u, 1, dims=1) u_x = (u - u_shifted) / dx u = u - dt * (c * u_x + alpha * u**2) return u.to(torch.float32) def solve_ellipse_2d_batch( params: torch.Tensor, N: int = 64 ) -> torch.Tensor: """Semi-analytical potential flow solver for an ellipse in 2D.""" B = params.shape[0] x = torch.linspace(-1, 1, N, device=TORCH_DEVICE) y = torch.linspace(-1, 1, N, device=TORCH_DEVICE) yy, xx = torch.meshgrid(y, x, indexing="ij") # meshgrid(y,x) to match np.meshgrid(x,y) with indexing='ij' logic? # Wait, np.meshgrid(x,y, indexing='ij') returns [N,N] where xx[i,j] = x[i] and yy[i,j] = y[j] # torch.meshgrid(x,y, indexing='ij') returns same. xx, yy = torch.meshgrid(x, y, indexing="ij") # Pressure fields p = torch.zeros((B, N, N), dtype=torch.float32, device=TORCH_DEVICE) for i in range(B): a, b, alpha = params[i] # Rotate coordinates xr = xx * torch.cos(alpha) + yy * torch.sin(alpha) yr = -xx * torch.sin(alpha) + yy * torch.cos(alpha) # Ellipse SDF sdf = torch.sqrt((xr/a)**2 + (yr/b)**2) - 1.0 # Potential flow around ellipse (velocity magnitude approximation) v_mag = 1.0 + (a + b) / (a * torch.sin(torch.atan2(yr, xr))**2 + b * torch.cos(torch.atan2(yr, xr))**2 + 1e-6) v_mag[sdf < 0] = 0.0 # interior # Bernoulli pressure: p = 0.5 * rho * (U^2 - v^2) p[i] = 0.5 * (1.0 - v_mag**2) return p def _ellipse_ic(n: int, rng: np.random.RandomState) -> np.ndarray: """Random ellipse parameters [a, b, alpha].""" a = rng.uniform(0.2, 0.5, size=(n, 1)) b = rng.uniform(0.1, 0.3, size=(n, 1)) alpha = rng.uniform(0, np.pi, size=(n, 1)) return np.concatenate([a, b, alpha], axis=1).astype(np.float32) # ── IC generators ───────────────────────────────────────────────────────────── def _kdv_ic(n: int, N: int, rng: np.random.RandomState) -> np.ndarray: """Random smooth ICs for KdV — same Fourier-series generator as Burgers.""" return _random_ic_np(n, N, rng, n_modes=8) def _wave_ic(n: int, N: int, rng: np.random.RandomState) -> tuple[np.ndarray, np.ndarray]: """Random ICs for wave equation: (u0, ∂u/∂t|₀).""" u0 = _random_ic_np(n, N, rng, n_modes=8) ut0 = np.zeros_like(u0) # released from rest: u(x,0)=u0(x), u_t(x,0)=0 return u0, ut0 def _darcy_fix_ic(n: int, N: int, rng: np.random.RandomState ) -> tuple[np.ndarray, np.ndarray]: """ICs for corrected Darcy benchmark.""" # GRF with zero mean and scale 0.5 z = _random_ic_2d(n, N, rng, n_modes=5, scale=0.5, offset=0.0) a = np.exp(z) # Generate fixed source f (same for all samples in all splits) f_rng = np.random.RandomState(12345) f_single = _random_ic_2d(1, N, f_rng, n_modes=DARCY_FIX_MODES_F, scale=1.0, offset=0.0) f = np.broadcast_to(f_single, (n, N, N)) return a, f def _ns_fix_ic(n: int, N: int, rng: np.random.RandomState) -> np.ndarray: """ICs for corrected NS benchmark: vorticity with small amplitude.""" return _random_ic_2d(n, N, rng, n_modes=4, scale=NS_SCALE, offset=0.0) def _swe_ic(n: int, N: int, rng: np.random.RandomState) -> np.ndarray: return _random_ic_2d(n, N, rng, n_modes=3, scale=0.1, offset=1.0) def _allen_cahn_ic(n: int, N: int, rng: np.random.RandomState) -> np.ndarray: return _random_ic_2d(n, N, rng, n_modes=8, scale=0.5, offset=0.0) def _mhd_ic(n: int, N: int, rng: np.random.RandomState) -> tuple[np.ndarray, np.ndarray]: w0 = _random_ic_2d(n, N, rng, n_modes=4, scale=0.1, offset=0.0) a0 = _random_ic_2d(n, N, rng, n_modes=4, scale=0.1, offset=0.0) return w0, a0 # ── Dataset generation ───────────────────────────────────────────────────────── def _generate_ext_dataset(benchmark: str, n: int, seed: int) -> tuple: rng = np.random.RandomState(seed) # To keep the dashboard alive and provide visibility, we generate in chunks chunk_size = 100 if "2d" in benchmark else 1000 all_inputs = [] all_targets = [] import sys for i in range(0, n, chunk_size): curr_n = min(chunk_size, n - i) if i > 0 or n > chunk_size: print(f" [{benchmark}] Generating samples {i}/{n}...", end="\r") sys.stdout.flush() if benchmark == "kdv_1d": inp_t = _kdv_ic(curr_n, GRID_SIZE, rng) tgt_t = solve_kdv_batch(inp_t, T=KDV_T, n_steps=KDV_NSTEPS) inp, tgt = inp_t.cpu().numpy(), tgt_t.cpu().numpy() elif benchmark == "wave_1d": u0_np, ut0_np = _wave_ic(curr_n, GRID_SIZE, rng) u0_t, ut0_t = torch.from_numpy(u0_np).to(TORCH_DEVICE), torch.from_numpy(ut0_np).to(TORCH_DEVICE) tgt_t = solve_wave_batch(u0_t, ut0_t, c=WAVE_C, T=WAVE_T, n_steps=WAVE_NSTEPS) inp, tgt = u0_np, tgt_t.cpu().numpy() elif benchmark == "darcy_2d": a_np, f_np = _darcy_fix_ic(curr_n, GRID_SIZE, rng) a_t, f_t = torch.from_numpy(a_np).to(TORCH_DEVICE), torch.from_numpy(f_np).to(TORCH_DEVICE) tgt_t = solve_darcy_2d_batch(a_t, f_t) inp, tgt = a_np[..., None], tgt_t.cpu().numpy()[..., None] elif benchmark == "ns_2d": w0_np = _ns_fix_ic(curr_n, GRID_SIZE, rng) w0_t = torch.from_numpy(w0_np).to(TORCH_DEVICE) tgt_t = solve_ns_2d_batch(w0_t) inp, tgt = w0_np[..., None], tgt_t.cpu().numpy()[..., None] elif benchmark == "swe_2d": h0_np = _swe_ic(curr_n, GRID_SIZE, rng) h0_t = torch.from_numpy(h0_np).to(TORCH_DEVICE) tgt_t = solve_swe_2d_batch(h0_t) inp, tgt = h0_np[..., None], tgt_t.cpu().numpy()[..., None] elif benchmark == "allen_cahn_2d": u0_np = _allen_cahn_ic(curr_n, GRID_SIZE, rng) u0_t = torch.from_numpy(u0_np).to(TORCH_DEVICE) tgt_t = solve_allen_cahn_2d_batch(u0_t) inp, tgt = u0_np[..., None], tgt_t.cpu().numpy()[..., None] elif benchmark == "mhd_2d": w0_np, a0_np = _mhd_ic(curr_n, GRID_SIZE, rng) w0_t, a0_t = torch.from_numpy(w0_np).to(TORCH_DEVICE), torch.from_numpy(a0_np).to(TORCH_DEVICE) tgt_t = solve_mhd_2d_batch(w0_t, a0_t) inp = np.stack([w0_np, a0_np], axis=-1) tgt = tgt_t.cpu().numpy()[..., None] elif benchmark == "burgers_nu_01": inp_np = _random_ic_np(curr_n, GRID_SIZE, rng) from data.prepare import solve_burgers_batch inp_t = torch.from_numpy(inp_np).to(TORCH_DEVICE) tgt_t = solve_burgers_batch(inp_t, nu=0.1) inp, tgt = inp_np, tgt_t.cpu().numpy() elif benchmark == "burgers_nu_001": inp_np = _random_ic_np(curr_n, GRID_SIZE, rng) from data.prepare import solve_burgers_batch inp_t = torch.from_numpy(inp_np).to(TORCH_DEVICE) tgt_t = solve_burgers_batch(inp_t, nu=0.01) inp, tgt = inp_np, tgt_t.cpu().numpy() elif benchmark == "poisson_2d": f_np = _random_ic_2d(curr_n, GRID_SIZE, rng, n_modes=5, scale=1.0) f_t = torch.from_numpy(f_np).to(TORCH_DEVICE) tgt_t = solve_poisson_2d_batch(f_t) inp, tgt = f_np[..., None], tgt_t.cpu().numpy()[..., None] elif benchmark == "reionization_1d": f_np = _random_ic_np(curr_n, GRID_SIZE, rng, n_modes=3) * 1.0 + 0.5 f_t = torch.from_numpy(f_np).to(TORCH_DEVICE) tgt_t = solve_reionization_1d_batch(f_t) inp, tgt = f_np, tgt_t.cpu().numpy() elif benchmark == "ellipse_2d": params_np = _ellipse_ic(curr_n, rng) params_t = torch.from_numpy(params_np).to(TORCH_DEVICE) # Input is the SDF of the ellipse x = np.linspace(-1, 1, GRID_SIZE) y = np.linspace(-1, 1, GRID_SIZE) xx, yy = np.meshgrid(x, y) inp_list = [] for j in range(curr_n): a, b, alpha = params_np[j] xr = xx * np.cos(alpha) + yy * np.sin(alpha) yr = -xx * np.sin(alpha) + yy * np.cos(alpha) sdf = np.sqrt((xr/a)**2 + (yr/b)**2) - 1.0 inp_list.append(sdf[..., None]) inp = np.array(inp_list) tgt_t = solve_ellipse_2d_batch(params_t, GRID_SIZE) tgt = tgt_t.cpu().numpy()[..., None] else: raise ValueError(f"Unknown extended benchmark: {benchmark!r}") all_inputs.append(inp) all_targets.append(tgt) if n > chunk_size: print(f" [{benchmark}] Generating samples {n}/{n}... Done.") sys.stdout.flush() return np.concatenate(all_inputs, axis=0), np.concatenate(all_targets, axis=0) def _get_ext_val_cache(benchmark: str) -> str: return os.path.join(CACHE_DIR, f"{benchmark}_val_N{GRID_SIZE}_ext.npz") def _load_or_gen_ext_val(benchmark: str) -> tuple: os.makedirs(CACHE_DIR, exist_ok=True) cache = _get_ext_val_cache(benchmark) if os.path.exists(cache): data = np.load(cache) return data["inputs"], data["targets"] print(f"Generating {benchmark} val set ({N_VAL} samples, seed={VAL_SEED})…") t0 = time.time() inp, tgt = _generate_ext_dataset(benchmark, N_VAL, VAL_SEED) np.savez(cache, inputs=inp, targets=tgt) print(f" Cached {N_VAL} samples in {time.time()-t0:.1f}s → {cache}") return inp, tgt def _get_ext_train_cache_path(benchmark: str) -> str: return os.path.join(CACHE_DIR, f"{benchmark}_train_N{N_TRAIN}_ext.npz") _ext_train_cache: dict = {} def _get_ext_train(benchmark: str) -> tuple: if benchmark not in _ext_train_cache: os.makedirs(CACHE_DIR, exist_ok=True) cache_path = _get_ext_train_cache_path(benchmark) if os.path.exists(cache_path): data = np.load(cache_path) _ext_train_cache[benchmark] = (data["inputs"], data["targets"]) else: print(f"Generating {benchmark} train data ({N_TRAIN} samples)…") t0 = time.time() inputs, targets = _generate_ext_dataset(benchmark, N_TRAIN, TRAIN_SEED) np.savez(cache_path, inputs=inputs, targets=targets) print(f" {N_TRAIN} samples in {time.time()-t0:.1f}s → {cache_path}") _ext_train_cache[benchmark] = (inputs, targets) return _ext_train_cache[benchmark] from data.prepare import PDEDataset # ── Public dataloader (same interface as prepare.make_dataloader) ───────────── def make_ext_dataloader(benchmark: str, split: str, batch_size: int, seed: int | None = None, **kwargs): """Yielding (inputs, targets) as PyTorch tensors.""" assert split in ("train", "val") if split == "val": inp, tgt = _load_or_gen_ext_val(benchmark) dataset = PDEDataset(torch.from_numpy(inp), torch.from_numpy(tgt)) return torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True ) else: inp, tgt = _get_ext_train(benchmark) dataset = PDEDataset(torch.from_numpy(inp), torch.from_numpy(tgt)) loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, generator=torch.Generator().manual_seed(seed if seed is not None else 12345) ) def infinite_loader(): while True: for batch in loader: yield batch return infinite_loader() def evaluate_l2_rel_ext(benchmark: str, model, batch_size: int = 64) -> float: """Mean relative L2 error on fixed val set. Same metric as prepare.py.""" val_loader = make_ext_dataloader(benchmark, "val", batch_size) total_err = 0.0 total_norm = 0.0 model.eval() with torch.no_grad(): for x, y in val_loader: x, y = x.to(TORCH_DEVICE), y.to(TORCH_DEVICE) y_pred = model(x) diff = (y_pred - y).float() y_f = y.float() axes = tuple(range(1, y.ndim)) err = torch.sqrt(torch.mean(diff ** 2, dim=axes)) nrm = torch.sqrt(torch.mean(y_f ** 2, dim=axes)) total_err += torch.sum(err).item() total_norm += torch.sum(nrm).item() return total_err / max(total_norm, 1e-8) # ── Extended SOTA targets ────────────────────────────────────────────────────── EXT_SOTA = { "kdv_1d": 0.010, # FNO on KdV, Tran et al. 2023 "wave_1d": 0.005, # Wave equation: easier than Burgers, FNO near-exact "darcy_2d": 0.0108, # Li et al. 2020 FNO on Darcy (proper solver) "ns_2d": 0.0128, # Li et al. 2020 FNO on NS (T=1, nu=1e-2) "ns_hre_2d": 0.0700, # Estimated SOTA for Re=1000 "swe_2d": 0.0020, # FNO on SWE "allen_cahn_2d": 0.020, # SOTA near 0.02 "mhd_2d": 0.0350, # MHD targets from PhysicsNeMo } # ── Benchmark metadata ──────────────────────────────────────────────────────── EXT_BENCHMARK_INFO = { "kdv_1d": { "pde": "u_t + u·u_x + u_xxx = 0 (Korteweg-de Vries)", "domain": "[0, 2π), periodic", "ic_type": "smooth random Fourier series", "solver": "ETDRK4 (exponential time differencing Runge-Kutta 4)", "t_final": KDV_T, "n_steps": KDV_NSTEPS, "sota_model": "FNO", "notes": "Soliton dynamics; FNO handles well due to periodicity", }, "wave_1d": { "pde": "u_tt = c² u_xx (1D wave, c=1)", "domain": "[0, 2π), periodic", "ic_type": "smooth random Fourier series for u0 and du/dt", "solver": "Störmer-Verlet (symplectic, energy-conserving)", "t_final": WAVE_T, "n_steps": WAVE_NSTEPS, "sota_model": "FNO", "notes": "Linear PDE; FNO can achieve near-zero error easily", }, "darcy_2d": { "pde": "-∇·(a(x,y)∇u) = f (2D Darcy flow)", "domain": "[0, 1]², periodic", "ic_type": "GRF permeability a ∈ [0.6, 1.4]; zero-mean GRF source f", "solver": "Richardson iteration + spectral preconditioner (40 iters)", "t_final": None, "n_steps": DARCY_FIX_N_ITER, "sota_model": "FNO", "notes": "Fixed: proper variable-coeff solve; prepare.py used only mean(a)", "known_issue_in_prepare_py": "solve_darcy_2d_batch uses a_avg (scalar) → u independent of spatial a; " "f uses fixed seed=42 → u uncorrelated with model input a", }, "ns_2d": { "pde": "w_t + (u*grad)w = nu * Laplacian(w) (2D NS, vorticity form)", "domain": "[0, 2π)², periodic", "ic_type": "small-amplitude vorticity (scale=0.1) → CFL≈0.6 < 1", "solver": "Semi-implicit Euler, 2/3-rule dealiasing, n_steps=1000", "t_final": NS_T, "n_steps": NS_NSTEPS, "sota_model": "FNO", "notes": "Fixed: IC scale 1.0→0.1 reduces CFL from 61 to ~0.6", "known_issue_in_prepare_py": "solve_ns_2d_batch uses IC scale=1.0 → max_velocity≈95 → " "CFL≈61 → semi-implicit Euler explodes to NaN on step 1", }, "swe_2d": { "pde": "Shallow Water Equations (height-vorticity)", "domain": "[0, 1]^2, periodic", "ic_type": "Random height bumps", "solver": "Spectral continuity + momentum", "t_final": SWE_T, "n_steps": SWE_NSTEPS, "sota_model": "MemNO", "notes": "Tests multi-scale wave dynamics", }, "allen_cahn_2d": { "pde": "Allen-Cahn Phase Separation", "domain": "[0, 1]^2, periodic", "ic_type": "High-frequency random noise", "solver": "Semi-implicit spectral", "t_final": AC_T, "n_steps": AC_NSTEPS, "sota_model": "FNO", "notes": "Tests sharp interface capture", }, "mhd_2d": { "pde": "Magnetohydrodynamics (vorticity-potential)", "domain": "[0, 1]^2, periodic", "ic_type": "Orszag-Tang inspired random fields", "solver": "Dual-field spectral", "t_final": MHD_T, "n_steps": MHD_NSTEPS, "sota_model": "TFNO", "notes": "Coupled fluid-magnetic dynamics", }, "poisson_2d": { "pde": "-Δu = f (2D Poisson equation)", "domain": "[0, 1]^2, periodic", "ic_type": "Zero-mean random source f", "solver": "Exact spectral solver", "t_final": None, "n_steps": None, "sota_model": "IterativeFNO", "notes": "Fundamental elliptic PDE; tests precision and convergence stabilities", }, "reionization_1d": { "pde": "u_t + c·u_x = S(x) - alpha·u^2 (Cosmic Reionization toy model)", "domain": "[0, 1], periodic", "ic_type": "Source field S(x)", "solver": "Upwind scheme + recombination", "t_final": 0.5, "n_steps": 100, "sota_model": "PINN", "notes": "Non-linear advection-reaction; mimics ionization front propagation", }, "ellipse_2d": { "pde": "Incompressible laminar flow around ellipse (surface pressure)", "domain": "[-1, 1]^2", "ic_type": "Random ellipse geometry (a, b, alpha)", "solver": "Potential flow analytical approximation", "t_final": None, "n_steps": None, "sota_model": "SAR", "notes": "Tests geometry-to-distribution mapping as proposed in Lino & Thuerey (2026)", }, } if __name__ == "__main__": import time print("Extended benchmarks available:", sorted(EXT_BENCHMARKS)) for bm in sorted(EXT_BENCHMARKS): info = EXT_BENCHMARK_INFO.get(bm, {"pde": "Unknown", "solver": "Unknown"}) print(f"\n{bm}:") print(f" PDE : {info['pde']}") print(f" Solver: {info['solver']}") print(f" SOTA : ~{EXT_SOTA.get(bm, 0.0):.4f} rel-L2") # Quick smoke test: generate 4 samples t0 = time.time() inp, tgt = _generate_ext_dataset(bm, 4, seed=0) elapsed = time.time() - t0 print(f" Shape : in={inp.shape} → out={tgt.shape}") print(f" Gen : {elapsed:.2f}s for 4 samples") print(f" NaN? : in={np.isnan(inp).any()} out={np.isnan(tgt).any()}") if bm in ("darcy_2d",): from scipy.stats import pearsonr import warnings with warnings.catch_warnings(): warnings.simplefilter("ignore") r, _ = pearsonr(inp[0].flatten(), tgt[0].flatten()) print(f" corr(a[0], u[0]): {r:.4f} (should be non-trivial for learnable data)")