|
|
""" |
|
|
Climate Embedding Module for LILITH. |
|
|
|
|
|
Encodes climate indices and large-scale patterns that influence |
|
|
long-range weather predictability: |
|
|
- ENSO (El Nino Southern Oscillation) |
|
|
- MJO (Madden-Julian Oscillation) |
|
|
- NAO (North Atlantic Oscillation) |
|
|
- Seasonal cycles |
|
|
- Solar position |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Optional, Dict |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class SeasonalEmbedding(nn.Module): |
|
|
""" |
|
|
Encodes seasonal and cyclical time features. |
|
|
|
|
|
Uses sinusoidal encoding to capture annual, monthly, and daily cycles. |
|
|
""" |
|
|
|
|
|
def __init__(self, d_model: int, max_harmonics: int = 4): |
|
|
""" |
|
|
Initialize seasonal embedding. |
|
|
|
|
|
Args: |
|
|
d_model: Embedding dimension |
|
|
max_harmonics: Number of harmonic frequencies |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.d_model = d_model |
|
|
self.max_harmonics = max_harmonics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_features = 4 * max_harmonics + 2 |
|
|
|
|
|
self.proj = nn.Linear(n_features, d_model) |
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
day_of_year: torch.Tensor, |
|
|
hour: Optional[torch.Tensor] = None, |
|
|
year: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Encode time features. |
|
|
|
|
|
Args: |
|
|
day_of_year: Day of year (1-366), shape (batch, seq_len) |
|
|
hour: Hour of day (0-23), shape (batch, seq_len) - optional |
|
|
year: Year, shape (batch, seq_len) - optional |
|
|
|
|
|
Returns: |
|
|
Seasonal embedding of shape (batch, seq_len, d_model) |
|
|
""" |
|
|
features = [] |
|
|
|
|
|
|
|
|
for k in range(1, self.max_harmonics + 1): |
|
|
freq = 2 * math.pi * k * day_of_year / 365.25 |
|
|
features.extend([torch.sin(freq), torch.cos(freq)]) |
|
|
|
|
|
|
|
|
if hour is not None: |
|
|
for k in range(1, self.max_harmonics + 1): |
|
|
freq = 2 * math.pi * k * hour / 24.0 |
|
|
features.extend([torch.sin(freq), torch.cos(freq)]) |
|
|
else: |
|
|
|
|
|
features.extend([torch.zeros_like(day_of_year)] * (2 * self.max_harmonics)) |
|
|
|
|
|
|
|
|
|
|
|
declination = 23.45 * torch.sin(2 * math.pi * (day_of_year - 81) / 365.25) |
|
|
features.append(declination / 23.45) |
|
|
|
|
|
|
|
|
B = 2 * math.pi * (day_of_year - 81) / 365.25 |
|
|
eot = 9.87 * torch.sin(2 * B) - 7.53 * torch.cos(B) - 1.5 * torch.sin(B) |
|
|
features.append(eot / 15.0) |
|
|
|
|
|
|
|
|
encoding = torch.stack(features, dim=-1) |
|
|
|
|
|
|
|
|
return self.norm(self.proj(encoding)) |
|
|
|
|
|
|
|
|
class ClimateIndexEmbedding(nn.Module): |
|
|
""" |
|
|
Embeds climate indices (ENSO, NAO, MJO, etc.). |
|
|
|
|
|
These indices capture large-scale climate patterns that influence |
|
|
weather predictability on subseasonal to seasonal timescales. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
n_indices: int = 10, |
|
|
hidden_dim: int = 64, |
|
|
): |
|
|
""" |
|
|
Initialize climate index embedding. |
|
|
|
|
|
Args: |
|
|
d_model: Output embedding dimension |
|
|
n_indices: Maximum number of climate indices |
|
|
hidden_dim: Hidden dimension |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.d_model = d_model |
|
|
self.n_indices = n_indices |
|
|
|
|
|
|
|
|
self.index_names = [ |
|
|
"nino34", |
|
|
"nino12", |
|
|
"soi", |
|
|
"mjo_amp", |
|
|
"mjo_phase", |
|
|
"nao", |
|
|
"ao", |
|
|
"pdo", |
|
|
"amo", |
|
|
"qbo", |
|
|
] |
|
|
|
|
|
|
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(n_indices + 2, hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim, d_model), |
|
|
nn.LayerNorm(d_model), |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
indices: torch.Tensor, |
|
|
mjo_phase: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Embed climate indices. |
|
|
|
|
|
Args: |
|
|
indices: Climate index values of shape (batch, seq_len, n_indices) |
|
|
Values should be normalized/standardized |
|
|
mjo_phase: MJO phase (1-8) of shape (batch, seq_len) |
|
|
|
|
|
Returns: |
|
|
Embedding of shape (batch, seq_len, d_model) |
|
|
""" |
|
|
|
|
|
if mjo_phase is not None: |
|
|
phase_rad = 2 * math.pi * mjo_phase / 8.0 |
|
|
mjo_sin = torch.sin(phase_rad).unsqueeze(-1) |
|
|
mjo_cos = torch.cos(phase_rad).unsqueeze(-1) |
|
|
x = torch.cat([indices, mjo_sin, mjo_cos], dim=-1) |
|
|
else: |
|
|
|
|
|
batch, seq_len, _ = indices.shape |
|
|
padding = torch.zeros(batch, seq_len, 2, device=indices.device) |
|
|
x = torch.cat([indices, padding], dim=-1) |
|
|
|
|
|
return self.mlp(x) |
|
|
|
|
|
|
|
|
class SolarPositionEmbedding(nn.Module): |
|
|
""" |
|
|
Encodes solar position for each location and time. |
|
|
|
|
|
Critical for capturing diurnal cycles and their geographic variation. |
|
|
""" |
|
|
|
|
|
def __init__(self, d_model: int): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
n_features = 6 |
|
|
self.proj = nn.Linear(n_features, d_model) |
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
lat: torch.Tensor, |
|
|
lon: torch.Tensor, |
|
|
day_of_year: torch.Tensor, |
|
|
hour: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute solar position features. |
|
|
|
|
|
Args: |
|
|
lat: Latitude in degrees, shape (batch, n_locations) or (batch,) |
|
|
lon: Longitude in degrees, shape (batch, n_locations) or (batch,) |
|
|
day_of_year: Day of year, shape (batch,) or (batch, seq_len) |
|
|
hour: Hour of day (0-24), shape (batch,) or (batch, seq_len) |
|
|
|
|
|
Returns: |
|
|
Solar position embedding |
|
|
""" |
|
|
|
|
|
if lat.dim() == 1: |
|
|
lat = lat.unsqueeze(-1) |
|
|
if lon.dim() == 1: |
|
|
lon = lon.unsqueeze(-1) |
|
|
|
|
|
|
|
|
declination = 23.45 * torch.sin( |
|
|
torch.deg2rad(torch.tensor(360 / 365.25 * (day_of_year - 81))) |
|
|
) |
|
|
|
|
|
|
|
|
lat_rad = torch.deg2rad(lat) |
|
|
dec_rad = torch.deg2rad(declination) |
|
|
|
|
|
if dec_rad.dim() < lat_rad.dim(): |
|
|
dec_rad = dec_rad.unsqueeze(-1) |
|
|
|
|
|
|
|
|
if hour is not None: |
|
|
|
|
|
|
|
|
solar_time = hour + lon / 15.0 |
|
|
hour_angle = torch.deg2rad((solar_time - 12.0) * 15.0) |
|
|
else: |
|
|
hour_angle = torch.zeros_like(lat_rad) |
|
|
|
|
|
|
|
|
sin_alt = ( |
|
|
torch.sin(lat_rad) * torch.sin(dec_rad) + |
|
|
torch.cos(lat_rad) * torch.cos(dec_rad) * torch.cos(hour_angle) |
|
|
) |
|
|
solar_altitude = torch.arcsin(torch.clamp(sin_alt, -1, 1)) |
|
|
|
|
|
|
|
|
cos_azimuth = ( |
|
|
torch.sin(dec_rad) - torch.sin(lat_rad) * sin_alt |
|
|
) / (torch.cos(lat_rad) * torch.cos(solar_altitude) + 1e-8) |
|
|
solar_azimuth = torch.arccos(torch.clamp(cos_azimuth, -1, 1)) |
|
|
|
|
|
|
|
|
cos_hour_angle = -torch.tan(lat_rad) * torch.tan(dec_rad) |
|
|
cos_hour_angle = torch.clamp(cos_hour_angle, -1, 1) |
|
|
day_length = 2 * torch.arccos(cos_hour_angle) / math.pi * 12.0 |
|
|
|
|
|
|
|
|
features = torch.stack([ |
|
|
torch.sin(solar_altitude), |
|
|
torch.cos(solar_altitude), |
|
|
torch.sin(solar_azimuth), |
|
|
torch.cos(solar_azimuth), |
|
|
day_length / 24.0, |
|
|
declination / 23.45, |
|
|
], dim=-1) |
|
|
|
|
|
return self.norm(self.proj(features)) |
|
|
|
|
|
|
|
|
class ClimateEmbedding(nn.Module): |
|
|
""" |
|
|
Combined climate embedding module. |
|
|
|
|
|
Integrates: |
|
|
1. Seasonal/cyclical time features |
|
|
2. Climate indices (ENSO, MJO, NAO, etc.) |
|
|
3. Solar position |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
use_climate_indices: bool = True, |
|
|
use_solar_position: bool = True, |
|
|
max_harmonics: int = 4, |
|
|
): |
|
|
""" |
|
|
Initialize climate embedding. |
|
|
|
|
|
Args: |
|
|
d_model: Output dimension |
|
|
use_climate_indices: Include climate index embedding |
|
|
use_solar_position: Include solar position embedding |
|
|
max_harmonics: Number of harmonics for seasonal encoding |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.d_model = d_model |
|
|
self.use_climate_indices = use_climate_indices |
|
|
self.use_solar_position = use_solar_position |
|
|
|
|
|
|
|
|
self.seasonal = SeasonalEmbedding(d_model, max_harmonics) |
|
|
|
|
|
if use_climate_indices: |
|
|
self.climate_indices = ClimateIndexEmbedding(d_model) |
|
|
else: |
|
|
self.climate_indices = None |
|
|
|
|
|
if use_solar_position: |
|
|
self.solar = SolarPositionEmbedding(d_model) |
|
|
else: |
|
|
self.solar = None |
|
|
|
|
|
|
|
|
n_components = 1 + int(use_climate_indices) + int(use_solar_position) |
|
|
self.fusion = nn.Sequential( |
|
|
nn.Linear(d_model * n_components, d_model), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model, d_model), |
|
|
nn.LayerNorm(d_model), |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
day_of_year: torch.Tensor, |
|
|
hour: Optional[torch.Tensor] = None, |
|
|
lat: Optional[torch.Tensor] = None, |
|
|
lon: Optional[torch.Tensor] = None, |
|
|
climate_indices: Optional[torch.Tensor] = None, |
|
|
mjo_phase: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute climate embedding. |
|
|
|
|
|
Args: |
|
|
day_of_year: Day of year (1-366) |
|
|
hour: Hour of day (0-23) - optional |
|
|
lat: Latitude in degrees - for solar position |
|
|
lon: Longitude in degrees - for solar position |
|
|
climate_indices: Climate index values - optional |
|
|
mjo_phase: MJO phase (1-8) - optional |
|
|
|
|
|
Returns: |
|
|
Combined climate embedding |
|
|
""" |
|
|
embeddings = [] |
|
|
|
|
|
|
|
|
seasonal_emb = self.seasonal(day_of_year, hour) |
|
|
embeddings.append(seasonal_emb) |
|
|
|
|
|
|
|
|
if self.climate_indices is not None and climate_indices is not None: |
|
|
climate_emb = self.climate_indices(climate_indices, mjo_phase) |
|
|
embeddings.append(climate_emb) |
|
|
elif self.climate_indices is not None: |
|
|
|
|
|
shape = list(seasonal_emb.shape) |
|
|
embeddings.append(torch.zeros(shape, device=seasonal_emb.device)) |
|
|
|
|
|
|
|
|
if self.solar is not None and lat is not None and lon is not None: |
|
|
solar_emb = self.solar(lat, lon, day_of_year, hour) |
|
|
embeddings.append(solar_emb) |
|
|
elif self.solar is not None: |
|
|
shape = list(seasonal_emb.shape) |
|
|
embeddings.append(torch.zeros(shape, device=seasonal_emb.device)) |
|
|
|
|
|
|
|
|
combined = torch.cat(embeddings, dim=-1) |
|
|
return self.fusion(combined) |
|
|
|