|
|
""" |
|
|
Spherical Fourier Neural Operator (SFNO) for LILITH. |
|
|
|
|
|
Processes atmospheric state on a spherical domain using spectral methods. |
|
|
Based on NVIDIA's FourCastNet architecture. |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
try: |
|
|
import torch_harmonics as th |
|
|
TORCH_HARMONICS_AVAILABLE = True |
|
|
except ImportError: |
|
|
TORCH_HARMONICS_AVAILABLE = False |
|
|
|
|
|
|
|
|
class SpectralConv2d(nn.Module): |
|
|
""" |
|
|
2D Spectral Convolution using FFT. |
|
|
|
|
|
Performs convolution in the Fourier domain for global receptive field |
|
|
with O(N log N) complexity. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
modes1: int = 32, |
|
|
modes2: int = 32, |
|
|
): |
|
|
""" |
|
|
Initialize spectral convolution. |
|
|
|
|
|
Args: |
|
|
in_channels: Input channels |
|
|
out_channels: Output channels |
|
|
modes1: Number of Fourier modes in first dimension |
|
|
modes2: Number of Fourier modes in second dimension |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.in_channels = in_channels |
|
|
self.out_channels = out_channels |
|
|
self.modes1 = modes1 |
|
|
self.modes2 = modes2 |
|
|
|
|
|
|
|
|
scale = 1 / (in_channels * out_channels) |
|
|
self.weights1 = nn.Parameter( |
|
|
scale * torch.randn(in_channels, out_channels, modes1, modes2, dtype=torch.cfloat) |
|
|
) |
|
|
self.weights2 = nn.Parameter( |
|
|
scale * torch.randn(in_channels, out_channels, modes1, modes2, dtype=torch.cfloat) |
|
|
) |
|
|
|
|
|
def compl_mul2d( |
|
|
self, |
|
|
input: torch.Tensor, |
|
|
weights: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
"""Complex multiplication for batched inputs.""" |
|
|
|
|
|
return torch.einsum("bixy,ioxy->boxy", input, weights) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (batch, channels, height, width) |
|
|
|
|
|
Returns: |
|
|
Output tensor of same shape |
|
|
""" |
|
|
batch_size = x.size(0) |
|
|
height, width = x.size(-2), x.size(-1) |
|
|
|
|
|
|
|
|
x_ft = torch.fft.rfft2(x) |
|
|
|
|
|
|
|
|
out_ft = torch.zeros( |
|
|
batch_size, |
|
|
self.out_channels, |
|
|
height, |
|
|
width // 2 + 1, |
|
|
dtype=torch.cfloat, |
|
|
device=x.device, |
|
|
) |
|
|
|
|
|
|
|
|
out_ft[:, :, :self.modes1, :self.modes2] = self.compl_mul2d( |
|
|
x_ft[:, :, :self.modes1, :self.modes2], |
|
|
self.weights1, |
|
|
) |
|
|
|
|
|
|
|
|
out_ft[:, :, -self.modes1:, :self.modes2] = self.compl_mul2d( |
|
|
x_ft[:, :, -self.modes1:, :self.modes2], |
|
|
self.weights2, |
|
|
) |
|
|
|
|
|
|
|
|
x = torch.fft.irfft2(out_ft, s=(height, width)) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class SphericalConv(nn.Module): |
|
|
""" |
|
|
Spherical convolution using spherical harmonics. |
|
|
|
|
|
Properly handles the geometry of the sphere, avoiding polar distortion |
|
|
that occurs with standard 2D convolutions on lat-lon grids. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
nlat: int = 721, |
|
|
nlon: int = 1440, |
|
|
lmax: Optional[int] = None, |
|
|
): |
|
|
""" |
|
|
Initialize spherical convolution. |
|
|
|
|
|
Args: |
|
|
in_channels: Input channels |
|
|
out_channels: Output channels |
|
|
nlat: Number of latitude points |
|
|
nlon: Number of longitude points |
|
|
lmax: Maximum spherical harmonic degree |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.in_channels = in_channels |
|
|
self.out_channels = out_channels |
|
|
self.nlat = nlat |
|
|
self.nlon = nlon |
|
|
self.lmax = lmax or nlat // 2 |
|
|
|
|
|
if TORCH_HARMONICS_AVAILABLE: |
|
|
|
|
|
self.sht = th.RealSHT(nlat, nlon, lmax=self.lmax) |
|
|
self.isht = th.InverseRealSHT(nlat, nlon, lmax=self.lmax) |
|
|
|
|
|
|
|
|
n_coeffs = (self.lmax + 1) * (self.lmax + 2) // 2 |
|
|
self.spectral_weights = nn.Parameter( |
|
|
torch.randn(in_channels, out_channels, n_coeffs) / math.sqrt(in_channels) |
|
|
) |
|
|
else: |
|
|
|
|
|
self.spectral_conv = SpectralConv2d( |
|
|
in_channels, out_channels, |
|
|
modes1=min(32, nlat // 2), |
|
|
modes2=min(32, nlon // 2), |
|
|
) |
|
|
self.spectral_weights = None |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (batch, channels, nlat, nlon) |
|
|
|
|
|
Returns: |
|
|
Output tensor of same shape |
|
|
""" |
|
|
if TORCH_HARMONICS_AVAILABLE and self.spectral_weights is not None: |
|
|
batch_size = x.size(0) |
|
|
|
|
|
|
|
|
x_spec = self.sht(x) |
|
|
|
|
|
|
|
|
out_spec = torch.einsum("bixk,iok->boxk", x_spec, self.spectral_weights) |
|
|
|
|
|
|
|
|
x = self.isht(out_spec) |
|
|
else: |
|
|
x = self.spectral_conv(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class SFNOBlock(nn.Module): |
|
|
""" |
|
|
Single block of the Spherical Fourier Neural Operator. |
|
|
|
|
|
Combines spectral convolution with pointwise MLP and residual connection. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
channels: int, |
|
|
nlat: int = 64, |
|
|
nlon: int = 128, |
|
|
mlp_ratio: float = 2.0, |
|
|
dropout: float = 0.1, |
|
|
use_spherical: bool = True, |
|
|
): |
|
|
""" |
|
|
Initialize SFNO block. |
|
|
|
|
|
Args: |
|
|
channels: Number of channels |
|
|
nlat: Number of latitude points |
|
|
nlon: Number of longitude points |
|
|
mlp_ratio: MLP hidden dimension ratio |
|
|
dropout: Dropout probability |
|
|
use_spherical: Use spherical harmonics (if available) |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.channels = channels |
|
|
self.use_spherical = use_spherical and TORCH_HARMONICS_AVAILABLE |
|
|
|
|
|
|
|
|
if self.use_spherical: |
|
|
self.spectral = SphericalConv(channels, channels, nlat, nlon) |
|
|
else: |
|
|
self.spectral = SpectralConv2d( |
|
|
channels, channels, |
|
|
modes1=min(32, nlat // 2), |
|
|
modes2=min(32, nlon // 2), |
|
|
) |
|
|
|
|
|
|
|
|
hidden_dim = int(channels * mlp_ratio) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Conv2d(channels, hidden_dim, 1), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Conv2d(hidden_dim, channels, 1), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
|
|
|
self.norm1 = nn.GroupNorm(min(32, channels), channels) |
|
|
self.norm2 = nn.GroupNorm(min(32, channels), channels) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Forward pass with residual connections.""" |
|
|
|
|
|
h = self.norm1(x) |
|
|
h = self.spectral(h) |
|
|
x = x + h |
|
|
|
|
|
|
|
|
h = self.norm2(x) |
|
|
h = self.mlp(h) |
|
|
x = x + h |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class SphericalFourierNeuralOperator(nn.Module): |
|
|
""" |
|
|
Full Spherical Fourier Neural Operator. |
|
|
|
|
|
A neural operator that learns atmospheric dynamics on a spherical domain |
|
|
using spectral methods for efficient global communication. |
|
|
|
|
|
Based on: |
|
|
- FourCastNet (NVIDIA) |
|
|
- Spherical Fourier Neural Operators (Bonev et al.) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int, |
|
|
hidden_dim: int = 256, |
|
|
output_dim: int = 256, |
|
|
num_layers: int = 4, |
|
|
nlat: int = 64, |
|
|
nlon: int = 128, |
|
|
mlp_ratio: float = 2.0, |
|
|
dropout: float = 0.1, |
|
|
use_spherical: bool = True, |
|
|
): |
|
|
""" |
|
|
Initialize SFNO. |
|
|
|
|
|
Args: |
|
|
input_dim: Input feature dimension |
|
|
hidden_dim: Hidden dimension |
|
|
output_dim: Output dimension |
|
|
num_layers: Number of SFNO blocks |
|
|
nlat: Number of latitude points in grid |
|
|
nlon: Number of longitude points in grid |
|
|
mlp_ratio: MLP expansion ratio |
|
|
dropout: Dropout probability |
|
|
use_spherical: Use spherical harmonics |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.input_dim = input_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
self.output_dim = output_dim |
|
|
self.nlat = nlat |
|
|
self.nlon = nlon |
|
|
|
|
|
|
|
|
self.input_proj = nn.Conv2d(input_dim, hidden_dim, 1) |
|
|
self.input_norm = nn.GroupNorm(min(32, hidden_dim), hidden_dim) |
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
SFNOBlock( |
|
|
channels=hidden_dim, |
|
|
nlat=nlat, |
|
|
nlon=nlon, |
|
|
mlp_ratio=mlp_ratio, |
|
|
dropout=dropout, |
|
|
use_spherical=use_spherical, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.output_norm = nn.GroupNorm(min(32, hidden_dim), hidden_dim) |
|
|
self.output_proj = nn.Conv2d(hidden_dim, output_dim, 1) |
|
|
|
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def enable_gradient_checkpointing(self): |
|
|
"""Enable gradient checkpointing for memory efficiency.""" |
|
|
self.gradient_checkpointing = True |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Process gridded atmospheric state. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (batch, input_dim, nlat, nlon) |
|
|
|
|
|
Returns: |
|
|
Output tensor of shape (batch, output_dim, nlat, nlon) |
|
|
""" |
|
|
|
|
|
h = self.input_proj(x) |
|
|
h = self.input_norm(h) |
|
|
|
|
|
|
|
|
for block in self.blocks: |
|
|
if self.gradient_checkpointing and self.training: |
|
|
h = torch.utils.checkpoint.checkpoint(block, h, use_reentrant=False) |
|
|
else: |
|
|
h = block(h) |
|
|
|
|
|
|
|
|
h = self.output_norm(h) |
|
|
h = self.output_proj(h) |
|
|
|
|
|
return h |
|
|
|
|
|
def forward_multiscale( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
scales: Tuple[int, ...] = (1, 2, 4), |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Multi-scale processing for capturing different spatial patterns. |
|
|
|
|
|
Args: |
|
|
x: Input tensor |
|
|
scales: Downsampling factors to use |
|
|
|
|
|
Returns: |
|
|
Combined multi-scale output |
|
|
""" |
|
|
outputs = [] |
|
|
|
|
|
for scale in scales: |
|
|
if scale > 1: |
|
|
|
|
|
x_scaled = F.avg_pool2d(x, scale) |
|
|
|
|
|
h = self.forward(x_scaled) |
|
|
|
|
|
h = F.interpolate(h, size=(self.nlat, self.nlon), mode="bilinear") |
|
|
else: |
|
|
h = self.forward(x) |
|
|
|
|
|
outputs.append(h) |
|
|
|
|
|
|
|
|
return torch.stack(outputs).mean(dim=0) |
|
|
|