|
|
import random |
|
|
from dataclasses import dataclass |
|
|
from typing import Callable |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from pytorch_lightning import LightningDataModule |
|
|
from torch import Generator, nn |
|
|
from torch.utils.data import DataLoader, Dataset, IterableDataset |
|
|
|
|
|
from ..misc.step_tracker import StepTracker |
|
|
from . import DatasetCfg, get_dataset |
|
|
from .types import DataShim, Stage |
|
|
from .validation_wrapper import ValidationWrapper |
|
|
|
|
|
|
|
|
def get_data_shim(encoder: nn.Module) -> DataShim: |
|
|
"""Get functions that modify the batch. It's sometimes necessary to modify batches |
|
|
outside the data loader because GPU computations are required to modify the batch or |
|
|
because the modification depends on something outside the data loader. |
|
|
""" |
|
|
|
|
|
shims: list[DataShim] = [] |
|
|
if hasattr(encoder, "get_data_shim"): |
|
|
shims.append(encoder.get_data_shim()) |
|
|
|
|
|
def combined_shim(batch): |
|
|
for shim in shims: |
|
|
batch = shim(batch) |
|
|
return batch |
|
|
|
|
|
return combined_shim |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataLoaderStageCfg: |
|
|
batch_size: int |
|
|
num_workers: int |
|
|
persistent_workers: bool |
|
|
seed: int | None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataLoaderCfg: |
|
|
train: DataLoaderStageCfg |
|
|
test: DataLoaderStageCfg |
|
|
val: DataLoaderStageCfg |
|
|
|
|
|
|
|
|
DatasetShim = Callable[[Dataset, Stage], Dataset] |
|
|
|
|
|
|
|
|
def worker_init_fn(worker_id: int) -> None: |
|
|
random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) |
|
|
np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) |
|
|
|
|
|
|
|
|
class DataModule(LightningDataModule): |
|
|
dataset_cfg: DatasetCfg |
|
|
data_loader_cfg: DataLoaderCfg |
|
|
step_tracker: StepTracker | None |
|
|
dataset_shim: DatasetShim |
|
|
global_rank: int |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataset_cfg: DatasetCfg, |
|
|
data_loader_cfg: DataLoaderCfg, |
|
|
step_tracker: StepTracker | None = None, |
|
|
dataset_shim: DatasetShim = lambda dataset, _: dataset, |
|
|
global_rank: int = 0, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.dataset_cfg = dataset_cfg |
|
|
self.data_loader_cfg = data_loader_cfg |
|
|
self.step_tracker = step_tracker |
|
|
self.dataset_shim = dataset_shim |
|
|
self.global_rank = global_rank |
|
|
|
|
|
def get_persistent(self, loader_cfg: DataLoaderStageCfg) -> bool | None: |
|
|
return None if loader_cfg.num_workers == 0 else loader_cfg.persistent_workers |
|
|
|
|
|
def get_generator(self, loader_cfg: DataLoaderStageCfg) -> torch.Generator | None: |
|
|
if loader_cfg.seed is None: |
|
|
return None |
|
|
generator = Generator() |
|
|
generator.manual_seed(loader_cfg.seed + self.global_rank) |
|
|
return generator |
|
|
|
|
|
def train_dataloader(self): |
|
|
dataset = get_dataset(self.dataset_cfg, "train", self.step_tracker) |
|
|
dataset = self.dataset_shim(dataset, "train") |
|
|
return DataLoader( |
|
|
dataset, |
|
|
self.data_loader_cfg.train.batch_size, |
|
|
shuffle=not isinstance(dataset, IterableDataset), |
|
|
num_workers=self.data_loader_cfg.train.num_workers, |
|
|
generator=self.get_generator(self.data_loader_cfg.train), |
|
|
worker_init_fn=worker_init_fn, |
|
|
persistent_workers=self.get_persistent(self.data_loader_cfg.train), |
|
|
) |
|
|
|
|
|
def val_dataloader(self): |
|
|
dataset = get_dataset(self.dataset_cfg, "val", self.step_tracker) |
|
|
dataset = self.dataset_shim(dataset, "val") |
|
|
return DataLoader( |
|
|
ValidationWrapper(dataset, 1), |
|
|
self.data_loader_cfg.val.batch_size, |
|
|
num_workers=self.data_loader_cfg.val.num_workers, |
|
|
generator=self.get_generator(self.data_loader_cfg.val), |
|
|
worker_init_fn=worker_init_fn, |
|
|
persistent_workers=self.get_persistent(self.data_loader_cfg.val), |
|
|
) |
|
|
|
|
|
def test_dataloader(self, dataset_cfg=None): |
|
|
dataset = get_dataset( |
|
|
self.dataset_cfg if dataset_cfg is None else dataset_cfg, |
|
|
"test", |
|
|
self.step_tracker, |
|
|
) |
|
|
dataset = self.dataset_shim(dataset, "test") |
|
|
return DataLoader( |
|
|
dataset, |
|
|
self.data_loader_cfg.test.batch_size, |
|
|
num_workers=self.data_loader_cfg.test.num_workers, |
|
|
generator=self.get_generator(self.data_loader_cfg.test), |
|
|
worker_init_fn=worker_init_fn, |
|
|
persistent_workers=self.get_persistent(self.data_loader_cfg.test), |
|
|
shuffle=False, |
|
|
) |
|
|
|