|
|
import os |
|
|
import tempfile |
|
|
from datetime import timedelta |
|
|
|
|
|
import pytest |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import xarray as xr |
|
|
import torch |
|
|
import hydra |
|
|
|
|
|
from pvnet.data import DataModule, SiteDataModule |
|
|
import pvnet.models.multimodal.encoders.encoders3d |
|
|
import pvnet.models.multimodal.linear_networks.networks |
|
|
import pvnet.models.multimodal.site_encoders.encoders |
|
|
from pvnet.models.multimodal.multimodal import Model |
|
|
|
|
|
|
|
|
xr.set_options(keep_attrs=True) |
|
|
|
|
|
|
|
|
def time_before_present(dt: timedelta): |
|
|
return pd.Timestamp.now(tz=None) - dt |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def nwp_data(): |
|
|
|
|
|
ds = xr.open_zarr( |
|
|
f"{os.path.dirname(os.path.abspath(__file__))}/test_data/sample_data/nwp_shell.zarr" |
|
|
) |
|
|
|
|
|
|
|
|
t0_datetime_utc = time_before_present(timedelta(hours=2)).floor(timedelta(hours=3)) |
|
|
ds.init_time.values[:] = pd.date_range( |
|
|
t0_datetime_utc - timedelta(hours=3 * (len(ds.init_time) - 1)), |
|
|
t0_datetime_utc, |
|
|
freq=timedelta(hours=3), |
|
|
) |
|
|
|
|
|
|
|
|
for v in list(ds.coords.keys()): |
|
|
if ds.coords[v].dtype == object: |
|
|
ds[v].encoding.clear() |
|
|
|
|
|
for v in list(ds.variables.keys()): |
|
|
if ds[v].dtype == object: |
|
|
ds[v].encoding.clear() |
|
|
|
|
|
|
|
|
ds["UKV"] = xr.DataArray( |
|
|
np.zeros([len(ds[c]) for c in ds.coords]), |
|
|
coords=ds.coords, |
|
|
) |
|
|
|
|
|
|
|
|
ds.UKV.attrs = ds.attrs["_data_attrs"] |
|
|
del ds.attrs["_data_attrs"] |
|
|
|
|
|
return ds |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def sat_data(): |
|
|
|
|
|
ds = xr.open_zarr( |
|
|
f"{os.path.dirname(os.path.abspath(__file__))}/test_data/sample_data/non_hrv_shell.zarr" |
|
|
) |
|
|
|
|
|
|
|
|
t0_datetime_utc = time_before_present(timedelta(minutes=0)).floor(timedelta(minutes=30)) |
|
|
t0_datetime_utc = t0_datetime_utc - timedelta(minutes=30) |
|
|
ds.time.values[:] = pd.date_range( |
|
|
t0_datetime_utc - timedelta(minutes=5 * (len(ds.time) - 1)), |
|
|
t0_datetime_utc, |
|
|
freq=timedelta(minutes=5), |
|
|
) |
|
|
|
|
|
|
|
|
ds["data"] = xr.DataArray( |
|
|
np.zeros([len(ds[c]) for c in ds.coords]), |
|
|
coords=ds.coords, |
|
|
) |
|
|
|
|
|
|
|
|
ds.data.attrs = ds.attrs["_data_attrs"] |
|
|
del ds.attrs["_data_attrs"] |
|
|
|
|
|
return ds |
|
|
|
|
|
|
|
|
def generate_synthetic_sample(): |
|
|
""" |
|
|
Generate synthetic sample for testing |
|
|
""" |
|
|
now = pd.Timestamp.now(tz=None) |
|
|
sample = {} |
|
|
|
|
|
|
|
|
sample["nwp"] = { |
|
|
"ukv": { |
|
|
"nwp": torch.rand(11, 11, 24, 24), |
|
|
"nwp_init_time_utc": torch.tensor( |
|
|
[(now - pd.Timedelta(hours=i)).timestamp() for i in range(11)] |
|
|
), |
|
|
"nwp_step": torch.arange(11, dtype=torch.float32), |
|
|
"nwp_target_time_utc": torch.tensor( |
|
|
[(now + pd.Timedelta(hours=i)).timestamp() for i in range(11)] |
|
|
), |
|
|
"nwp_y_osgb": torch.linspace(0, 100, 24), |
|
|
"nwp_x_osgb": torch.linspace(0, 100, 24), |
|
|
}, |
|
|
"ecmwf": { |
|
|
"nwp": torch.rand(11, 12, 12, 12), |
|
|
"nwp_init_time_utc": torch.tensor( |
|
|
[(now - pd.Timedelta(hours=i)).timestamp() for i in range(11)] |
|
|
), |
|
|
"nwp_step": torch.arange(11, dtype=torch.float32), |
|
|
"nwp_target_time_utc": torch.tensor( |
|
|
[(now + pd.Timedelta(hours=i)).timestamp() for i in range(11)] |
|
|
), |
|
|
}, |
|
|
"sat_pred": { |
|
|
"nwp": torch.rand(12, 11, 24, 24), |
|
|
"nwp_init_time_utc": torch.tensor( |
|
|
[(now - pd.Timedelta(hours=i)).timestamp() for i in range(12)] |
|
|
), |
|
|
"nwp_step": torch.arange(12, dtype=torch.float32), |
|
|
"nwp_target_time_utc": torch.tensor( |
|
|
[(now + pd.Timedelta(hours=i)).timestamp() for i in range(12)] |
|
|
), |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
sample["satellite_actual"] = torch.rand(7, 11, 24, 24) |
|
|
sample["satellite_time_utc"] = torch.tensor( |
|
|
[(now - pd.Timedelta(minutes=5*i)).timestamp() for i in range(7)] |
|
|
) |
|
|
sample["satellite_x_geostationary"] = torch.linspace(0, 100, 24) |
|
|
sample["satellite_y_geostationary"] = torch.linspace(0, 100, 24) |
|
|
|
|
|
|
|
|
sample["gsp"] = torch.rand(21) |
|
|
sample["gsp_nominal_capacity_mwp"] = torch.tensor(100.0) |
|
|
sample["gsp_effective_capacity_mwp"] = torch.tensor(85.0) |
|
|
sample["gsp_time_utc"] = torch.tensor( |
|
|
[(now + pd.Timedelta(minutes=30*i)).timestamp() for i in range(21)] |
|
|
) |
|
|
sample["gsp_t0_idx"] = float(7) |
|
|
sample["gsp_id"] = 12 |
|
|
sample["gsp_x_osgb"] = 123456.0 |
|
|
sample["gsp_y_osgb"] = 654321.0 |
|
|
|
|
|
|
|
|
sample["solar_azimuth"] = torch.linspace(0, 180, 21) |
|
|
sample["solar_elevation"] = torch.linspace(-10, 60, 21) |
|
|
|
|
|
return sample |
|
|
|
|
|
|
|
|
def generate_synthetic_site_sample(site_id=1, variation_index=0, add_noise=True): |
|
|
""" |
|
|
Generate synthetic site sample that matches site sample structure |
|
|
|
|
|
Args: |
|
|
site_id: ID for the site |
|
|
variation_index: Index to use for coordinate variations |
|
|
add_noise: Whether to add random noise to data variables |
|
|
""" |
|
|
now = pd.Timestamp.now(tz=None) |
|
|
|
|
|
|
|
|
site_time_coords = pd.date_range(start=now - pd.Timedelta(hours=48), periods=197, freq="15min") |
|
|
nwp_time_coords = pd.date_range(start=now, periods=50, freq="1h") |
|
|
nwp_lat = np.linspace(50.0, 60.0, 24) |
|
|
nwp_lon = np.linspace(-10.0, 2.0, 24) |
|
|
nwp_channels = np.array(['t2m', 'ssrd', 'ssr', 'sp', 'r', 'tcc', 'u10', 'v10'], dtype='<U5') |
|
|
|
|
|
|
|
|
nwp_init_time = pd.date_range(start=now - pd.Timedelta(hours=12), periods=1, freq="12h").repeat(50) |
|
|
nwp_steps = pd.timedelta_range(start=pd.Timedelta(hours=0), periods=50, freq="1h") |
|
|
nwp_data = np.random.randn(50, 8, 24, 24).astype(np.float32) |
|
|
|
|
|
|
|
|
site_data = np.random.rand(197) |
|
|
site_lat = 52.5 + variation_index * 0.1 |
|
|
site_lon = -1.5 - variation_index * 0.05 |
|
|
site_capacity = 10000.0 * (1.0 + variation_index * 0.01) |
|
|
|
|
|
|
|
|
days_since_jan1 = (site_time_coords.dayofyear - 1) / 365.0 |
|
|
hours_since_midnight = (site_time_coords.hour + site_time_coords.minute / 60.0) / 24.0 |
|
|
|
|
|
|
|
|
site_solar_azimuth = np.linspace(0, 360, 197) |
|
|
site_solar_elevation = 15 * np.sin(np.linspace(0, 2*np.pi, 197)) |
|
|
trig_features = { |
|
|
"date_sin": np.sin(2 * np.pi * days_since_jan1), |
|
|
"date_cos": np.cos(2 * np.pi * days_since_jan1), |
|
|
"time_sin": np.sin(2 * np.pi * hours_since_midnight), |
|
|
"time_cos": np.cos(2 * np.pi * hours_since_midnight), |
|
|
} |
|
|
|
|
|
|
|
|
site_data_ds = xr.Dataset( |
|
|
data_vars={ |
|
|
"nwp-ecmwf": (["nwp-ecmwf__target_time_utc", "nwp-ecmwf__channel", |
|
|
"nwp-ecmwf__longitude", "nwp-ecmwf__latitude"], nwp_data), |
|
|
"site": (["site__time_utc"], site_data), |
|
|
}, |
|
|
coords={ |
|
|
|
|
|
"nwp-ecmwf__latitude": nwp_lat, |
|
|
"nwp-ecmwf__longitude": nwp_lon, |
|
|
"nwp-ecmwf__channel": nwp_channels, |
|
|
"nwp-ecmwf__target_time_utc": nwp_time_coords, |
|
|
"nwp-ecmwf__init_time_utc": (["nwp-ecmwf__target_time_utc"], nwp_init_time), |
|
|
"nwp-ecmwf__step": (["nwp-ecmwf__target_time_utc"], nwp_steps), |
|
|
|
|
|
|
|
|
"site__site_id": np.int32(site_id), |
|
|
"site__latitude": site_lat, |
|
|
"site__longitude": site_lon, |
|
|
"site__capacity_kwp": site_capacity, |
|
|
"site__time_utc": site_time_coords, |
|
|
"site__solar_azimuth": (["site__time_utc"], site_solar_azimuth), |
|
|
"site__solar_elevation": (["site__time_utc"], site_solar_elevation), |
|
|
**{f"site__{k}": (["site__time_utc"], v) for k, v in trig_features.items()} |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
site_data_ds["nwp-ecmwf"].attrs = { |
|
|
"Conventions": "CF-1.7", |
|
|
"GRIB_centre": "ecmf", |
|
|
"GRIB_centreDescription": "European Centre for Medium-Range Weather Forecasts", |
|
|
"GRIB_subCentre": "0", |
|
|
"institution": "European Centre for Medium-Range Weather Forecasts" |
|
|
} |
|
|
|
|
|
|
|
|
if add_noise: |
|
|
for var in ["site", "nwp-ecmwf"]: |
|
|
noise_shape = site_data_ds[var].shape |
|
|
noise = np.random.randn(*noise_shape).astype(site_data_ds[var].dtype) * 0.01 |
|
|
site_data_ds[var] = site_data_ds[var] + noise |
|
|
|
|
|
return site_data_ds |
|
|
|
|
|
|
|
|
def generate_synthetic_pv_batch(): |
|
|
""" |
|
|
Generate a synthetic PV batch for SimpleLearnedAggregator tests |
|
|
""" |
|
|
|
|
|
batch_size = 8 |
|
|
sequence_length = 180 // 5 + 1 |
|
|
num_sites = 349 |
|
|
|
|
|
return torch.rand(batch_size, sequence_length, num_sites) |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def sample_train_val_datamodule(): |
|
|
""" |
|
|
Create a DataModule with synthetic data files for training and validation |
|
|
""" |
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
|
|
|
os.makedirs(f"{tmpdirname}/train", exist_ok=True) |
|
|
os.makedirs(f"{tmpdirname}/val", exist_ok=True) |
|
|
|
|
|
|
|
|
for i in range(10): |
|
|
sample = generate_synthetic_sample() |
|
|
torch.save(sample, f"{tmpdirname}/train/{i:08d}.pt") |
|
|
torch.save(sample, f"{tmpdirname}/val/{i:08d}.pt") |
|
|
|
|
|
|
|
|
dm = DataModule( |
|
|
configuration=None, |
|
|
sample_dir=tmpdirname, |
|
|
batch_size=2, |
|
|
num_workers=0, |
|
|
prefetch_factor=None, |
|
|
train_period=[None, None], |
|
|
val_period=[None, None], |
|
|
) |
|
|
|
|
|
yield dm |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def sample_datamodule(sample_train_val_datamodule): |
|
|
yield sample_train_val_datamodule |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def sample_site_datamodule(): |
|
|
""" |
|
|
Create a SiteDataModule with synthetic site data in netCDF format |
|
|
that matches the structure of the actual site samples |
|
|
""" |
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
|
|
|
os.makedirs(f"{tmpdirname}/train", exist_ok=True) |
|
|
os.makedirs(f"{tmpdirname}/val", exist_ok=True) |
|
|
|
|
|
|
|
|
for i in range(10): |
|
|
site_data = generate_synthetic_site_sample( |
|
|
site_id=i % 3 + 1, |
|
|
variation_index=i, |
|
|
add_noise=True |
|
|
) |
|
|
|
|
|
|
|
|
for subset in ["train", "val"]: |
|
|
file_path = f"{tmpdirname}/{subset}/{i:08d}.nc" |
|
|
site_data.to_netcdf(file_path, mode="w", engine="h5netcdf") |
|
|
|
|
|
|
|
|
dm = SiteDataModule( |
|
|
configuration=None, |
|
|
sample_dir=tmpdirname, |
|
|
batch_size=2, |
|
|
num_workers=0, |
|
|
prefetch_factor=None, |
|
|
train_period=[None, None], |
|
|
val_period=[None, None], |
|
|
) |
|
|
|
|
|
yield dm |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def sample_batch(sample_datamodule): |
|
|
batch = next(iter(sample_datamodule.train_dataloader())) |
|
|
return batch |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def sample_satellite_batch(sample_batch): |
|
|
sat_image = sample_batch["satellite_actual"] |
|
|
return torch.swapaxes(sat_image, 1, 2) |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def sample_pv_batch(): |
|
|
""" |
|
|
Create a batch of PV site data for testing site encoder models |
|
|
""" |
|
|
pv_tensor = generate_synthetic_pv_batch() |
|
|
|
|
|
|
|
|
batch_size = pv_tensor.shape[0] |
|
|
gsp_ids = torch.randint(low=0, high=10, size=(batch_size,)) |
|
|
|
|
|
|
|
|
batch = { |
|
|
"pv": pv_tensor, |
|
|
"gsp_id": gsp_ids, |
|
|
} |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def sample_site_batch(sample_site_datamodule): |
|
|
batch = next(iter(sample_site_datamodule.train_dataloader())) |
|
|
return batch |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def model_minutes_kwargs(): |
|
|
kwargs = dict( |
|
|
forecast_minutes=480, |
|
|
history_minutes=120, |
|
|
) |
|
|
return kwargs |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def encoder_model_kwargs(): |
|
|
|
|
|
kwargs = dict( |
|
|
sequence_length=7, |
|
|
image_size_pixels=24, |
|
|
in_channels=11, |
|
|
out_features=128, |
|
|
) |
|
|
return kwargs |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def site_encoder_model_kwargs(): |
|
|
"""Used to test site encoder model on PV data""" |
|
|
return dict( |
|
|
sequence_length=180 // 5 + 1, |
|
|
num_sites=349, |
|
|
out_features=128, |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def site_encoder_model_kwargs_dsampler(): |
|
|
"""Used to test site encoder model on PV data with data sampler""" |
|
|
return dict( |
|
|
sequence_length=60 // 15 + 1, |
|
|
num_sites=1, |
|
|
out_features=128, |
|
|
target_key_to_use="site" |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def site_encoder_sensor_model_kwargs(): |
|
|
"""Used to test site encoder model for sensor data""" |
|
|
return dict( |
|
|
sequence_length=180 // 5 + 1, |
|
|
num_sites=26, |
|
|
out_features=128, |
|
|
num_channels=23, |
|
|
target_key_to_use="wind", |
|
|
input_key_to_use="sensor", |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def raw_multimodal_model_kwargs(model_minutes_kwargs): |
|
|
kwargs = dict( |
|
|
sat_encoder=dict( |
|
|
_target_="pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet", |
|
|
_partial_=True, |
|
|
in_channels=11, |
|
|
out_features=128, |
|
|
number_of_conv3d_layers=6, |
|
|
conv3d_channels=32, |
|
|
image_size_pixels=24, |
|
|
), |
|
|
nwp_encoders_dict={ |
|
|
"ukv": dict( |
|
|
_target_="pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet", |
|
|
_partial_=True, |
|
|
in_channels=11, |
|
|
out_features=128, |
|
|
number_of_conv3d_layers=6, |
|
|
conv3d_channels=32, |
|
|
image_size_pixels=24, |
|
|
), |
|
|
}, |
|
|
add_image_embedding_channel=True, |
|
|
|
|
|
pv_encoder=None, |
|
|
output_network=dict( |
|
|
_target_="pvnet.models.multimodal.linear_networks.networks.ResFCNet2", |
|
|
_partial_=True, |
|
|
fc_hidden_features=128, |
|
|
n_res_blocks=6, |
|
|
res_block_layers=2, |
|
|
dropout_frac=0.0, |
|
|
), |
|
|
location_id_mapping={i:i for i in range(1, 318)}, |
|
|
embedding_dim=16, |
|
|
include_sun=True, |
|
|
include_gsp_yield_history=True, |
|
|
sat_history_minutes=30, |
|
|
nwp_history_minutes={"ukv": 120}, |
|
|
nwp_forecast_minutes={"ukv": 480}, |
|
|
min_sat_delay_minutes=0, |
|
|
) |
|
|
|
|
|
kwargs.update(model_minutes_kwargs) |
|
|
|
|
|
return kwargs |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def multimodal_model_kwargs(raw_multimodal_model_kwargs): |
|
|
return hydra.utils.instantiate(raw_multimodal_model_kwargs) |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def multimodal_model(multimodal_model_kwargs): |
|
|
model = Model(**multimodal_model_kwargs) |
|
|
return model |
|
|
|
|
|
@pytest.fixture() |
|
|
def raw_multimodal_model_kwargs_site_history(model_minutes_kwargs): |
|
|
kwargs = dict( |
|
|
|
|
|
sat_encoder=None, |
|
|
nwp_encoders_dict=None, |
|
|
add_image_embedding_channel=False, |
|
|
pv_encoder=None, |
|
|
output_network=dict( |
|
|
_target_="pvnet.models.multimodal.linear_networks.networks.ResFCNet2", |
|
|
_partial_=True, |
|
|
fc_hidden_features=128, |
|
|
n_res_blocks=6, |
|
|
res_block_layers=2, |
|
|
dropout_frac=0.0, |
|
|
), |
|
|
location_id_mapping=None, |
|
|
embedding_dim=None, |
|
|
include_sun=False, |
|
|
include_gsp_yield_history=False, |
|
|
include_site_yield_history=True |
|
|
) |
|
|
|
|
|
kwargs.update(model_minutes_kwargs) |
|
|
|
|
|
return kwargs |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def multimodal_model_kwargs_site_history(raw_multimodal_model_kwargs_site_history): |
|
|
return hydra.utils.instantiate(raw_multimodal_model_kwargs_site_history) |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def multimodal_model_site_history(multimodal_model_kwargs_site_history): |
|
|
model = Model(**multimodal_model_kwargs_site_history) |
|
|
return model |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def multimodal_quantile_model(multimodal_model_kwargs): |
|
|
model = Model(output_quantiles=[0.1, 0.5, 0.9], **multimodal_model_kwargs) |
|
|
return model |
|
|
|
|
|
|
|
|
@pytest.fixture() |
|
|
def multimodal_quantile_model_ignore_minutes(multimodal_model_kwargs): |
|
|
"""Only forecsat second half of the 8 hours""" |
|
|
model = Model( |
|
|
output_quantiles=[0.1, 0.5, 0.9], **multimodal_model_kwargs, forecast_minutes_ignore=240 |
|
|
) |
|
|
return model |
|
|
|