File size: 12,147 Bytes
8bcb60f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 |
"""
Climate Embedding Module for LILITH.
Encodes climate indices and large-scale patterns that influence
long-range weather predictability:
- ENSO (El Nino Southern Oscillation)
- MJO (Madden-Julian Oscillation)
- NAO (North Atlantic Oscillation)
- Seasonal cycles
- Solar position
"""
import math
from typing import Optional, Dict
import torch
import torch.nn as nn
class SeasonalEmbedding(nn.Module):
"""
Encodes seasonal and cyclical time features.
Uses sinusoidal encoding to capture annual, monthly, and daily cycles.
"""
def __init__(self, d_model: int, max_harmonics: int = 4):
"""
Initialize seasonal embedding.
Args:
d_model: Embedding dimension
max_harmonics: Number of harmonic frequencies
"""
super().__init__()
self.d_model = d_model
self.max_harmonics = max_harmonics
# Number of raw features:
# - day_of_year (sin/cos * harmonics)
# - hour_of_day (sin/cos * harmonics) - if needed
# - solar declination
# - equation of time
n_features = 4 * max_harmonics + 2
self.proj = nn.Linear(n_features, d_model)
self.norm = nn.LayerNorm(d_model)
def forward(
self,
day_of_year: torch.Tensor,
hour: Optional[torch.Tensor] = None,
year: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Encode time features.
Args:
day_of_year: Day of year (1-366), shape (batch, seq_len)
hour: Hour of day (0-23), shape (batch, seq_len) - optional
year: Year, shape (batch, seq_len) - optional
Returns:
Seasonal embedding of shape (batch, seq_len, d_model)
"""
features = []
# Day of year harmonics
for k in range(1, self.max_harmonics + 1):
freq = 2 * math.pi * k * day_of_year / 365.25
features.extend([torch.sin(freq), torch.cos(freq)])
# Hour of day harmonics (if provided)
if hour is not None:
for k in range(1, self.max_harmonics + 1):
freq = 2 * math.pi * k * hour / 24.0
features.extend([torch.sin(freq), torch.cos(freq)])
else:
# Pad with zeros
features.extend([torch.zeros_like(day_of_year)] * (2 * self.max_harmonics))
# Solar declination (approximate)
# Maximum ~23.45 degrees on summer solstice
declination = 23.45 * torch.sin(2 * math.pi * (day_of_year - 81) / 365.25)
features.append(declination / 23.45) # Normalize
# Equation of time (minutes, approximate)
B = 2 * math.pi * (day_of_year - 81) / 365.25
eot = 9.87 * torch.sin(2 * B) - 7.53 * torch.cos(B) - 1.5 * torch.sin(B)
features.append(eot / 15.0) # Normalize by max (~15 minutes)
# Stack features
encoding = torch.stack(features, dim=-1)
# Project to model dimension
return self.norm(self.proj(encoding))
class ClimateIndexEmbedding(nn.Module):
"""
Embeds climate indices (ENSO, NAO, MJO, etc.).
These indices capture large-scale climate patterns that influence
weather predictability on subseasonal to seasonal timescales.
"""
def __init__(
self,
d_model: int,
n_indices: int = 10,
hidden_dim: int = 64,
):
"""
Initialize climate index embedding.
Args:
d_model: Output embedding dimension
n_indices: Maximum number of climate indices
hidden_dim: Hidden dimension
"""
super().__init__()
self.d_model = d_model
self.n_indices = n_indices
# Index names and their typical value ranges
self.index_names = [
"nino34", # ENSO: Nino 3.4 SST anomaly
"nino12", # ENSO: Nino 1+2 region
"soi", # Southern Oscillation Index
"mjo_amp", # MJO amplitude
"mjo_phase", # MJO phase (1-8, encoded as sin/cos)
"nao", # North Atlantic Oscillation
"ao", # Arctic Oscillation
"pdo", # Pacific Decadal Oscillation
"amo", # Atlantic Multidecadal Oscillation
"qbo", # Quasi-Biennial Oscillation
]
# Embedding network
self.mlp = nn.Sequential(
nn.Linear(n_indices + 2, hidden_dim), # +2 for MJO phase sin/cos
nn.GELU(),
nn.Linear(hidden_dim, d_model),
nn.LayerNorm(d_model),
)
def forward(
self,
indices: torch.Tensor,
mjo_phase: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Embed climate indices.
Args:
indices: Climate index values of shape (batch, seq_len, n_indices)
Values should be normalized/standardized
mjo_phase: MJO phase (1-8) of shape (batch, seq_len)
Returns:
Embedding of shape (batch, seq_len, d_model)
"""
# Handle MJO phase specially (cyclical)
if mjo_phase is not None:
phase_rad = 2 * math.pi * mjo_phase / 8.0
mjo_sin = torch.sin(phase_rad).unsqueeze(-1)
mjo_cos = torch.cos(phase_rad).unsqueeze(-1)
x = torch.cat([indices, mjo_sin, mjo_cos], dim=-1)
else:
# Pad if MJO phase not provided
batch, seq_len, _ = indices.shape
padding = torch.zeros(batch, seq_len, 2, device=indices.device)
x = torch.cat([indices, padding], dim=-1)
return self.mlp(x)
class SolarPositionEmbedding(nn.Module):
"""
Encodes solar position for each location and time.
Critical for capturing diurnal cycles and their geographic variation.
"""
def __init__(self, d_model: int):
super().__init__()
# Features: solar altitude, azimuth, day length, sunrise/sunset
n_features = 6
self.proj = nn.Linear(n_features, d_model)
self.norm = nn.LayerNorm(d_model)
def forward(
self,
lat: torch.Tensor,
lon: torch.Tensor,
day_of_year: torch.Tensor,
hour: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute solar position features.
Args:
lat: Latitude in degrees, shape (batch, n_locations) or (batch,)
lon: Longitude in degrees, shape (batch, n_locations) or (batch,)
day_of_year: Day of year, shape (batch,) or (batch, seq_len)
hour: Hour of day (0-24), shape (batch,) or (batch, seq_len)
Returns:
Solar position embedding
"""
# Ensure proper broadcasting
if lat.dim() == 1:
lat = lat.unsqueeze(-1)
if lon.dim() == 1:
lon = lon.unsqueeze(-1)
# Solar declination (degrees)
declination = 23.45 * torch.sin(
torch.deg2rad(torch.tensor(360 / 365.25 * (day_of_year - 81)))
)
# Convert to radians
lat_rad = torch.deg2rad(lat)
dec_rad = torch.deg2rad(declination)
if dec_rad.dim() < lat_rad.dim():
dec_rad = dec_rad.unsqueeze(-1)
# Hour angle (if hour provided)
if hour is not None:
# Solar noon at longitude 0 is at 12:00 UTC
# Each 15 degrees of longitude = 1 hour offset
solar_time = hour + lon / 15.0
hour_angle = torch.deg2rad((solar_time - 12.0) * 15.0)
else:
hour_angle = torch.zeros_like(lat_rad)
# Solar altitude (elevation angle)
sin_alt = (
torch.sin(lat_rad) * torch.sin(dec_rad) +
torch.cos(lat_rad) * torch.cos(dec_rad) * torch.cos(hour_angle)
)
solar_altitude = torch.arcsin(torch.clamp(sin_alt, -1, 1))
# Solar azimuth
cos_azimuth = (
torch.sin(dec_rad) - torch.sin(lat_rad) * sin_alt
) / (torch.cos(lat_rad) * torch.cos(solar_altitude) + 1e-8)
solar_azimuth = torch.arccos(torch.clamp(cos_azimuth, -1, 1))
# Day length (hours)
cos_hour_angle = -torch.tan(lat_rad) * torch.tan(dec_rad)
cos_hour_angle = torch.clamp(cos_hour_angle, -1, 1)
day_length = 2 * torch.arccos(cos_hour_angle) / math.pi * 12.0
# Normalize features
features = torch.stack([
torch.sin(solar_altitude),
torch.cos(solar_altitude),
torch.sin(solar_azimuth),
torch.cos(solar_azimuth),
day_length / 24.0,
declination / 23.45,
], dim=-1)
return self.norm(self.proj(features))
class ClimateEmbedding(nn.Module):
"""
Combined climate embedding module.
Integrates:
1. Seasonal/cyclical time features
2. Climate indices (ENSO, MJO, NAO, etc.)
3. Solar position
"""
def __init__(
self,
d_model: int,
use_climate_indices: bool = True,
use_solar_position: bool = True,
max_harmonics: int = 4,
):
"""
Initialize climate embedding.
Args:
d_model: Output dimension
use_climate_indices: Include climate index embedding
use_solar_position: Include solar position embedding
max_harmonics: Number of harmonics for seasonal encoding
"""
super().__init__()
self.d_model = d_model
self.use_climate_indices = use_climate_indices
self.use_solar_position = use_solar_position
# Component embeddings
self.seasonal = SeasonalEmbedding(d_model, max_harmonics)
if use_climate_indices:
self.climate_indices = ClimateIndexEmbedding(d_model)
else:
self.climate_indices = None
if use_solar_position:
self.solar = SolarPositionEmbedding(d_model)
else:
self.solar = None
# Fusion layer
n_components = 1 + int(use_climate_indices) + int(use_solar_position)
self.fusion = nn.Sequential(
nn.Linear(d_model * n_components, d_model),
nn.GELU(),
nn.Linear(d_model, d_model),
nn.LayerNorm(d_model),
)
def forward(
self,
day_of_year: torch.Tensor,
hour: Optional[torch.Tensor] = None,
lat: Optional[torch.Tensor] = None,
lon: Optional[torch.Tensor] = None,
climate_indices: Optional[torch.Tensor] = None,
mjo_phase: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute climate embedding.
Args:
day_of_year: Day of year (1-366)
hour: Hour of day (0-23) - optional
lat: Latitude in degrees - for solar position
lon: Longitude in degrees - for solar position
climate_indices: Climate index values - optional
mjo_phase: MJO phase (1-8) - optional
Returns:
Combined climate embedding
"""
embeddings = []
# Seasonal embedding (always included)
seasonal_emb = self.seasonal(day_of_year, hour)
embeddings.append(seasonal_emb)
# Climate indices
if self.climate_indices is not None and climate_indices is not None:
climate_emb = self.climate_indices(climate_indices, mjo_phase)
embeddings.append(climate_emb)
elif self.climate_indices is not None:
# Use zeros if no indices provided
shape = list(seasonal_emb.shape)
embeddings.append(torch.zeros(shape, device=seasonal_emb.device))
# Solar position
if self.solar is not None and lat is not None and lon is not None:
solar_emb = self.solar(lat, lon, day_of_year, hour)
embeddings.append(solar_emb)
elif self.solar is not None:
shape = list(seasonal_emb.shape)
embeddings.append(torch.zeros(shape, device=seasonal_emb.device))
# Fuse embeddings
combined = torch.cat(embeddings, dim=-1)
return self.fusion(combined)
|