File size: 12,147 Bytes
8bcb60f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
"""
Climate Embedding Module for LILITH.

Encodes climate indices and large-scale patterns that influence
long-range weather predictability:
- ENSO (El Nino Southern Oscillation)
- MJO (Madden-Julian Oscillation)
- NAO (North Atlantic Oscillation)
- Seasonal cycles
- Solar position
"""

import math
from typing import Optional, Dict

import torch
import torch.nn as nn


class SeasonalEmbedding(nn.Module):
    """
    Encodes seasonal and cyclical time features.

    Uses sinusoidal encoding to capture annual, monthly, and daily cycles.
    """

    def __init__(self, d_model: int, max_harmonics: int = 4):
        """
        Initialize seasonal embedding.

        Args:
            d_model: Embedding dimension
            max_harmonics: Number of harmonic frequencies
        """
        super().__init__()

        self.d_model = d_model
        self.max_harmonics = max_harmonics

        # Number of raw features:
        # - day_of_year (sin/cos * harmonics)
        # - hour_of_day (sin/cos * harmonics) - if needed
        # - solar declination
        # - equation of time
        n_features = 4 * max_harmonics + 2

        self.proj = nn.Linear(n_features, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(
        self,
        day_of_year: torch.Tensor,
        hour: Optional[torch.Tensor] = None,
        year: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Encode time features.

        Args:
            day_of_year: Day of year (1-366), shape (batch, seq_len)
            hour: Hour of day (0-23), shape (batch, seq_len) - optional
            year: Year, shape (batch, seq_len) - optional

        Returns:
            Seasonal embedding of shape (batch, seq_len, d_model)
        """
        features = []

        # Day of year harmonics
        for k in range(1, self.max_harmonics + 1):
            freq = 2 * math.pi * k * day_of_year / 365.25
            features.extend([torch.sin(freq), torch.cos(freq)])

        # Hour of day harmonics (if provided)
        if hour is not None:
            for k in range(1, self.max_harmonics + 1):
                freq = 2 * math.pi * k * hour / 24.0
                features.extend([torch.sin(freq), torch.cos(freq)])
        else:
            # Pad with zeros
            features.extend([torch.zeros_like(day_of_year)] * (2 * self.max_harmonics))

        # Solar declination (approximate)
        # Maximum ~23.45 degrees on summer solstice
        declination = 23.45 * torch.sin(2 * math.pi * (day_of_year - 81) / 365.25)
        features.append(declination / 23.45)  # Normalize

        # Equation of time (minutes, approximate)
        B = 2 * math.pi * (day_of_year - 81) / 365.25
        eot = 9.87 * torch.sin(2 * B) - 7.53 * torch.cos(B) - 1.5 * torch.sin(B)
        features.append(eot / 15.0)  # Normalize by max (~15 minutes)

        # Stack features
        encoding = torch.stack(features, dim=-1)

        # Project to model dimension
        return self.norm(self.proj(encoding))


class ClimateIndexEmbedding(nn.Module):
    """
    Embeds climate indices (ENSO, NAO, MJO, etc.).

    These indices capture large-scale climate patterns that influence
    weather predictability on subseasonal to seasonal timescales.
    """

    def __init__(
        self,
        d_model: int,
        n_indices: int = 10,
        hidden_dim: int = 64,
    ):
        """
        Initialize climate index embedding.

        Args:
            d_model: Output embedding dimension
            n_indices: Maximum number of climate indices
            hidden_dim: Hidden dimension
        """
        super().__init__()

        self.d_model = d_model
        self.n_indices = n_indices

        # Index names and their typical value ranges
        self.index_names = [
            "nino34",      # ENSO: Nino 3.4 SST anomaly
            "nino12",      # ENSO: Nino 1+2 region
            "soi",         # Southern Oscillation Index
            "mjo_amp",     # MJO amplitude
            "mjo_phase",   # MJO phase (1-8, encoded as sin/cos)
            "nao",         # North Atlantic Oscillation
            "ao",          # Arctic Oscillation
            "pdo",         # Pacific Decadal Oscillation
            "amo",         # Atlantic Multidecadal Oscillation
            "qbo",         # Quasi-Biennial Oscillation
        ]

        # Embedding network
        self.mlp = nn.Sequential(
            nn.Linear(n_indices + 2, hidden_dim),  # +2 for MJO phase sin/cos
            nn.GELU(),
            nn.Linear(hidden_dim, d_model),
            nn.LayerNorm(d_model),
        )

    def forward(
        self,
        indices: torch.Tensor,
        mjo_phase: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Embed climate indices.

        Args:
            indices: Climate index values of shape (batch, seq_len, n_indices)
                    Values should be normalized/standardized
            mjo_phase: MJO phase (1-8) of shape (batch, seq_len)

        Returns:
            Embedding of shape (batch, seq_len, d_model)
        """
        # Handle MJO phase specially (cyclical)
        if mjo_phase is not None:
            phase_rad = 2 * math.pi * mjo_phase / 8.0
            mjo_sin = torch.sin(phase_rad).unsqueeze(-1)
            mjo_cos = torch.cos(phase_rad).unsqueeze(-1)
            x = torch.cat([indices, mjo_sin, mjo_cos], dim=-1)
        else:
            # Pad if MJO phase not provided
            batch, seq_len, _ = indices.shape
            padding = torch.zeros(batch, seq_len, 2, device=indices.device)
            x = torch.cat([indices, padding], dim=-1)

        return self.mlp(x)


class SolarPositionEmbedding(nn.Module):
    """
    Encodes solar position for each location and time.

    Critical for capturing diurnal cycles and their geographic variation.
    """

    def __init__(self, d_model: int):
        super().__init__()

        # Features: solar altitude, azimuth, day length, sunrise/sunset
        n_features = 6
        self.proj = nn.Linear(n_features, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(
        self,
        lat: torch.Tensor,
        lon: torch.Tensor,
        day_of_year: torch.Tensor,
        hour: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Compute solar position features.

        Args:
            lat: Latitude in degrees, shape (batch, n_locations) or (batch,)
            lon: Longitude in degrees, shape (batch, n_locations) or (batch,)
            day_of_year: Day of year, shape (batch,) or (batch, seq_len)
            hour: Hour of day (0-24), shape (batch,) or (batch, seq_len)

        Returns:
            Solar position embedding
        """
        # Ensure proper broadcasting
        if lat.dim() == 1:
            lat = lat.unsqueeze(-1)
        if lon.dim() == 1:
            lon = lon.unsqueeze(-1)

        # Solar declination (degrees)
        declination = 23.45 * torch.sin(
            torch.deg2rad(torch.tensor(360 / 365.25 * (day_of_year - 81)))
        )

        # Convert to radians
        lat_rad = torch.deg2rad(lat)
        dec_rad = torch.deg2rad(declination)

        if dec_rad.dim() < lat_rad.dim():
            dec_rad = dec_rad.unsqueeze(-1)

        # Hour angle (if hour provided)
        if hour is not None:
            # Solar noon at longitude 0 is at 12:00 UTC
            # Each 15 degrees of longitude = 1 hour offset
            solar_time = hour + lon / 15.0
            hour_angle = torch.deg2rad((solar_time - 12.0) * 15.0)
        else:
            hour_angle = torch.zeros_like(lat_rad)

        # Solar altitude (elevation angle)
        sin_alt = (
            torch.sin(lat_rad) * torch.sin(dec_rad) +
            torch.cos(lat_rad) * torch.cos(dec_rad) * torch.cos(hour_angle)
        )
        solar_altitude = torch.arcsin(torch.clamp(sin_alt, -1, 1))

        # Solar azimuth
        cos_azimuth = (
            torch.sin(dec_rad) - torch.sin(lat_rad) * sin_alt
        ) / (torch.cos(lat_rad) * torch.cos(solar_altitude) + 1e-8)
        solar_azimuth = torch.arccos(torch.clamp(cos_azimuth, -1, 1))

        # Day length (hours)
        cos_hour_angle = -torch.tan(lat_rad) * torch.tan(dec_rad)
        cos_hour_angle = torch.clamp(cos_hour_angle, -1, 1)
        day_length = 2 * torch.arccos(cos_hour_angle) / math.pi * 12.0

        # Normalize features
        features = torch.stack([
            torch.sin(solar_altitude),
            torch.cos(solar_altitude),
            torch.sin(solar_azimuth),
            torch.cos(solar_azimuth),
            day_length / 24.0,
            declination / 23.45,
        ], dim=-1)

        return self.norm(self.proj(features))


class ClimateEmbedding(nn.Module):
    """
    Combined climate embedding module.

    Integrates:
    1. Seasonal/cyclical time features
    2. Climate indices (ENSO, MJO, NAO, etc.)
    3. Solar position
    """

    def __init__(
        self,
        d_model: int,
        use_climate_indices: bool = True,
        use_solar_position: bool = True,
        max_harmonics: int = 4,
    ):
        """
        Initialize climate embedding.

        Args:
            d_model: Output dimension
            use_climate_indices: Include climate index embedding
            use_solar_position: Include solar position embedding
            max_harmonics: Number of harmonics for seasonal encoding
        """
        super().__init__()

        self.d_model = d_model
        self.use_climate_indices = use_climate_indices
        self.use_solar_position = use_solar_position

        # Component embeddings
        self.seasonal = SeasonalEmbedding(d_model, max_harmonics)

        if use_climate_indices:
            self.climate_indices = ClimateIndexEmbedding(d_model)
        else:
            self.climate_indices = None

        if use_solar_position:
            self.solar = SolarPositionEmbedding(d_model)
        else:
            self.solar = None

        # Fusion layer
        n_components = 1 + int(use_climate_indices) + int(use_solar_position)
        self.fusion = nn.Sequential(
            nn.Linear(d_model * n_components, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model),
            nn.LayerNorm(d_model),
        )

    def forward(
        self,
        day_of_year: torch.Tensor,
        hour: Optional[torch.Tensor] = None,
        lat: Optional[torch.Tensor] = None,
        lon: Optional[torch.Tensor] = None,
        climate_indices: Optional[torch.Tensor] = None,
        mjo_phase: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Compute climate embedding.

        Args:
            day_of_year: Day of year (1-366)
            hour: Hour of day (0-23) - optional
            lat: Latitude in degrees - for solar position
            lon: Longitude in degrees - for solar position
            climate_indices: Climate index values - optional
            mjo_phase: MJO phase (1-8) - optional

        Returns:
            Combined climate embedding
        """
        embeddings = []

        # Seasonal embedding (always included)
        seasonal_emb = self.seasonal(day_of_year, hour)
        embeddings.append(seasonal_emb)

        # Climate indices
        if self.climate_indices is not None and climate_indices is not None:
            climate_emb = self.climate_indices(climate_indices, mjo_phase)
            embeddings.append(climate_emb)
        elif self.climate_indices is not None:
            # Use zeros if no indices provided
            shape = list(seasonal_emb.shape)
            embeddings.append(torch.zeros(shape, device=seasonal_emb.device))

        # Solar position
        if self.solar is not None and lat is not None and lon is not None:
            solar_emb = self.solar(lat, lon, day_of_year, hour)
            embeddings.append(solar_emb)
        elif self.solar is not None:
            shape = list(seasonal_emb.shape)
            embeddings.append(torch.zeros(shape, device=seasonal_emb.device))

        # Fuse embeddings
        combined = torch.cat(embeddings, dim=-1)
        return self.fusion(combined)