File size: 13,720 Bytes
8bcb60f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
"""
Forecast Dataset for LILITH.

Provides data loading optimized for multi-station forecasting
with graph-based models.
"""

from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from loguru import logger


class ForecastDataset(Dataset):
    """
    Dataset for graph-based multi-station forecasting.

    Instead of loading single stations, this dataset loads data for
    multiple stations simultaneously, suitable for GNN-based models.

    Each sample contains:
    - Observations from N stations for the input period
    - Targets for N stations for the forecast period
    - Station coordinates and connectivity graph
    """

    def __init__(
        self,
        data_dir: Union[str, Path],
        sequence_length: int = 30,
        forecast_length: int = 14,
        max_stations: int = 500,
        spatial_radius: float = 5.0,  # degrees
        target_variables: Optional[List[str]] = None,
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
        seed: int = 42,
    ):
        """
        Initialize the forecast dataset.

        Args:
            data_dir: Directory with processed Parquet files
            sequence_length: Days of input history
            forecast_length: Days to forecast
            max_stations: Maximum stations per sample
            spatial_radius: Radius in degrees for station sampling
            target_variables: Variables to forecast
            start_date: Start date for data (YYYY-MM-DD)
            end_date: End date for data (YYYY-MM-DD)
            seed: Random seed for reproducibility
        """
        self.data_dir = Path(data_dir)
        self.sequence_length = sequence_length
        self.forecast_length = forecast_length
        self.total_length = sequence_length + forecast_length
        self.max_stations = max_stations
        self.spatial_radius = spatial_radius
        self.target_variables = target_variables or ["TMAX", "TMIN", "PRCP"]
        self.seed = seed

        self.rng = np.random.default_rng(seed)

        # Load station metadata
        self.stations = pd.read_parquet(self.data_dir / "stations.parquet")

        # Parse date range
        self.start_date = pd.Timestamp(start_date) if start_date else pd.Timestamp("2000-01-01")
        self.end_date = pd.Timestamp(end_date) if end_date else pd.Timestamp("2023-12-31")

        # Build date index
        self.dates = pd.date_range(
            self.start_date,
            self.end_date - pd.Timedelta(days=self.total_length),
            freq="D",
        )

        # Build spatial clusters for efficient sampling
        self._build_spatial_clusters()

        # Cache for loaded data
        self._data_cache: Dict[int, pd.DataFrame] = {}

        logger.info(
            f"ForecastDataset: {len(self.dates)} dates, "
            f"{len(self.stations)} stations, {len(self.clusters)} clusters"
        )

    def _build_spatial_clusters(self) -> None:
        """
        Build spatial clusters of stations for efficient sampling.

        Groups stations into overlapping clusters based on spatial proximity.
        """
        self.clusters = []

        # Grid-based clustering
        lat_bins = np.arange(-90, 90, self.spatial_radius * 2)
        lon_bins = np.arange(-180, 180, self.spatial_radius * 2)

        for lat in lat_bins:
            for lon in lon_bins:
                # Find stations in this grid cell (with overlap)
                mask = (
                    (self.stations["latitude"] >= lat - self.spatial_radius) &
                    (self.stations["latitude"] < lat + self.spatial_radius * 3) &
                    (self.stations["longitude"] >= lon - self.spatial_radius) &
                    (self.stations["longitude"] < lon + self.spatial_radius * 3)
                )
                cluster_stations = self.stations[mask]["station_id"].tolist()

                if len(cluster_stations) >= 10:  # Minimum cluster size
                    self.clusters.append({
                        "center_lat": lat + self.spatial_radius,
                        "center_lon": lon + self.spatial_radius,
                        "station_ids": cluster_stations,
                    })

    def _load_data_for_date(self, date: pd.Timestamp) -> pd.DataFrame:
        """Load data for a specific date range, with caching."""
        year = date.year
        end_year = (date + pd.Timedelta(days=self.total_length)).year

        # Load required years
        dfs = []
        for y in range(year, end_year + 1):
            if y in self._data_cache:
                dfs.append(self._data_cache[y])
            else:
                year_file = self.data_dir / f"observations_{y}.parquet"
                if year_file.exists():
                    df = pd.read_parquet(year_file)
                    self._data_cache[y] = df
                    dfs.append(df)

        if not dfs:
            return pd.DataFrame()

        return pd.concat(dfs)

    def _build_station_graph(
        self,
        station_coords: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Build adjacency information for stations.

        Returns edge_index and edge_attr for PyTorch Geometric.

        Args:
            station_coords: (N, 3) array of [lat, lon, elev]

        Returns:
            edge_index: (2, E) source and target node indices
            edge_attr: (E, 1) edge distances
        """
        n_stations = len(station_coords)
        edges_src = []
        edges_dst = []
        edge_weights = []

        # Connect stations within spatial radius
        for i in range(n_stations):
            for j in range(i + 1, n_stations):
                # Calculate distance
                dlat = station_coords[i, 0] - station_coords[j, 0]
                dlon = station_coords[i, 1] - station_coords[j, 1]
                dist = np.sqrt(dlat**2 + dlon**2)

                if dist < self.spatial_radius:
                    # Bidirectional edges
                    edges_src.extend([i, j])
                    edges_dst.extend([j, i])
                    edge_weights.extend([dist, dist])

        if not edges_src:
            # Fallback: connect to k nearest neighbors
            from scipy.spatial import KDTree

            tree = KDTree(station_coords[:, :2])
            for i in range(n_stations):
                _, neighbors = tree.query(station_coords[i, :2], k=min(5, n_stations))
                for j in neighbors:
                    if i != j:
                        dist = np.linalg.norm(station_coords[i, :2] - station_coords[j, :2])
                        edges_src.append(i)
                        edges_dst.append(j)
                        edge_weights.append(dist)

        edge_index = np.array([edges_src, edges_dst], dtype=np.int64)
        edge_attr = np.array(edge_weights, dtype=np.float32).reshape(-1, 1)

        return edge_index, edge_attr

    def __len__(self) -> int:
        return len(self.dates) * len(self.clusters)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """
        Get a multi-station sample.

        Returns:
            Dict with keys:
            - node_features: (N, seq_len, F) station observations
            - node_coords: (N, 3) lat/lon/elev
            - edge_index: (2, E) graph connectivity
            - edge_attr: (E, 1) edge weights
            - target_features: (N, forecast_len, T) targets
            - mask: (N, seq_len + forecast_len) valid mask
        """
        # Decode index
        date_idx = idx // len(self.clusters)
        cluster_idx = idx % len(self.clusters)

        date = self.dates[date_idx]
        cluster = self.clusters[cluster_idx]

        # Sample stations from cluster
        station_ids = cluster["station_ids"]
        if len(station_ids) > self.max_stations:
            station_ids = self.rng.choice(station_ids, self.max_stations, replace=False).tolist()

        n_stations = len(station_ids)

        # Load data
        data = self._load_data_for_date(date)
        if data.empty:
            return self._empty_sample(n_stations)

        # Filter to selected stations and date range
        end_date = date + pd.Timedelta(days=self.total_length - 1)
        mask = (
            data["station_id"].isin(station_ids) &
            (data.index >= date) &
            (data.index <= end_date)
        )
        data = data[mask]

        # Prepare feature arrays
        feature_cols = [c for c in self.target_variables if c in data.columns]
        n_features = len(feature_cols)

        node_features = np.zeros((n_stations, self.sequence_length, n_features), dtype=np.float32)
        target_features = np.zeros((n_stations, self.forecast_length, n_features), dtype=np.float32)
        node_coords = np.zeros((n_stations, 3), dtype=np.float32)
        valid_mask = np.zeros((n_stations, self.total_length), dtype=bool)

        # Fill in data for each station
        for i, station_id in enumerate(station_ids):
            station_data = data[data["station_id"] == station_id].sort_index()

            # Get station coords
            station_meta = self.stations[self.stations["station_id"] == station_id]
            if not station_meta.empty:
                node_coords[i] = [
                    station_meta.iloc[0]["latitude"],
                    station_meta.iloc[0]["longitude"],
                    station_meta.iloc[0].get("elevation", 0),
                ]

            # Fill input sequence
            for j, d in enumerate(pd.date_range(date, periods=self.sequence_length, freq="D")):
                if d in station_data.index:
                    row = station_data.loc[d]
                    if isinstance(row, pd.DataFrame):
                        row = row.iloc[0]
                    for k, col in enumerate(feature_cols):
                        val = row.get(col, np.nan)
                        if not pd.isna(val):
                            node_features[i, j, k] = val
                            valid_mask[i, j] = True

            # Fill target sequence
            target_start = date + pd.Timedelta(days=self.sequence_length)
            for j, d in enumerate(pd.date_range(target_start, periods=self.forecast_length, freq="D")):
                if d in station_data.index:
                    row = station_data.loc[d]
                    if isinstance(row, pd.DataFrame):
                        row = row.iloc[0]
                    for k, col in enumerate(feature_cols):
                        val = row.get(col, np.nan)
                        if not pd.isna(val):
                            target_features[i, j, k] = val
                            valid_mask[i, self.sequence_length + j] = True

        # Build graph
        edge_index, edge_attr = self._build_station_graph(node_coords)

        # Replace NaN with 0 (mask indicates valid values)
        node_features = np.nan_to_num(node_features, nan=0.0)
        target_features = np.nan_to_num(target_features, nan=0.0)

        return {
            "node_features": torch.from_numpy(node_features),
            "node_coords": torch.from_numpy(node_coords),
            "edge_index": torch.from_numpy(edge_index),
            "edge_attr": torch.from_numpy(edge_attr),
            "target_features": torch.from_numpy(target_features),
            "mask": torch.from_numpy(valid_mask),
            "n_stations": n_stations,
            "date": str(date.date()),
        }

    def _empty_sample(self, n_stations: int) -> Dict[str, torch.Tensor]:
        """Return an empty sample for error cases."""
        return {
            "node_features": torch.zeros(n_stations, self.sequence_length, len(self.target_variables)),
            "node_coords": torch.zeros(n_stations, 3),
            "edge_index": torch.zeros(2, 0, dtype=torch.long),
            "edge_attr": torch.zeros(0, 1),
            "target_features": torch.zeros(n_stations, self.forecast_length, len(self.target_variables)),
            "mask": torch.zeros(n_stations, self.total_length, dtype=torch.bool),
            "n_stations": n_stations,
            "date": "",
        }


def collate_variable_graphs(batch: List[Dict]) -> Dict[str, torch.Tensor]:
    """
    Custom collate function for variable-size graphs.

    Combines multiple samples into a single batched graph.
    """
    # Stack fixed-size tensors
    node_features = torch.cat([b["node_features"] for b in batch], dim=0)
    node_coords = torch.cat([b["node_coords"] for b in batch], dim=0)
    target_features = torch.cat([b["target_features"] for b in batch], dim=0)
    masks = torch.cat([b["mask"] for b in batch], dim=0)

    # Combine edge indices with offsets
    edge_indices = []
    edge_attrs = []
    offset = 0

    for b in batch:
        edge_index = b["edge_index"]
        if edge_index.size(1) > 0:
            edge_indices.append(edge_index + offset)
            edge_attrs.append(b["edge_attr"])
        offset += b["n_stations"]

    if edge_indices:
        edge_index = torch.cat(edge_indices, dim=1)
        edge_attr = torch.cat(edge_attrs, dim=0)
    else:
        edge_index = torch.zeros(2, 0, dtype=torch.long)
        edge_attr = torch.zeros(0, 1)

    # Batch indices for graph batching
    batch_idx = torch.cat([
        torch.full((b["n_stations"],), i, dtype=torch.long)
        for i, b in enumerate(batch)
    ])

    return {
        "node_features": node_features,
        "node_coords": node_coords,
        "edge_index": edge_index,
        "edge_attr": edge_attr,
        "target_features": target_features,
        "mask": masks,
        "batch": batch_idx,
    }