File size: 2,012 Bytes
54fa103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""WaveBench — 2D Wave propagation with high-frequency components.

PDE on [0,2π)², periodic:
    u_tt = c^2 ∇²u

Simulates harmonic wave propagation to test spectral bias.
Exact Fourier-mode solution:
    u_hat(k, T) = u_hat_0(k) * cos(c|k|T) + u_t_hat_0(k) * sin(c|k|T) / (c|k|)

Here we release from rest, so u_t(0) = 0.
"""

import math
import torch
import numpy as np

from data.prepare import _random_ic_2d
from core.device import DEVICE, TORCH_DEVICE

C_SPEED = 2.0
T_FINAL = 1.0

METADATA = {
    "pde":      "u_tt = c²∇²u (High-frequency harmonic wave propagation)",
    "domain":   "[0,2π)², periodic",
    "solver":   "Analytic Fourier propagator (exact)",
    "t_final":  T_FINAL,
    "n_steps":  1,
    "in_shape": "B,N,N",
    "out_shape": "B,N,N",
    "notes":    "Tests model ability to resolve high-frequency waves (spectral bias).",
}

def make_ic(n: int, N: int, rng: np.random.RandomState) -> torch.Tensor:
    """Random surface anomaly with higher mode frequencies."""
    # scale higher modes to test high-frequency bias
    ic_np = _random_ic_2d(n, N, rng, n_modes=12, scale=0.5, offset=0.0)
    return torch.from_numpy(ic_np).to(TORCH_DEVICE)

def solve_batch(u0: torch.Tensor, T: float = T_FINAL) -> torch.Tensor:
    B, N, _ = u0.shape
    # k_int = np.fft.fftfreq(N, d=1.0 / N)
    k_int = torch.fft.fftfreq(N, d=1.0 / N, device=TORCH_DEVICE)
    kx, ky = torch.meshgrid(k_int, k_int, indexing="ij")
    omega = C_SPEED * torch.sqrt(kx**2 + ky**2)
    
    propagator = torch.cos(omega * T)[None, :, :]
    
    u0_d = u0.to(torch.float32)
    u_hat = torch.fft.fft2(u0_d, dim=(1, 2))
    uT_hat = u_hat * propagator
    uT = torch.fft.ifft2(uT_hat, dim=(1, 2)).real
    return uT.to(torch.float32)

def make_dataset(n: int, seed: int, N: int = 64) -> tuple[torch.Tensor, torch.Tensor]:
    rng = np.random.RandomState(seed)
    inputs = make_ic(n, N, rng)
    targets = solve_batch(inputs)
    # Ensure [B, N, N, 1] shape
    return inputs[..., None], targets[..., None]