consigcody94's picture
Upload source code and documentation
8bcb60f verified
"""
Spherical Fourier Neural Operator (SFNO) for LILITH.
Processes atmospheric state on a spherical domain using spectral methods.
Based on NVIDIA's FourCastNet architecture.
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# Try to import torch_harmonics for spherical harmonics
try:
import torch_harmonics as th
TORCH_HARMONICS_AVAILABLE = True
except ImportError:
TORCH_HARMONICS_AVAILABLE = False
class SpectralConv2d(nn.Module):
"""
2D Spectral Convolution using FFT.
Performs convolution in the Fourier domain for global receptive field
with O(N log N) complexity.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
modes1: int = 32,
modes2: int = 32,
):
"""
Initialize spectral convolution.
Args:
in_channels: Input channels
out_channels: Output channels
modes1: Number of Fourier modes in first dimension
modes2: Number of Fourier modes in second dimension
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.modes1 = modes1
self.modes2 = modes2
# Learnable Fourier coefficients
scale = 1 / (in_channels * out_channels)
self.weights1 = nn.Parameter(
scale * torch.randn(in_channels, out_channels, modes1, modes2, dtype=torch.cfloat)
)
self.weights2 = nn.Parameter(
scale * torch.randn(in_channels, out_channels, modes1, modes2, dtype=torch.cfloat)
)
def compl_mul2d(
self,
input: torch.Tensor,
weights: torch.Tensor,
) -> torch.Tensor:
"""Complex multiplication for batched inputs."""
# (batch, in_ch, x, y) * (in_ch, out_ch, x, y) -> (batch, out_ch, x, y)
return torch.einsum("bixy,ioxy->boxy", input, weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Args:
x: Input tensor of shape (batch, channels, height, width)
Returns:
Output tensor of same shape
"""
batch_size = x.size(0)
height, width = x.size(-2), x.size(-1)
# Compute FFT
x_ft = torch.fft.rfft2(x)
# Multiply relevant Fourier modes
out_ft = torch.zeros(
batch_size,
self.out_channels,
height,
width // 2 + 1,
dtype=torch.cfloat,
device=x.device,
)
# Lower modes
out_ft[:, :, :self.modes1, :self.modes2] = self.compl_mul2d(
x_ft[:, :, :self.modes1, :self.modes2],
self.weights1,
)
# Upper modes (for symmetry)
out_ft[:, :, -self.modes1:, :self.modes2] = self.compl_mul2d(
x_ft[:, :, -self.modes1:, :self.modes2],
self.weights2,
)
# Inverse FFT
x = torch.fft.irfft2(out_ft, s=(height, width))
return x
class SphericalConv(nn.Module):
"""
Spherical convolution using spherical harmonics.
Properly handles the geometry of the sphere, avoiding polar distortion
that occurs with standard 2D convolutions on lat-lon grids.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
nlat: int = 721,
nlon: int = 1440,
lmax: Optional[int] = None,
):
"""
Initialize spherical convolution.
Args:
in_channels: Input channels
out_channels: Output channels
nlat: Number of latitude points
nlon: Number of longitude points
lmax: Maximum spherical harmonic degree
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.nlat = nlat
self.nlon = nlon
self.lmax = lmax or nlat // 2
if TORCH_HARMONICS_AVAILABLE:
# Use torch_harmonics for proper spherical harmonics
self.sht = th.RealSHT(nlat, nlon, lmax=self.lmax)
self.isht = th.InverseRealSHT(nlat, nlon, lmax=self.lmax)
# Learnable spectral weights
n_coeffs = (self.lmax + 1) * (self.lmax + 2) // 2
self.spectral_weights = nn.Parameter(
torch.randn(in_channels, out_channels, n_coeffs) / math.sqrt(in_channels)
)
else:
# Fallback to FFT-based convolution
self.spectral_conv = SpectralConv2d(
in_channels, out_channels,
modes1=min(32, nlat // 2),
modes2=min(32, nlon // 2),
)
self.spectral_weights = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Args:
x: Input tensor of shape (batch, channels, nlat, nlon)
Returns:
Output tensor of same shape
"""
if TORCH_HARMONICS_AVAILABLE and self.spectral_weights is not None:
batch_size = x.size(0)
# Transform to spectral domain
x_spec = self.sht(x)
# Apply learnable weights
out_spec = torch.einsum("bixk,iok->boxk", x_spec, self.spectral_weights)
# Transform back
x = self.isht(out_spec)
else:
x = self.spectral_conv(x)
return x
class SFNOBlock(nn.Module):
"""
Single block of the Spherical Fourier Neural Operator.
Combines spectral convolution with pointwise MLP and residual connection.
"""
def __init__(
self,
channels: int,
nlat: int = 64,
nlon: int = 128,
mlp_ratio: float = 2.0,
dropout: float = 0.1,
use_spherical: bool = True,
):
"""
Initialize SFNO block.
Args:
channels: Number of channels
nlat: Number of latitude points
nlon: Number of longitude points
mlp_ratio: MLP hidden dimension ratio
dropout: Dropout probability
use_spherical: Use spherical harmonics (if available)
"""
super().__init__()
self.channels = channels
self.use_spherical = use_spherical and TORCH_HARMONICS_AVAILABLE
# Spectral convolution
if self.use_spherical:
self.spectral = SphericalConv(channels, channels, nlat, nlon)
else:
self.spectral = SpectralConv2d(
channels, channels,
modes1=min(32, nlat // 2),
modes2=min(32, nlon // 2),
)
# Pointwise MLP
hidden_dim = int(channels * mlp_ratio)
self.mlp = nn.Sequential(
nn.Conv2d(channels, hidden_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(hidden_dim, channels, 1),
nn.Dropout(dropout),
)
# Normalization
self.norm1 = nn.GroupNorm(min(32, channels), channels)
self.norm2 = nn.GroupNorm(min(32, channels), channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with residual connections."""
# Spectral path
h = self.norm1(x)
h = self.spectral(h)
x = x + h
# MLP path
h = self.norm2(x)
h = self.mlp(h)
x = x + h
return x
class SphericalFourierNeuralOperator(nn.Module):
"""
Full Spherical Fourier Neural Operator.
A neural operator that learns atmospheric dynamics on a spherical domain
using spectral methods for efficient global communication.
Based on:
- FourCastNet (NVIDIA)
- Spherical Fourier Neural Operators (Bonev et al.)
"""
def __init__(
self,
input_dim: int,
hidden_dim: int = 256,
output_dim: int = 256,
num_layers: int = 4,
nlat: int = 64,
nlon: int = 128,
mlp_ratio: float = 2.0,
dropout: float = 0.1,
use_spherical: bool = True,
):
"""
Initialize SFNO.
Args:
input_dim: Input feature dimension
hidden_dim: Hidden dimension
output_dim: Output dimension
num_layers: Number of SFNO blocks
nlat: Number of latitude points in grid
nlon: Number of longitude points in grid
mlp_ratio: MLP expansion ratio
dropout: Dropout probability
use_spherical: Use spherical harmonics
"""
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.nlat = nlat
self.nlon = nlon
# Input projection
self.input_proj = nn.Conv2d(input_dim, hidden_dim, 1)
self.input_norm = nn.GroupNorm(min(32, hidden_dim), hidden_dim)
# SFNO blocks
self.blocks = nn.ModuleList([
SFNOBlock(
channels=hidden_dim,
nlat=nlat,
nlon=nlon,
mlp_ratio=mlp_ratio,
dropout=dropout,
use_spherical=use_spherical,
)
for _ in range(num_layers)
])
# Output projection
self.output_norm = nn.GroupNorm(min(32, hidden_dim), hidden_dim)
self.output_proj = nn.Conv2d(hidden_dim, output_dim, 1)
# Gradient checkpointing
self.gradient_checkpointing = False
def enable_gradient_checkpointing(self):
"""Enable gradient checkpointing for memory efficiency."""
self.gradient_checkpointing = True
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Process gridded atmospheric state.
Args:
x: Input tensor of shape (batch, input_dim, nlat, nlon)
Returns:
Output tensor of shape (batch, output_dim, nlat, nlon)
"""
# Input projection
h = self.input_proj(x)
h = self.input_norm(h)
# Apply SFNO blocks
for block in self.blocks:
if self.gradient_checkpointing and self.training:
h = torch.utils.checkpoint.checkpoint(block, h, use_reentrant=False)
else:
h = block(h)
# Output projection
h = self.output_norm(h)
h = self.output_proj(h)
return h
def forward_multiscale(
self,
x: torch.Tensor,
scales: Tuple[int, ...] = (1, 2, 4),
) -> torch.Tensor:
"""
Multi-scale processing for capturing different spatial patterns.
Args:
x: Input tensor
scales: Downsampling factors to use
Returns:
Combined multi-scale output
"""
outputs = []
for scale in scales:
if scale > 1:
# Downsample
x_scaled = F.avg_pool2d(x, scale)
# Process
h = self.forward(x_scaled)
# Upsample back
h = F.interpolate(h, size=(self.nlat, self.nlon), mode="bilinear")
else:
h = self.forward(x)
outputs.append(h)
# Combine scales (simple average, could be learned)
return torch.stack(outputs).mean(dim=0)