| |
| |
| """ |
| Dataset class for image-caption |
| @author: Tu Bui @University of Surrey |
| """ |
| import json |
| from PIL import Image |
| import numpy as np |
| from pathlib import Path |
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from functools import partial |
| import pytorch_lightning as pl |
| from ldm.util import instantiate_from_config |
| import pandas as pd |
|
|
|
|
| def worker_init_fn(_): |
| worker_info = torch.utils.data.get_worker_info() |
| worker_id = worker_info.id |
| return np.random.seed(np.random.get_state()[1][0] + worker_id) |
|
|
|
|
| class WrappedDataset(Dataset): |
| """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" |
|
|
| def __init__(self, dataset): |
| self.data = dataset |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| return self.data[idx] |
|
|
|
|
| class DataModuleFromConfig(pl.LightningDataModule): |
| def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, |
| shuffle_val_dataloader=False): |
| super().__init__() |
| self.batch_size = batch_size |
| self.dataset_configs = dict() |
| self.num_workers = num_workers if num_workers is not None else batch_size * 2 |
| self.use_worker_init_fn = use_worker_init_fn |
| if train is not None: |
| self.dataset_configs["train"] = train |
| self.train_dataloader = self._train_dataloader |
| if validation is not None: |
| self.dataset_configs["validation"] = validation |
| self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) |
| if test is not None: |
| self.dataset_configs["test"] = test |
| self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) |
| if predict is not None: |
| self.dataset_configs["predict"] = predict |
| self.predict_dataloader = self._predict_dataloader |
| self.wrap = wrap |
|
|
| def prepare_data(self): |
| for data_cfg in self.dataset_configs.values(): |
| instantiate_from_config(data_cfg) |
|
|
| def setup(self, stage=None): |
| self.datasets = dict( |
| (k, instantiate_from_config(self.dataset_configs[k])) |
| for k in self.dataset_configs) |
| if self.wrap: |
| for k in self.datasets: |
| self.datasets[k] = WrappedDataset(self.datasets[k]) |
|
|
| def _train_dataloader(self): |
| if self.use_worker_init_fn: |
| init_fn = worker_init_fn |
| else: |
| init_fn = None |
| return DataLoader(self.datasets["train"], batch_size=self.batch_size, |
| num_workers=self.num_workers, shuffle=True, |
| worker_init_fn=init_fn) |
|
|
| def _val_dataloader(self, shuffle=False): |
| if self.use_worker_init_fn: |
| init_fn = worker_init_fn |
| else: |
| init_fn = None |
| return DataLoader(self.datasets["validation"], |
| batch_size=self.batch_size, |
| num_workers=self.num_workers, |
| worker_init_fn=init_fn, |
| shuffle=shuffle) |
|
|
| def _test_dataloader(self, shuffle=False): |
| if self.use_worker_init_fn: |
| init_fn = worker_init_fn |
| else: |
| init_fn = None |
|
|
| return DataLoader(self.datasets["test"], batch_size=self.batch_size, |
| num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle) |
|
|
| def _predict_dataloader(self, shuffle=False): |
| if self.use_worker_init_fn: |
| init_fn = worker_init_fn |
| else: |
| init_fn = None |
| return DataLoader(self.datasets["predict"], batch_size=self.batch_size, |
| num_workers=self.num_workers, worker_init_fn=init_fn) |
|
|
|
|
| class ImageCaptionRaw(Dataset): |
| def __init__(self, image_dir, caption_file, secret_len=100, transform=None): |
| super().__init__() |
| self.image_dir = Path(image_dir) |
| self.data = [] |
| with open(caption_file, 'rt') as f: |
| for line in f: |
| self.data.append(json.loads(line)) |
| self.secret_len = secret_len |
| self.transform = transform |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| item = self.data[idx] |
| image = Image.open(self.image_dir/item['image']).convert('RGB').resize((512,512)) |
| caption = item['captions'] |
| cid = torch.randint(0, len(caption), (1,)).item() |
| caption = caption[cid] |
| if self.transform is not None: |
| image = self.transform(image) |
|
|
| image = np.array(image, dtype=np.float32)/ 255.0 |
| target = image * 2.0 - 1.0 |
| secret = torch.zeros(self.secret_len, dtype=torch.float).random_(0, 2) |
| return dict(image=image, caption=caption, target=target, secret=secret) |
|
|
|
|
| class BAMFG(Dataset): |
| def __init__(self, style_dir, gt_dir, data_list, transform=None): |
| super().__init__() |
| self.style_dir = Path(style_dir) |
| self.gt_dir = Path(gt_dir) |
| self.data = pd.read_csv(data_list) |
| self.transform = transform |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| item = self.data.iloc[idx] |
| gt_img = Image.open(self.gt_dir/item['gt_img']).convert('RGB').resize((512,512)) |
| style_img = Image.open(self.style_dir/item['style_img']).convert('RGB').resize((512,512)) |
| txt = item['prompt'] |
| if self.transform is not None: |
| gt_img = self.transform(gt_img) |
| style_img = self.transform(style_img) |
|
|
| gt_img = np.array(gt_img, dtype=np.float32)/ 255.0 |
| style_img = np.array(style_img, dtype=np.float32)/ 255.0 |
| target = gt_img * 2.0 - 1.0 |
|
|
| return dict(image=gt_img, txt=txt, hint=style_img) |