| import xarray as xr |
| from datetime import datetime, timedelta |
| import pandas as pd |
| import numpy as np |
| import sys |
| import os |
|
|
| import torch |
| import random |
| from torch.utils import data |
| import torch.nn.functional as F |
|
|
| import matplotlib.pyplot as plt |
|
|
|
|
| sfc_vars_1 = ['SSTK', 'TCW', 'TCWV', 'CP', 'MSL', 'TCC', 'U10M', 'V10M', 'T2M', 'TP', 'SKT'] |
| sfc_vars_2 = ['sst', 'tcw', 'tcwv', 'cp', 'msl', 'tcc', 'u10', 'v10', 't2m', 'tp', 'skt'] |
| pl_vars_1 = ["Z", "T", "Q", "W", "D", "U", "V"] |
| pl_vars_2 = ["z", "t", "q", "w", "d", "u", "v"] |
| var_map = {} |
| for var1, var2 in zip(sfc_vars_1+pl_vars_1, sfc_vars_2+pl_vars_2): |
| var_map[var1] = var2 |
|
|
| class Aurora_CDF_Dataset_china(data.Dataset): |
| |
| """Dataset class for the era5 upper and surface variables.""" |
|
|
| def __init__(self, |
| nc_path='', |
| seed=1234, |
| startDate='2010', |
| endDate='2020', |
| freq='12h', |
| horizon = 12, |
| surface = ["2m_temperature","10m_u_component_of_wind","10m_v_component_of_wind", "mean_sea_level_pressure", "total_precipitation_6hr"], |
| upper = ["temperature", "u_component_of_wind", "v_component_of_wind", "relative_humidity", "geopotential"], |
| level = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000], |
| h = 240, |
| w = 240, |
| degree = 0.25, |
| num_points = 5, |
| evaluate = False, |
| ): |
| |
| """Initialize.""" |
| self.nc_path = nc_path |
| """ |
| To do |
| if start and end is valid date, if the date can be found in the downloaded files, length >= 0 |
| |
| """ |
| |
|
|
| self.freq = int(freq[:-1]) |
| self.surface_variables = surface |
| self.upper_variables = upper |
| self.levels = level |
| self.horizon = horizon |
| self.h = h |
| self.w = w |
| self.degree = degree |
| self.num_points = num_points |
|
|
| self.keys = list(pd.date_range(start=startDate, end=endDate, freq=freq))[:-1] |
|
|
| random.seed(seed) |
| |
| |
|
|
| def __getitem__(self, index): |
| """Return input frames, target frames, and its corresponding time steps.""" |
| key = self.keys[index] |
| |
| input_surfaces = [] |
| input_uppers = [] |
|
|
| time_points = [] |
| |
| for p in range(self.num_points): |
| time_str = datetime.strftime(key, '%Y%m%d%H') |
| time_points.append(key.timestamp()) |
| data = np.load(os.path.join(self.nc_path,f"{time_str}.npy")).astype(np.float32) |
|
|
| input_surface_variables = data[:,0] |
| input_upper_variables = data[:,1:] |
|
|
| input_surfaces.append(input_surface_variables[np.newaxis, ...]) |
| input_uppers.append(input_upper_variables[np.newaxis, ...]) |
| key = key + timedelta(hours=self.horizon) |
| |
| return np.concatenate(input_surfaces,axis=0), np.concatenate(input_uppers, axis=0), time_points |
|
|
| def __len__(self): |
| return len(self.keys) |
|
|
| def __repr__(self): |
| return self.__class__.__name__ |
| |
| class Dataset_Postprocessing(data.Dataset): |
| |
| def __init__(self, |
| data_path='', |
| seed=1234, |
| start_date='1998-01-01', |
| end_date='2015-12-31', |
| val = False, |
| surface = ['SSTK', 'TCW', 'TCWV', 'CP', 'MSL', 'TCC', 'U10M', 'V10M', 'T2M', 'TP', 'SKT'], |
| upper = ["Z", "T", "Q", "W", "D", "U", "V"], |
| levels = [500, 850], |
| target_surface= ["T2M",'U10M', 'V10M'], |
| target_upper = ["Z", "T"], |
| target_level = [500, 850], |
| H = 361, |
| W = 720, |
| full = False, |
| ): |
| |
| """Initialize.""" |
| self.data_path = data_path |
| self.surface_variables = surface |
| self.upper_variables = upper |
| self.levels = levels |
| self.target_sfc_variables = target_surface |
| self.target_pl_variables = target_upper |
| self.target_pl_level = target_level |
| self.H = 361 |
| self.W = 720 |
| self.full = full |
| |
| time_range = slice(start_date, end_date) |
| if self.full: |
| self.ensemble_path = os.path.join(data_path, "ensemble") |
| self.T = 10 |
| self.sfc_value_range = {"sst": (260., 305.), "tcw": (0., 60.), "tcwv": (0., 60.), "cp": (0., 0.04), "msl": (97000., 1.1e5), "tcc": (0., 1.0), "u10": (-13., 11.), "v10": (-10., 15.), "t2m":(218, 304), "tp": (0., 0.07), "skt": (210., 310.)} |
| self.pl_value_range = [{"z": (48200, 58000), "t": (230, 269), "q": (0., 4e-3), "w": (-0.7, 1.4), "d": (-5e-5, 8e-5), "u": (-7., 27.), "v": (-7., 7.)}, |
| {"z": (10000, 15500), "t":(240, 299), "q": (0., 1.5e-2), "w": (-1.2, 1.8), "d": (-1.9e-4, 1.6e-4), "u": (-16., 17.5), "v": (-10., 16.)}] |
| else: |
| self.T = 2 |
| self.sfc_ens_mean_normalized = xr.open_dataset(os.path.join(data_path, "ENS10_sfc_mean_normalized.nc"), engine="h5netcdf").sel(time=time_range) |
| self.sfc_ens_std_normalized = xr.open_dataset(os.path.join(data_path, "ENS10_sfc_std_normalized.nc"), engine="h5netcdf").sel(time=time_range) |
| self.sfc_ens_mean = xr.open_dataset(os.path.join(data_path, "ENS10_sfc_mean.nc"), engine="h5netcdf").sel(time=time_range) |
| self.sfc_ens_std = xr.open_dataset(os.path.join(data_path, "ENS10_sfc_std.nc"), engine="h5netcdf").sel(time=time_range) |
| |
| |
| self.pl_ens_mean_normalized = [xr.open_dataset(os.path.join(data_path, f"ENS10_pl_mean_{str(l)}_normalized.nc"), engine="h5netcdf").sel(time=time_range) for l in levels] |
| self.pl_ens_std_normalized = [xr.open_dataset(os.path.join(data_path, f"ENS10_pl_std_{str(l)}_normalized.nc"), engine="h5netcdf").sel(time=time_range) for l in levels] |
| self.pl_ens_mean = [xr.open_dataset(os.path.join(data_path, f"ENS10_pl_mean_{str(l)}.nc"), engine="h5netcdf").sel(time=time_range) for l in levels] |
| self.pl_ens_std = [xr.open_dataset(os.path.join(data_path, f"ENS10_pl_std_{str(l)}.nc"), engine="h5netcdf").sel(time=time_range) for l in levels] |
| |
| self.era5 = xr.open_dataset(os.path.join(data_path, "ERA5.nc"), engine="h5netcdf").sel(time=time_range) |
| if val: |
| self.keys = self.sfc_ens_mean.drop_sel(time="2017-01-02").time.values |
| |
| else: |
| self.keys = self.sfc_ens_mean.time.values |
| self.era5_sfc_scale = {} |
|
|
| random.seed(seed) |
|
|
| def __getitem__(self, index): |
| """Return input frames, target frames, and its corresponding time steps.""" |
| time_points = [] |
| key = self.keys[index] |
| time_points.append(key.item()) |
| |
| sfc_inputs = np.zeros((self.T, len(self.surface_variables), self.H, self.W)).astype(np.float32) |
| pl_inputs = np.zeros((self.T, len(self.upper_variables), len(self.levels), self.H, self.W)).astype(np.float32) |
| |
| ds_targets = self.era5.sel(time=key) |
| time_str = (pd.to_datetime(key)-timedelta(days=2)).strftime("%Y%m%d") |
| |
| sfc_targets = np.zeros((len(self.target_sfc_variables), self.H, self.W)).astype(np.float32) |
| sfc_scale = np.zeros((2, len(self.target_sfc_variables), self.H, self.W)).astype(np.float32) |
|
|
| if self.target_sfc_variables: |
| for i in range(len(self.target_sfc_variables)): |
| sfc_targets[i] = ds_targets[self.target_sfc_variables[i]].values.astype(np.float32) |
| sfc_scale[0,i] = self.sfc_ens_mean[self.target_sfc_variables[i]].sel(time=key).values.astype(np.float32) |
| sfc_scale[1,i] = self.sfc_ens_std[self.target_sfc_variables[i]].sel(time=key).values.astype(np.float32) |
| |
| |
| pl_targets = np.zeros((len(self.target_pl_variables), self.H, self.W)).astype(np.float32) |
| pl_scale = np.zeros((2, len(self.target_pl_variables), self.H, self.W)).astype(np.float32) |
| |
| if self.target_pl_variables: |
| for i in range(len(self.target_pl_variables)): |
| pl_targets[i] = ds_targets[self.target_pl_variables[i]].values.astype(np.float32)[0] |
| pl_scale[0,i] = self.pl_ens_mean[self.levels.index(self.target_pl_level[i])][self.target_pl_variables[i]].sel(time=key,plev=self.target_pl_level[i]*1e2).values.astype(np.float32) |
| pl_scale[1,i] = self.pl_ens_std[self.levels.index(self.target_pl_level[i])][self.target_pl_variables[i]].sel(time=key,plev=self.target_pl_level[i]*1e2).values.astype(np.float32) |
|
|
| if self.full: |
| sfc_ds = xr.open_dataset(os.path.join(self.ensemble_path, f"output.sfc.{time_str}.grib"), backend_kwargs={"indexpath":""}).fillna(9999.0) |
| pl_ds = xr.open_dataset(os.path.join(self.ensemble_path, f"output.pl.{time_str}.grib"), backend_kwargs={"indexpath":""}).sel(isobaricInhPa=self.levels).fillna(9999.0) |
| for i, var in enumerate(self.surface_variables): |
| value = sfc_ds[var_map[var]].values.astype(np.float32) |
| minval, maxval = self.sfc_value_range[var_map[var]] |
| sfc_inputs[:,i] = (value - minval) / (maxval - minval) |
| for i, var in enumerate(self.upper_variables): |
| value = pl_ds[var_map[var]].values.astype(np.float32) |
| for j in range(len(self.levels)): |
| minval, maxval = self.pl_value_range[j][var_map[var]] |
| pl_inputs[:,i,j] = (value[:,j] - minval) / (maxval - minval) |
| else: |
| for i, var in enumerate(self.surface_variables): |
| sfc_inputs[1,i] = self.sfc_ens_mean_normalized[var].sel(time=key).values.astype(np.float32) |
| sfc_inputs[0,i] = self.sfc_ens_std_normalized[var].sel(time=key).values.astype(np.float32) |
|
|
| for i, var in enumerate(self.upper_variables): |
| for j, l in enumerate(self.levels): |
| pl_inputs[1,i,j] = self.pl_ens_mean_normalized[j][var].sel(time=key).values[0].astype(np.float32) |
| pl_inputs[0,i,j] = self.pl_ens_std_normalized[j][var].sel(time=key).values[0].astype(np.float32) |
|
|
|
|
|
|
| return sfc_inputs, pl_inputs, sfc_scale[:,:,:-1], sfc_targets[:,:-1], pl_scale[:,:,:-1],pl_targets[:,:-1], time_points |
|
|
| def __len__(self): |
| return len(self.keys) |
|
|
| def __repr__(self): |
| return self.__class__.__name__ |
| |
| class Aurora_downscale(data.Dataset): |
| |
| """Dataset class for the era5 upper and surface variables.""" |
|
|
| def __init__(self, |
| nc_path='', |
| seed=1234, |
| startDate='1999', |
| endDate='2020', |
| freq='6h', |
| surface_prefix = ["2m_temperature","10m_u_component_of_wind","10m_v_component_of_wind"], |
| upper_prefix = ["temperature", "u_component_of_wind", "v_component_of_wind", "specific_humidity", "geopotential"], |
| surface = ["t2m","u10","v10"], |
| upper = ["t", "u", "v", "q", "z"], |
| levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000], |
| num_points = 2, |
| degree = ("5.625", "1.40625"), |
| lr_dir = None, |
| hr_dir = None, |
| lr_degree = None, |
| hr_degree = None, |
| spatial_multiple = 4, |
| ): |
| |
| """Initialize.""" |
| self.nc_path = nc_path |
| """ |
| To do |
| if start and end is valid date, if the date can be found in the downloaded files, length >= 0 |
| |
| """ |
| |
|
|
| self.freq = int(freq[:-1]) |
| self.surface_variables = surface |
| self.upper_variables = upper |
| self.levels = levels |
| self.spatial_multiple = int(spatial_multiple) |
|
|
| self.num_points = num_points |
|
|
| |
| self.keys = list(pd.date_range(start=startDate, end=endDate, freq=freq))[:-2] |
| |
| self.lr_datasets = {} |
| self.hr_datasets = {} |
|
|
| years = [datetime.strftime(key, '%Y') for key in list(pd.date_range(start=startDate[:4], end=endDate[:4], freq="YE"))] |
| if not years: |
| raise ValueError("No years were resolved from startDate/endDate.") |
|
|
| if lr_degree is None: |
| lr_degree = degree[0] |
| if hr_degree is None: |
| hr_degree = degree[1] |
|
|
| |
| sample_prefix = surface_prefix[0] |
| sample_year = years[0] |
| lr_dir, lr_degree = self._resolve_dataset_layout( |
| nc_root=self.nc_path, |
| prefix=sample_prefix, |
| year=sample_year, |
| requested_dir=lr_dir, |
| requested_degree=lr_degree, |
| candidates=(("5.625", "5.625"), ("5.625_nc", "5.625")), |
| kind="LR", |
| ) |
| hr_dir, hr_degree = self._resolve_dataset_layout( |
| nc_root=self.nc_path, |
| prefix=sample_prefix, |
| year=sample_year, |
| requested_dir=hr_dir, |
| requested_degree=hr_degree, |
| candidates=( |
| ("1.40625", "1.40625"), |
| ("1.40625_nc", "1.40625"), |
| ("1.5_nc", "1.5"), |
| ("1.5", "1.5"), |
| ), |
| kind="HR", |
| ) |
|
|
| self.lr_dir = lr_dir |
| self.hr_dir = hr_dir |
| self.lr_degree = lr_degree |
| self.hr_degree = hr_degree |
| print(f"[Aurora_downscale] LR={self.lr_dir} ({self.lr_degree}deg), HR={self.hr_dir} ({self.hr_degree}deg)") |
| |
| for i, var in enumerate(self.surface_variables): |
| for year in years: |
| self.lr_datasets[f"{var}_{year}"] = xr.open_dataset( |
| os.path.join(self.nc_path, self.lr_dir, surface_prefix[i], f'{surface_prefix[i]}_{year}_{self.lr_degree}deg.nc') |
| ) |
| self.hr_datasets[f"{var}_{year}"] = xr.open_dataset( |
| os.path.join(self.nc_path, self.hr_dir, surface_prefix[i], f'{surface_prefix[i]}_{year}_{self.hr_degree}deg.nc') |
| ) |
|
|
| for i, var in enumerate(self.upper_variables): |
| for year in years: |
| self.lr_datasets[f"{var}_{year}"] = xr.open_dataset( |
| os.path.join(self.nc_path, self.lr_dir, upper_prefix[i], f'{upper_prefix[i]}_{year}_{self.lr_degree}deg.nc') |
| ) |
| self.hr_datasets[f"{var}_{year}"] = xr.open_dataset( |
| os.path.join(self.nc_path, self.hr_dir, upper_prefix[i], f'{upper_prefix[i]}_{year}_{self.hr_degree}deg.nc') |
| ) |
| |
| random.seed(seed) |
|
|
| def _trim_width_to_multiple(self, value): |
| """Trim the last longitude columns if width is not divisible by `self.spatial_multiple`.""" |
| if self.spatial_multiple <= 0: |
| return value |
| w = value.shape[-1] |
| w_new = (w // self.spatial_multiple) * self.spatial_multiple |
| if w_new == 0: |
| return value |
| if w_new == w: |
| return value |
| return value[..., :w_new] |
|
|
| @staticmethod |
| def _lat_name(da): |
| if "latitude" in da.coords: |
| return "latitude" |
| if "lat" in da.coords: |
| return "lat" |
| raise KeyError("Latitude coordinate not found.") |
|
|
| @staticmethod |
| def _lon_name(da): |
| if "longitude" in da.coords: |
| return "longitude" |
| if "lon" in da.coords: |
| return "lon" |
| raise KeyError("Longitude coordinate not found.") |
|
|
| @staticmethod |
| def _level_name(da): |
| for c in ("level", "pressure_level", "plev"): |
| if c in da.dims or c in da.coords: |
| return c |
| raise KeyError("Level coordinate not found for upper variable.") |
|
|
| def _read_surface(self, ds, var, key): |
| da = ds[var].sel(time=key) |
| lat_name = self._lat_name(da) |
| lon_name = self._lon_name(da) |
| lat = da[lat_name].values |
| if lat[0] < lat[-1]: |
| da = da.isel({lat_name: slice(None, None, -1)}) |
| da = da.transpose(lat_name, lon_name) |
| value = da.values.astype(np.float32) |
| value = self._trim_width_to_multiple(value) |
| return value |
|
|
| def _read_upper(self, ds, var, key): |
| da = ds[var].sel(time=key) |
| lev_name = self._level_name(da) |
| lat_name = self._lat_name(da) |
| lon_name = self._lon_name(da) |
| lat = da[lat_name].values |
| if lat[0] < lat[-1]: |
| da = da.isel({lat_name: slice(None, None, -1)}) |
| da = da.transpose(lev_name, lat_name, lon_name) |
| value = da.values.astype(np.float32) |
| value = self._trim_width_to_multiple(value) |
| return value |
|
|
| @staticmethod |
| def _resolve_dataset_layout( |
| nc_root, |
| prefix, |
| year, |
| requested_dir, |
| requested_degree, |
| candidates, |
| kind, |
| ): |
| if requested_dir is not None and requested_degree is None: |
| for cand_dir, cand_degree in candidates: |
| if cand_dir != requested_dir: |
| continue |
| expected = os.path.join(nc_root, requested_dir, prefix, f"{prefix}_{year}_{cand_degree}deg.nc") |
| if os.path.exists(expected): |
| return requested_dir, cand_degree |
|
|
| if requested_degree is not None and requested_dir is None: |
| for cand_dir, cand_degree in candidates: |
| if cand_degree != requested_degree: |
| continue |
| expected = os.path.join(nc_root, cand_dir, prefix, f"{prefix}_{year}_{requested_degree}deg.nc") |
| if os.path.exists(expected): |
| return cand_dir, requested_degree |
|
|
| if requested_dir is not None and requested_degree is not None: |
| expected = os.path.join(nc_root, requested_dir, prefix, f"{prefix}_{year}_{requested_degree}deg.nc") |
| if os.path.exists(expected): |
| return requested_dir, requested_degree |
| raise FileNotFoundError( |
| f"[{kind}] Requested path not found: {expected}. " |
| f"Check --{kind.lower()}_data_dir/--{kind.lower()}_degree_tag." |
| ) |
|
|
| for cand_dir, cand_degree in candidates: |
| expected = os.path.join(nc_root, cand_dir, prefix, f"{prefix}_{year}_{cand_degree}deg.nc") |
| if os.path.exists(expected): |
| return cand_dir, cand_degree |
|
|
| raise FileNotFoundError( |
| f"[{kind}] Could not auto-resolve dataset layout for prefix={prefix}, year={year}. " |
| f"Tried: {[(d, deg) for d, deg in candidates]}" |
| ) |
|
|
| def __getitem__(self, index): |
| """Return input frames, target frames, and its corresponding time steps.""" |
| |
| key = self.keys[index] |
| input_surfaces = [] |
| input_uppers = [] |
| target_surfaces = [] |
| target_uppers = [] |
| time_points = [] |
| |
| for _ in range(self.num_points): |
| time_points.append(key.timestamp()) |
| year = datetime.strftime(key, '%Y') |
| input_surface_variables = [] |
| input_upper_variables = [] |
| |
| for var in self.surface_variables: |
| value = self._read_surface(self.lr_datasets[f"{var}_{year}"], var, key) |
| input_surface_variables.append(value[np.newaxis, ...]) |
| |
| for var in self.upper_variables: |
| value = self._read_upper(self.lr_datasets[f"{var}_{year}"], var, key) |
| input_upper_variables.append(value[np.newaxis, ...]) |
|
|
| input_surfaces.append(np.concatenate(input_surface_variables, axis=0)[np.newaxis, ...]) |
| input_uppers.append(np.concatenate(input_upper_variables, axis=0)[np.newaxis, ...]) |
|
|
| key = key + timedelta(hours=6) |
|
|
| key = key - timedelta(hours=6) |
| year = datetime.strftime(key, '%Y') |
| for var in self.surface_variables: |
| value = self._read_surface(self.hr_datasets[f"{var}_{year}"], var, key) |
| target_surfaces.append(value[np.newaxis, ...]) |
| |
| for var in self.upper_variables: |
| value = self._read_upper(self.hr_datasets[f"{var}_{year}"], var, key) |
| target_uppers.append(value[np.newaxis, ...]) |
|
|
| return np.concatenate(input_surfaces,axis=0), np.concatenate(input_uppers, axis=0), np.concatenate(target_surfaces,axis=0), np.concatenate(target_uppers, axis=0), time_points |
|
|
| def __len__(self): |
| return len(self.keys) |
|
|
| def __repr__(self): |
| return self.__class__.__name__ |
|
|
|
|
| class AuroraFactorizedST(data.Dataset): |
| """Dataset for factorized spatial/temporal SR with commutativity training.""" |
|
|
| def __init__( |
| self, |
| nc_path="", |
| seed=1234, |
| startDate="1999", |
| endDate="2020", |
| freq="6h", |
| surface_prefix=("2m_temperature", "10m_u_component_of_wind", "10m_v_component_of_wind"), |
| upper_prefix=("temperature", "u_component_of_wind", "v_component_of_wind", "specific_humidity", "geopotential"), |
| surface=("t2m", "u10", "v10"), |
| upper=("t", "u", "v", "q", "z"), |
| levels=(50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000), |
| lr6h_dir="5.625", |
| hr6h_dir="1.5_nc", |
| hr1h_dir="1.5_1h_nc", |
| lr1h_dir="", |
| lr6h_degree="5.625", |
| hr_degree="1.5", |
| include_endpoints=True, |
| derive_hr6h_from_hr1h=True, |
| derive_lr1h_from_hr1h=True, |
| ): |
| self.nc_path = nc_path |
| self.surface_prefix = list(surface_prefix) |
| self.upper_prefix = list(upper_prefix) |
| self.surface_variables = list(surface) |
| self.upper_variables = list(upper) |
| self.levels = tuple(levels) |
| self.include_endpoints = include_endpoints |
| self.derive_hr6h_from_hr1h = bool(derive_hr6h_from_hr1h) |
| self.derive_lr1h_from_hr1h = bool(derive_lr1h_from_hr1h) |
| self.lr1h_available = bool(lr1h_dir) and (not self.derive_lr1h_from_hr1h) |
|
|
| if self.derive_hr6h_from_hr1h and (not self.include_endpoints): |
| raise ValueError( |
| "derive_hr6h_from_hr1h=True requires include_endpoints=True to access 6h endpoints." |
| ) |
|
|
| self.keys = list(pd.date_range(start=startDate, end=endDate, freq=freq))[:-2] |
| self.hr_offsets = list(range(0, 7)) if include_endpoints else list(range(1, 6)) |
|
|
| self.ds_lr6h = {} |
| self.ds_hr6h = {} |
| self.ds_hr1h = {} |
| self.ds_lr1h = {} |
| self._time_pos_cache = {} |
|
|
| start_year = pd.Timestamp(startDate).year |
| end_year = pd.Timestamp(endDate).year |
| years = [str(y) for y in range(start_year, end_year + 1)] |
| for year in years: |
| for i, var in enumerate(self.surface_variables): |
| pref = self.surface_prefix[i] |
| self.ds_lr6h[f"{var}_{year}"] = xr.open_dataset( |
| os.path.join(nc_path, lr6h_dir, pref, f"{pref}_{year}_{lr6h_degree}deg.nc") |
| ) |
| if not self.derive_hr6h_from_hr1h: |
| self.ds_hr6h[f"{var}_{year}"] = xr.open_dataset( |
| os.path.join(nc_path, hr6h_dir, pref, f"{pref}_{year}_{hr_degree}deg.nc") |
| ) |
| self.ds_hr1h[f"{var}_{year}"] = xr.open_dataset( |
| os.path.join(nc_path, hr1h_dir, pref, f"{pref}_{year}_{hr_degree}deg.nc") |
| ) |
| if self.lr1h_available: |
| self.ds_lr1h[f"{var}_{year}"] = xr.open_dataset( |
| os.path.join(nc_path, lr1h_dir, pref, f"{pref}_{year}_{lr6h_degree}deg.nc") |
| ) |
|
|
| for i, var in enumerate(self.upper_variables): |
| pref = self.upper_prefix[i] |
| self.ds_lr6h[f"{var}_{year}"] = xr.open_dataset( |
| os.path.join(nc_path, lr6h_dir, pref, f"{pref}_{year}_{lr6h_degree}deg.nc") |
| ) |
| if not self.derive_hr6h_from_hr1h: |
| self.ds_hr6h[f"{var}_{year}"] = xr.open_dataset( |
| os.path.join(nc_path, hr6h_dir, pref, f"{pref}_{year}_{hr_degree}deg.nc") |
| ) |
| self.ds_hr1h[f"{var}_{year}"] = xr.open_dataset( |
| os.path.join(nc_path, hr1h_dir, pref, f"{pref}_{year}_{hr_degree}deg.nc") |
| ) |
| if self.lr1h_available: |
| self.ds_lr1h[f"{var}_{year}"] = xr.open_dataset( |
| os.path.join(nc_path, lr1h_dir, pref, f"{pref}_{year}_{lr6h_degree}deg.nc") |
| ) |
|
|
| random.seed(seed) |
|
|
| @staticmethod |
| def _lat_name(da): |
| if "latitude" in da.coords: |
| return "latitude" |
| if "lat" in da.coords: |
| return "lat" |
| raise KeyError("Latitude coordinate not found.") |
|
|
| @staticmethod |
| def _lon_name(da): |
| if "longitude" in da.coords: |
| return "longitude" |
| if "lon" in da.coords: |
| return "lon" |
| raise KeyError("Longitude coordinate not found.") |
|
|
| @staticmethod |
| def _level_name(da): |
| for c in ("level", "pressure_level", "plev"): |
| if c in da.dims or c in da.coords: |
| return c |
| raise KeyError("Level coordinate not found for upper variable.") |
|
|
| @staticmethod |
| def _year_str(key): |
| return datetime.strftime(pd.Timestamp(key).to_pydatetime(), "%Y") |
|
|
| def _dataset_for_time(self, ds_dict, var, key): |
| year = self._year_str(key) |
| ds_key = f"{var}_{year}" |
| if ds_key not in ds_dict: |
| raise KeyError(f"Dataset key {ds_key} is not available.") |
| return ds_dict[ds_key] |
|
|
| def _time_pos(self, ds, key): |
| ds_id = id(ds) |
| if ds_id not in self._time_pos_cache: |
| if "time" not in ds.indexes: |
| raise KeyError("Dataset does not have `time` index.") |
| time_index = pd.DatetimeIndex(ds.indexes["time"]) |
| self._time_pos_cache[ds_id] = {int(ts.value): i for i, ts in enumerate(time_index)} |
|
|
| key_ns = int(pd.Timestamp(key).value) |
| pos = self._time_pos_cache[ds_id].get(key_ns, None) |
| if pos is None: |
| raise KeyError(f"time={pd.Timestamp(key)} not found in dataset time index.") |
| return pos |
|
|
| def _standardize_lat(self, da): |
| lat_name = self._lat_name(da) |
| lat = da[lat_name].values |
| if lat[0] < lat[-1]: |
| da = da.isel({lat_name: slice(None, None, -1)}) |
| return da |
|
|
| def _read_surface(self, ds, var, key): |
| t_pos = self._time_pos(ds, key) |
| da = ds[var].isel(time=t_pos) |
| da = self._standardize_lat(da) |
| lat_name = self._lat_name(da) |
| lon_name = self._lon_name(da) |
| da = da.transpose(lat_name, lon_name) |
| return da.values.astype(np.float32) |
|
|
| def _read_upper(self, ds, var, key): |
| t_pos = self._time_pos(ds, key) |
| da = ds[var].isel(time=t_pos) |
| da = self._standardize_lat(da) |
| lev_name = self._level_name(da) |
| lat_name = self._lat_name(da) |
| lon_name = self._lon_name(da) |
| if lev_name in da.coords: |
| da = da.sel({lev_name: list(self.levels)}) |
| da = da.transpose(lev_name, lat_name, lon_name) |
| return da.values.astype(np.float32) |
|
|
| def _interp_surface_to_lr(self, da_hr, lat_lr, lon_lr): |
| da_hr = self._standardize_lat(da_hr) |
| lat_name = self._lat_name(da_hr) |
| lon_name = self._lon_name(da_hr) |
| da_lr = da_hr.interp( |
| {lat_name: xr.DataArray(lat_lr, dims=(lat_name,)), |
| lon_name: xr.DataArray(lon_lr, dims=(lon_name,))}, |
| method="linear", |
| ) |
| da_lr = da_lr.transpose(lat_name, lon_name) |
| return da_lr.values.astype(np.float32) |
|
|
| def _interp_upper_to_lr(self, da_hr, lat_lr, lon_lr): |
| da_hr = self._standardize_lat(da_hr) |
| lev_name = self._level_name(da_hr) |
| lat_name = self._lat_name(da_hr) |
| lon_name = self._lon_name(da_hr) |
| da_lr = da_hr.interp( |
| {lat_name: xr.DataArray(lat_lr, dims=(lat_name,)), |
| lon_name: xr.DataArray(lon_lr, dims=(lon_name,))}, |
| method="linear", |
| ) |
| da_lr = da_lr.transpose(lev_name, lat_name, lon_name) |
| return da_lr.values.astype(np.float32) |
|
|
| def __getitem__(self, index): |
| key = self.keys[index] |
| t0 = key |
| t6 = key + timedelta(hours=6) |
| times_6h = [t0, t6] |
| times_1h = [t0 + timedelta(hours=h) for h in self.hr_offsets] |
| x_lr6h_surface = [] |
| y_hr6h_surface = [] |
| y_hr1h_surface = [] |
| y_lr1h_surface = [] if self.lr1h_available else None |
|
|
| x_lr6h_upper = [] |
| y_hr6h_upper = [] |
| y_hr1h_upper = [] |
| y_lr1h_upper = [] if self.lr1h_available else None |
|
|
| for var in self.surface_variables: |
| lr_stack = [self._read_surface(self._dataset_for_time(self.ds_lr6h, var, t), var, t) for t in times_6h] |
| hr1_stack = [self._read_surface(self._dataset_for_time(self.ds_hr1h, var, t), var, t) for t in times_1h] |
| if self.derive_hr6h_from_hr1h: |
| hr6_stack = [hr1_stack[0], hr1_stack[-1]] |
| else: |
| hr6_stack = [self._read_surface(self._dataset_for_time(self.ds_hr6h, var, t), var, t) for t in times_6h] |
| if self.lr1h_available: |
| lr1_stack = [self._read_surface(self._dataset_for_time(self.ds_lr1h, var, t), var, t) for t in times_1h] |
| else: |
| lr1_stack = None |
|
|
| x_lr6h_surface.append(np.stack(lr_stack, axis=0)) |
| y_hr6h_surface.append(np.stack(hr6_stack, axis=0)) |
| y_hr1h_surface.append(np.stack(hr1_stack, axis=0)) |
| if y_lr1h_surface is not None: |
| y_lr1h_surface.append(np.stack(lr1_stack, axis=0)) |
|
|
| for var in self.upper_variables: |
| lr_stack = [self._read_upper(self._dataset_for_time(self.ds_lr6h, var, t), var, t) for t in times_6h] |
| hr1_stack = [self._read_upper(self._dataset_for_time(self.ds_hr1h, var, t), var, t) for t in times_1h] |
| if self.derive_hr6h_from_hr1h: |
| hr6_stack = [hr1_stack[0], hr1_stack[-1]] |
| else: |
| hr6_stack = [self._read_upper(self._dataset_for_time(self.ds_hr6h, var, t), var, t) for t in times_6h] |
| if self.lr1h_available: |
| lr1_stack = [self._read_upper(self._dataset_for_time(self.ds_lr1h, var, t), var, t) for t in times_1h] |
| else: |
| lr1_stack = None |
|
|
| x_lr6h_upper.append(np.stack(lr_stack, axis=0)) |
| y_hr6h_upper.append(np.stack(hr6_stack, axis=0)) |
| y_hr1h_upper.append(np.stack(hr1_stack, axis=0)) |
| if y_lr1h_upper is not None: |
| y_lr1h_upper.append(np.stack(lr1_stack, axis=0)) |
|
|
| |
| x_lr6h_surface = np.stack(x_lr6h_surface, axis=1).astype(np.float32) |
| y_hr6h_surface = np.stack(y_hr6h_surface, axis=1).astype(np.float32) |
| y_hr1h_surface = np.stack(y_hr1h_surface, axis=1).astype(np.float32) |
| if y_lr1h_surface is not None: |
| y_lr1h_surface = np.stack(y_lr1h_surface, axis=1).astype(np.float32) |
|
|
| x_lr6h_upper = np.stack(x_lr6h_upper, axis=1).astype(np.float32) |
| y_hr6h_upper = np.stack(y_hr6h_upper, axis=1).astype(np.float32) |
| y_hr1h_upper = np.stack(y_hr1h_upper, axis=1).astype(np.float32) |
| if y_lr1h_upper is not None: |
| y_lr1h_upper = np.stack(y_lr1h_upper, axis=1).astype(np.float32) |
|
|
| time_unix = np.asarray([t.timestamp() for t in times_1h], dtype=np.int64) |
|
|
| out = { |
| "x_lr6h_surface": x_lr6h_surface, |
| "x_lr6h_upper": x_lr6h_upper, |
| "y_hr6h_surface": y_hr6h_surface, |
| "y_hr6h_upper": y_hr6h_upper, |
| "y_hr1h_surface": y_hr1h_surface, |
| "y_hr1h_upper": y_hr1h_upper, |
| "time_unix": time_unix, |
| } |
| if y_lr1h_surface is not None: |
| out["y_lr1h_surface"] = y_lr1h_surface |
| if y_lr1h_upper is not None: |
| out["y_lr1h_upper"] = y_lr1h_upper |
| return out |
|
|
| def __len__(self): |
| return len(self.keys) |
|
|
| def __repr__(self): |
| return self.__class__.__name__ |
|
|