Spaces:
Sleeping
Sleeping
File size: 4,499 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import random
from dataclasses import dataclass
from typing import Callable
import numpy as np
import torch
from pytorch_lightning import LightningDataModule
from torch import Generator, nn
from torch.utils.data import DataLoader, Dataset, IterableDataset
from . import DatasetCfg, get_dataset
from .data_types import DataShim, Stage
from .validation_wrapper import ValidationWrapper
from ..misc.step_tracker import StepTracker
def get_data_shim(encoder: nn.Module) -> DataShim:
"""Get functions that modify the batch. It's sometimes necessary to modify batches
outside the data loader because GPU computations are required to modify the batch or
because the modification depends on something outside the data loader.
"""
shims: list[DataShim] = []
if hasattr(encoder, "get_data_shim"):
shims.append(encoder.get_data_shim())
def combined_shim(batch):
for shim in shims:
batch = shim(batch)
return batch
return combined_shim
@dataclass
class DataLoaderStageCfg:
batch_size: int
num_workers: int
persistent_workers: bool
seed: int | None
@dataclass
class DataLoaderCfg:
train: DataLoaderStageCfg
test: DataLoaderStageCfg
val: DataLoaderStageCfg
DatasetShim = Callable[[Dataset, Stage], Dataset]
def worker_init_fn(worker_id: int) -> None:
random.seed(int(torch.utils.data.get_worker_info().seed) % (2 ** 32 - 1))
np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2 ** 32 - 1))
class DataModule(LightningDataModule):
dataset_cfg: DatasetCfg
data_loader_cfg: DataLoaderCfg
step_tracker: StepTracker | None
dataset_shim: DatasetShim
global_rank: int
def __init__(
self,
dataset_cfg: DatasetCfg,
data_loader_cfg: DataLoaderCfg,
step_tracker: StepTracker | None = None,
dataset_shim: DatasetShim = lambda dataset, _: dataset,
global_rank: int = 0,
) -> None:
super().__init__()
self.dataset_cfg = dataset_cfg
self.data_loader_cfg = data_loader_cfg
self.step_tracker = step_tracker
self.dataset_shim = dataset_shim
self.global_rank = global_rank
def get_persistent(self, loader_cfg: DataLoaderStageCfg) -> bool | None:
return None if loader_cfg.num_workers == 0 else loader_cfg.persistent_workers
def get_generator(self, loader_cfg: DataLoaderStageCfg) -> torch.Generator | None:
if loader_cfg.seed is None:
return None
generator = Generator()
generator.manual_seed(loader_cfg.seed + self.global_rank)
return generator
def train_dataloader(self):
loader_cfg = self.data_loader_cfg.train
dataset = get_dataset(
self.dataset_cfg,
"train",
self.step_tracker,
)
dataset = self.dataset_shim(dataset, "train")
return DataLoader(
dataset,
loader_cfg.batch_size,
shuffle=not isinstance(dataset, IterableDataset),
num_workers=loader_cfg.num_workers,
generator=self.get_generator(loader_cfg),
worker_init_fn=worker_init_fn,
persistent_workers=self.get_persistent(loader_cfg),
)
def val_dataloader(self):
loader_cfg = self.data_loader_cfg.val
dataset = get_dataset(
self.dataset_cfg,
"val",
self.step_tracker,
)
dataset = self.dataset_shim(dataset, "val")
return DataLoader(
ValidationWrapper(dataset, 1),
loader_cfg.batch_size,
num_workers=loader_cfg.num_workers,
generator=self.get_generator(loader_cfg),
worker_init_fn=worker_init_fn,
persistent_workers=self.get_persistent(loader_cfg),
)
def test_dataloader(self, dataset_cfg=None):
loader_cfg = self.data_loader_cfg.test
dataset = get_dataset(
self.dataset_cfg if dataset_cfg is None else dataset_cfg,
"test",
self.step_tracker,
)
dataset = self.dataset_shim(dataset, "test")
return DataLoader(
dataset,
loader_cfg.batch_size,
num_workers=loader_cfg.num_workers,
generator=self.get_generator(loader_cfg),
worker_init_fn=worker_init_fn,
persistent_workers=self.get_persistent(loader_cfg),
shuffle=False,
)
|