| import os |
| import tempfile |
| import pytest |
| import torch |
| import numpy as np |
| import pandas as pd |
| import xarray as xr |
| from pvnet.data import DataModule, SiteDataModule |
|
|
| @pytest.fixture |
| def temp_pt_sample_dir(): |
| """Create temporary directory with synthetic PT samples""" |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| |
| os.makedirs(f"{tmpdirname}/train", exist_ok=True) |
| os.makedirs(f"{tmpdirname}/val", exist_ok=True) |
|
|
| |
| for i in range(5): |
| sample = { |
| "gsp": torch.rand(21), |
| "gsp_time_utc": torch.tensor(list(range(21))), |
| "gsp_nominal_capacity_mwp": torch.tensor(100.0), |
| "gsp_id": 12 |
| } |
| torch.save(sample, f"{tmpdirname}/train/{i:08d}.pt") |
| torch.save(sample, f"{tmpdirname}/val/{i:08d}.pt") |
|
|
| yield tmpdirname |
|
|
|
|
| @pytest.fixture |
| def temp_nc_sample_dir(): |
| """Create temporary directory with synthetic NC site samples""" |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| |
| os.makedirs(f"{tmpdirname}/train", exist_ok=True) |
| os.makedirs(f"{tmpdirname}/val", exist_ok=True) |
|
|
| |
| config_path = f"{tmpdirname}/data_configuration.yaml" |
| with open(config_path, "w") as f: |
| f.write(f"sample_dir: {tmpdirname}\n") |
|
|
| |
| for i in range(5): |
| site_time = pd.date_range("2023-01-01", periods=10, freq="15min") |
| ds = xr.Dataset( |
| data_vars={ |
| "site": (["site__time_utc"], np.random.rand(10)), |
| }, |
| coords={ |
| "site__time_utc": site_time, |
| "site__site_id": np.int32(i % 3 + 1), |
| "site__latitude": 52.5, |
| "site__longitude": -1.5, |
| "site__capacity_kwp": 10000.0, |
| } |
| ) |
|
|
| ds.to_netcdf(f"{tmpdirname}/train/{i:08d}.nc", mode="w", engine="h5netcdf") |
| ds.to_netcdf(f"{tmpdirname}/val/{i:08d}.nc", mode="w", engine="h5netcdf") |
|
|
| yield tmpdirname |
|
|
|
|
| def test_init(temp_pt_sample_dir): |
| """Test DataModule initialization""" |
| dm = DataModule( |
| configuration=None, |
| sample_dir=temp_pt_sample_dir, |
| batch_size=2, |
| num_workers=0, |
| prefetch_factor=None, |
| train_period=[None, None], |
| val_period=[None, None], |
| ) |
|
|
| |
| assert dm is not None |
| assert hasattr(dm, "train_dataloader") |
|
|
|
|
| def test_iter(temp_pt_sample_dir): |
| """Test iteration through DataModule""" |
| dm = DataModule( |
| configuration=None, |
| sample_dir=temp_pt_sample_dir, |
| batch_size=2, |
| num_workers=0, |
| prefetch_factor=None, |
| train_period=[None, None], |
| val_period=[None, None], |
| ) |
|
|
| |
| batch = next(iter(dm.train_dataloader())) |
| assert batch is not None |
| assert "gsp" in batch |
|
|
|
|
| def test_iter_multiprocessing(temp_pt_sample_dir): |
| """Test DataModule with multiple workers""" |
| dm = DataModule( |
| configuration=None, |
| sample_dir=temp_pt_sample_dir, |
| batch_size=1, |
| num_workers=2, |
| prefetch_factor=1, |
| train_period=[None, None], |
| val_period=[None, None], |
| ) |
|
|
| served_batches = 0 |
| for batch in dm.train_dataloader(): |
| served_batches += 1 |
|
|
| if served_batches == 2: |
| break |
|
|
| |
| assert served_batches == 2 |
|
|
|
|
| def test_site_init_sample_dir(temp_nc_sample_dir): |
| """Test SiteDataModule initialization with sample dir""" |
| dm = SiteDataModule( |
| configuration=None, |
| sample_dir=temp_nc_sample_dir, |
| batch_size=2, |
| num_workers=0, |
| prefetch_factor=None, |
| train_period=[None, None], |
| val_period=[None, None], |
| ) |
|
|
| |
| assert dm is not None |
| assert hasattr(dm, "train_dataloader") |
|
|
|
|
| def test_site_init_config(temp_nc_sample_dir): |
| """Test SiteDataModule initialization with config file""" |
| config_path = f"{temp_nc_sample_dir}/data_configuration.yaml" |
|
|
| dm = SiteDataModule( |
| configuration=config_path, |
| batch_size=2, |
| num_workers=0, |
| prefetch_factor=None, |
| train_period=[None, None], |
| val_period=[None, None], |
| sample_dir=None, |
| ) |
|
|
| |
| assert dm is not None |
| assert hasattr(dm, "train_dataloader") |
|
|