File size: 5,619 Bytes
e7c18b3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | from functools import partial
import numpy as np
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
import argparse, os, sys, glob
os.chdir(sys.path[0])
sys.path.append("..")
from utils.common_utils import instantiate_from_config
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
mode = "image" if worker_id < worker_info.num_workers * 0.2 else "video"
print(f"Mode is {mode}")
dataset.set_mode(mode)
return np.random.seed(np.random.get_state()[1][0] + worker_id)
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size,
train=None,
validation=None,
test=None,
predict=None,
wrap=False,
num_workers=None,
shuffle_test_loader=False,
img_video_joint_train=False,
shuffle_val_dataloader=False,
train_img=None,
test_max_n_samples=None,
):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = dict()
self.num_workers = num_workers if num_workers is not None else batch_size * 2
self.use_worker_init_fn = img_video_joint_train
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(
self._val_dataloader, shuffle=shuffle_val_dataloader
)
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = partial(
self._test_dataloader, shuffle=shuffle_test_loader
)
if predict is not None:
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
# train 2 dataset
# if img_loader is not None:
# img_data = instantiate_from_config(img_loader)
# img_data.setup()
if train_img is not None:
if train_img["params"]["batch_size"] == -1:
train_img["params"]["batch_size"] = (
batch_size * train["params"]["video_length"]
)
print(
"Set train_img batch_size to {}".format(
train_img["params"]["batch_size"]
)
)
img_data = instantiate_from_config(train_img)
self.img_loader = img_data.train_dataloader()
else:
self.img_loader = None
self.wrap = wrap
self.test_max_n_samples = test_max_n_samples
self.collate_fn = None
def prepare_data(self):
# for data_cfg in self.dataset_configs.values():
# instantiate_from_config(data_cfg)
pass
def setup(self, stage=None):
self.datasets = dict(
(k, instantiate_from_config(self.dataset_configs[k]))
for k in self.dataset_configs
)
if self.wrap:
for k in self.datasets:
self.datasets[k] = WrappedDataset(self.datasets[k])
def _train_dataloader(self):
if self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
loader = DataLoader(
self.datasets["train"],
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
worker_init_fn=init_fn,
collate_fn=self.collate_fn,
)
if self.img_loader is not None:
return {"loader_video": loader, "loader_img": self.img_loader}
else:
return loader
def _val_dataloader(self, shuffle=False):
if self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
collate_fn=self.collate_fn,
)
def _test_dataloader(self, shuffle=False):
if self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
if self.test_max_n_samples is not None:
dataset = torch.utils.data.Subset(
self.datasets["test"], list(range(self.test_max_n_samples))
)
else:
dataset = self.datasets["test"]
return DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
collate_fn=self.collate_fn,
)
def _predict_dataloader(self, shuffle=False):
if self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["predict"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
collate_fn=self.collate_fn,
)
|