import os import tempfile from datetime import timedelta import pytest import pandas as pd import numpy as np import xarray as xr import torch import hydra from pvnet.data import DataModule, SiteDataModule import pvnet.models.multimodal.encoders.encoders3d import pvnet.models.multimodal.linear_networks.networks import pvnet.models.multimodal.site_encoders.encoders from pvnet.models.multimodal.multimodal import Model xr.set_options(keep_attrs=True) def time_before_present(dt: timedelta): return pd.Timestamp.now(tz=None) - dt @pytest.fixture def nwp_data(): # Load dataset which only contains coordinates, but no data ds = xr.open_zarr( f"{os.path.dirname(os.path.abspath(__file__))}/test_data/sample_data/nwp_shell.zarr" ) # Last init time was at least 2 hours ago and hour to 3-hour interval t0_datetime_utc = time_before_present(timedelta(hours=2)).floor(timedelta(hours=3)) ds.init_time.values[:] = pd.date_range( t0_datetime_utc - timedelta(hours=3 * (len(ds.init_time) - 1)), t0_datetime_utc, freq=timedelta(hours=3), ) # This is important to avoid saving errors for v in list(ds.coords.keys()): if ds.coords[v].dtype == object: ds[v].encoding.clear() for v in list(ds.variables.keys()): if ds[v].dtype == object: ds[v].encoding.clear() # Add data to dataset ds["UKV"] = xr.DataArray( np.zeros([len(ds[c]) for c in ds.coords]), coords=ds.coords, ) # Add stored attributes to DataArray ds.UKV.attrs = ds.attrs["_data_attrs"] del ds.attrs["_data_attrs"] return ds @pytest.fixture() def sat_data(): # Load dataset which only contains coordinates, but no data ds = xr.open_zarr( f"{os.path.dirname(os.path.abspath(__file__))}/test_data/sample_data/non_hrv_shell.zarr" ) # Change times so they lead up to present. Delayed by at most 1 hour t0_datetime_utc = time_before_present(timedelta(minutes=0)).floor(timedelta(minutes=30)) t0_datetime_utc = t0_datetime_utc - timedelta(minutes=30) ds.time.values[:] = pd.date_range( t0_datetime_utc - timedelta(minutes=5 * (len(ds.time) - 1)), t0_datetime_utc, freq=timedelta(minutes=5), ) # Add data to dataset ds["data"] = xr.DataArray( np.zeros([len(ds[c]) for c in ds.coords]), coords=ds.coords, ) # Add stored attributes to DataArray ds.data.attrs = ds.attrs["_data_attrs"] del ds.attrs["_data_attrs"] return ds def generate_synthetic_sample(): """ Generate synthetic sample for testing """ now = pd.Timestamp.now(tz=None) sample = {} # NWP define sample["nwp"] = { "ukv": { "nwp": torch.rand(11, 11, 24, 24), "nwp_init_time_utc": torch.tensor( [(now - pd.Timedelta(hours=i)).timestamp() for i in range(11)] ), "nwp_step": torch.arange(11, dtype=torch.float32), "nwp_target_time_utc": torch.tensor( [(now + pd.Timedelta(hours=i)).timestamp() for i in range(11)] ), "nwp_y_osgb": torch.linspace(0, 100, 24), "nwp_x_osgb": torch.linspace(0, 100, 24), }, "ecmwf": { "nwp": torch.rand(11, 12, 12, 12), "nwp_init_time_utc": torch.tensor( [(now - pd.Timedelta(hours=i)).timestamp() for i in range(11)] ), "nwp_step": torch.arange(11, dtype=torch.float32), "nwp_target_time_utc": torch.tensor( [(now + pd.Timedelta(hours=i)).timestamp() for i in range(11)] ), }, "sat_pred": { "nwp": torch.rand(12, 11, 24, 24), "nwp_init_time_utc": torch.tensor( [(now - pd.Timedelta(hours=i)).timestamp() for i in range(12)] ), "nwp_step": torch.arange(12, dtype=torch.float32), "nwp_target_time_utc": torch.tensor( [(now + pd.Timedelta(hours=i)).timestamp() for i in range(12)] ), }, } # Satellite define sample["satellite_actual"] = torch.rand(7, 11, 24, 24) sample["satellite_time_utc"] = torch.tensor( [(now - pd.Timedelta(minutes=5*i)).timestamp() for i in range(7)] ) sample["satellite_x_geostationary"] = torch.linspace(0, 100, 24) sample["satellite_y_geostationary"] = torch.linspace(0, 100, 24) # GSP define sample["gsp"] = torch.rand(21) sample["gsp_nominal_capacity_mwp"] = torch.tensor(100.0) sample["gsp_effective_capacity_mwp"] = torch.tensor(85.0) sample["gsp_time_utc"] = torch.tensor( [(now + pd.Timedelta(minutes=30*i)).timestamp() for i in range(21)] ) sample["gsp_t0_idx"] = float(7) sample["gsp_id"] = 12 sample["gsp_x_osgb"] = 123456.0 sample["gsp_y_osgb"] = 654321.0 # Solar position define sample["solar_azimuth"] = torch.linspace(0, 180, 21) sample["solar_elevation"] = torch.linspace(-10, 60, 21) return sample def generate_synthetic_site_sample(site_id=1, variation_index=0, add_noise=True): """ Generate synthetic site sample that matches site sample structure Args: site_id: ID for the site variation_index: Index to use for coordinate variations add_noise: Whether to add random noise to data variables """ now = pd.Timestamp.now(tz=None) # Create time and space coordinates site_time_coords = pd.date_range(start=now - pd.Timedelta(hours=48), periods=197, freq="15min") nwp_time_coords = pd.date_range(start=now, periods=50, freq="1h") nwp_lat = np.linspace(50.0, 60.0, 24) nwp_lon = np.linspace(-10.0, 2.0, 24) nwp_channels = np.array(['t2m', 'ssrd', 'ssr', 'sp', 'r', 'tcc', 'u10', 'v10'], dtype='