|
|
import os |
|
|
import tempfile |
|
|
import pytest |
|
|
import torch |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import xarray as xr |
|
|
from pvnet.data import DataModule, SiteDataModule |
|
|
|
|
|
@pytest.fixture |
|
|
def temp_pt_sample_dir(): |
|
|
"""Create temporary directory with synthetic PT samples""" |
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
|
|
|
os.makedirs(f"{tmpdirname}/train", exist_ok=True) |
|
|
os.makedirs(f"{tmpdirname}/val", exist_ok=True) |
|
|
|
|
|
|
|
|
for i in range(5): |
|
|
sample = { |
|
|
"gsp": torch.rand(21), |
|
|
"gsp_time_utc": torch.tensor(list(range(21))), |
|
|
"gsp_nominal_capacity_mwp": torch.tensor(100.0), |
|
|
"gsp_id": 12 |
|
|
} |
|
|
torch.save(sample, f"{tmpdirname}/train/{i:08d}.pt") |
|
|
torch.save(sample, f"{tmpdirname}/val/{i:08d}.pt") |
|
|
|
|
|
yield tmpdirname |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def temp_nc_sample_dir(): |
|
|
"""Create temporary directory with synthetic NC site samples""" |
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
|
|
|
os.makedirs(f"{tmpdirname}/train", exist_ok=True) |
|
|
os.makedirs(f"{tmpdirname}/val", exist_ok=True) |
|
|
|
|
|
|
|
|
config_path = f"{tmpdirname}/data_configuration.yaml" |
|
|
with open(config_path, "w") as f: |
|
|
f.write(f"sample_dir: {tmpdirname}\n") |
|
|
|
|
|
|
|
|
for i in range(5): |
|
|
site_time = pd.date_range("2023-01-01", periods=10, freq="15min") |
|
|
ds = xr.Dataset( |
|
|
data_vars={ |
|
|
"site": (["site__time_utc"], np.random.rand(10)), |
|
|
}, |
|
|
coords={ |
|
|
"site__time_utc": site_time, |
|
|
"site__site_id": np.int32(i % 3 + 1), |
|
|
"site__latitude": 52.5, |
|
|
"site__longitude": -1.5, |
|
|
"site__capacity_kwp": 10000.0, |
|
|
} |
|
|
) |
|
|
|
|
|
ds.to_netcdf(f"{tmpdirname}/train/{i:08d}.nc", mode="w", engine="h5netcdf") |
|
|
ds.to_netcdf(f"{tmpdirname}/val/{i:08d}.nc", mode="w", engine="h5netcdf") |
|
|
|
|
|
yield tmpdirname |
|
|
|
|
|
|
|
|
def test_init(temp_pt_sample_dir): |
|
|
"""Test DataModule initialization""" |
|
|
dm = DataModule( |
|
|
configuration=None, |
|
|
sample_dir=temp_pt_sample_dir, |
|
|
batch_size=2, |
|
|
num_workers=0, |
|
|
prefetch_factor=None, |
|
|
train_period=[None, None], |
|
|
val_period=[None, None], |
|
|
) |
|
|
|
|
|
|
|
|
assert dm is not None |
|
|
assert hasattr(dm, "train_dataloader") |
|
|
|
|
|
|
|
|
def test_iter(temp_pt_sample_dir): |
|
|
"""Test iteration through DataModule""" |
|
|
dm = DataModule( |
|
|
configuration=None, |
|
|
sample_dir=temp_pt_sample_dir, |
|
|
batch_size=2, |
|
|
num_workers=0, |
|
|
prefetch_factor=None, |
|
|
train_period=[None, None], |
|
|
val_period=[None, None], |
|
|
) |
|
|
|
|
|
|
|
|
batch = next(iter(dm.train_dataloader())) |
|
|
assert batch is not None |
|
|
assert "gsp" in batch |
|
|
|
|
|
|
|
|
def test_iter_multiprocessing(temp_pt_sample_dir): |
|
|
"""Test DataModule with multiple workers""" |
|
|
dm = DataModule( |
|
|
configuration=None, |
|
|
sample_dir=temp_pt_sample_dir, |
|
|
batch_size=1, |
|
|
num_workers=2, |
|
|
prefetch_factor=1, |
|
|
train_period=[None, None], |
|
|
val_period=[None, None], |
|
|
) |
|
|
|
|
|
served_batches = 0 |
|
|
for batch in dm.train_dataloader(): |
|
|
served_batches += 1 |
|
|
|
|
|
if served_batches == 2: |
|
|
break |
|
|
|
|
|
|
|
|
assert served_batches == 2 |
|
|
|
|
|
|
|
|
def test_site_init_sample_dir(temp_nc_sample_dir): |
|
|
"""Test SiteDataModule initialization with sample dir""" |
|
|
dm = SiteDataModule( |
|
|
configuration=None, |
|
|
sample_dir=temp_nc_sample_dir, |
|
|
batch_size=2, |
|
|
num_workers=0, |
|
|
prefetch_factor=None, |
|
|
train_period=[None, None], |
|
|
val_period=[None, None], |
|
|
) |
|
|
|
|
|
|
|
|
assert dm is not None |
|
|
assert hasattr(dm, "train_dataloader") |
|
|
|
|
|
|
|
|
def test_site_init_config(temp_nc_sample_dir): |
|
|
"""Test SiteDataModule initialization with config file""" |
|
|
config_path = f"{temp_nc_sample_dir}/data_configuration.yaml" |
|
|
|
|
|
dm = SiteDataModule( |
|
|
configuration=config_path, |
|
|
batch_size=2, |
|
|
num_workers=0, |
|
|
prefetch_factor=None, |
|
|
train_period=[None, None], |
|
|
val_period=[None, None], |
|
|
sample_dir=None, |
|
|
) |
|
|
|
|
|
|
|
|
assert dm is not None |
|
|
assert hasattr(dm, "train_dataloader") |
|
|
|