File size: 4,315 Bytes
cbe6208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
""" Data module for pytorch lightning """

from glob import glob

from lightning.pytorch import LightningDataModule
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
from ocf_data_sampler.torch_datasets.sample.base import (
    NumpyBatch,
    SampleBase,
    TensorBatch,
    batch_to_tensor,
)
from torch.utils.data import DataLoader, Dataset


def collate_fn(samples: list[NumpyBatch]) -> TensorBatch:
    """Convert a list of NumpySample samples to a tensor batch"""
    return batch_to_tensor(stack_np_samples_into_batch(samples))


class PremadeSamplesDataset(Dataset):
    """Dataset to load samples from

    Args:
        sample_dir: Path to the directory of pre-saved samples.
        sample_class: sample class type to use for save/load/to_numpy
    """

    def __init__(self, sample_dir: str, sample_class: SampleBase):
        """Initialise PremadeSamplesDataset"""
        self.sample_paths = glob(f"{sample_dir}/*")
        self.sample_class = sample_class

    def __len__(self):
        return len(self.sample_paths)

    def __getitem__(self, idx):
        sample = self.sample_class.load(self.sample_paths[idx])
        return sample.to_numpy()


class BaseDataModule(LightningDataModule):
    """Base Datamodule for training pvnet and using pvnet pipeline in ocf-data-sampler."""

    def __init__(
        self,
        configuration: str | None = None,
        sample_dir: str | None = None,
        batch_size: int = 16,
        num_workers: int = 0,
        prefetch_factor: int | None = None,
        train_period: list[str | None] = [None, None],
        val_period: list[str | None] = [None, None],
    ):
        """Base Datamodule for training pvnet architecture.

        Can also be used with pre-made batches if `sample_dir` is set.

        Args:
            configuration: Path to ocf-data-sampler configuration file.
            sample_dir: Path to the directory of pre-saved samples. Cannot be used together with
                `configuration` or '[train/val]_period'.
            batch_size: Batch size.
            num_workers: Number of workers to use in multiprocess batch loading.
            prefetch_factor: Number of data will be prefetched at the end of each worker process.
            train_period: Date range filter for train dataloader.
            val_period: Date range filter for val dataloader.

        """
        super().__init__()

        if not ((sample_dir is not None) ^ (configuration is not None)):
            raise ValueError("Exactly one of `sample_dir` or `configuration` must be set.")

        if sample_dir is not None:
            if any([period != [None, None] for period in [train_period, val_period]]):
                raise ValueError("Cannot set `(train/val)_period` with presaved samples")

        self.configuration = configuration
        self.sample_dir = sample_dir
        self.train_period = train_period
        self.val_period = val_period

        self._common_dataloader_kwargs = dict(
            batch_size=batch_size,
            sampler=None,
            batch_sampler=None,
            num_workers=num_workers,
            collate_fn=collate_fn,
            pin_memory=False,
            drop_last=False,
            timeout=0,
            worker_init_fn=None,
            prefetch_factor=prefetch_factor,
            persistent_workers=False,
        )

    def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
        raise NotImplementedError

    def _get_premade_samples_dataset(self, subdir) -> Dataset:
        raise NotImplementedError

    def train_dataloader(self) -> DataLoader:
        """Construct train dataloader"""
        if self.sample_dir is not None:
            dataset = self._get_premade_samples_dataset("train")
        else:
            dataset = self._get_streamed_samples_dataset(*self.train_period)
        return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)

    def val_dataloader(self) -> DataLoader:
        """Construct val dataloader"""
        if self.sample_dir is not None:
            dataset = self._get_premade_samples_dataset("val")
        else:
            dataset = self._get_streamed_samples_dataset(*self.val_period)
        return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)