| | import typing |
| | from pathlib import Path |
| | from pprint import pprint |
| | from torch.utils.data import DataLoader |
| | from pytorch_lightning import LightningDataModule |
| |
|
| | from . import augment as Aug |
| | from ..render import Renderer |
| | from .utils import MultiLoader, EmptyDataset, collate_fn |
| | from .source import AmbientCG |
| | from .source import FabricsDataset |
| | from .target import StableDiffusion |
| |
|
| |
|
| | def is_set(x): |
| | return x is not None |
| |
|
| | class DataModule(LightningDataModule): |
| | def __init__( |
| | self, |
| | batch_size: int = 1, |
| | num_workers: int = 1, |
| | transform: bool = None, |
| | source_ds: str = '', |
| | target_ds: str = '', |
| | test_ds: str = '', |
| | predict_ds: str = '', |
| | source_list: typing.Optional[Path] = None, |
| | target_list: typing.Optional[Path] = None, |
| | target_val_list: typing.Optional[Path] = None, |
| | test_list: typing.Optional[Path] = None, |
| | predict_list: typing.Optional[Path] = None, |
| | source_dir: typing.Optional[Path] = None, |
| | target_dir: typing.Optional[Path] = None, |
| | predict_dir: typing.Optional[Path] = None, |
| | test_dir: typing.Optional[Path] = None, |
| | pseudo_labels: bool = False, |
| | input_size: int = 512, |
| | ): |
| | super().__init__() |
| | self.batch_size = batch_size |
| | self.num_workers = num_workers |
| | self.transform = transform |
| | self.pseudo_labels = pseudo_labels |
| |
|
| | self.source_ds = source_ds |
| | self.target_ds = target_ds |
| | self.test_ds = test_ds |
| | self.predict_ds = predict_ds |
| |
|
| | self.source_list = source_list |
| | self.target_list = target_list |
| | self.target_val_list = target_val_list |
| | self.predict_list = predict_list |
| | self.test_list = test_list |
| |
|
| | self.source_dir = source_dir |
| | self.target_dir = target_dir |
| | self.predict_dir = predict_dir |
| | self.test_dir = test_dir |
| |
|
| | self.input_size = input_size |
| | |
| |
|
| | assert self.source_ds or self.target_ds or self.test_ds or self.predict_ds |
| | if self.source_ds: |
| | assert is_set(source_list) |
| | if self.target_ds: |
| | assert is_set(target_list) |
| | if self.target_ds != 'sd': |
| | assert is_set(target_val_list) |
| |
|
| | def setup(self, stage: str): |
| | renderer = Renderer(return_params=True) |
| | eval_tf = [ |
| | Aug.NormalizeGeometry(), |
| | |
| | Aug.Resize([self.input_size, self.input_size], antialias=True)] |
| |
|
| |
|
| | if stage == 'fit': |
| | if self.transform: |
| | train_tf = [ |
| | Aug.RandomResizedCrop((512,512), scale=(1/16, 1/4), ratio=(1.,1.)), |
| | |
| | Aug.NormalizeGeometry(), |
| | Aug.RandomHorizontalFlip(), |
| | Aug.RandomVerticalFlip(), |
| | Aug.RandomIncrementRotate(p=1.), |
| | Aug.ColorJitter(brightness=.2, hue=.05, contrast=0.1) |
| | ] |
| | else: |
| | train_tf = [ |
| | Aug.CenterCrop((self.input_size, self.input_size)), |
| | Aug.NormalizeGeometry()] |
| | train_kwargs = dict(pseudo_labels=self.pseudo_labels, |
| | renderer=renderer, |
| | transform=train_tf) |
| | print('stage fit:') |
| | pprint(train_kwargs) |
| |
|
| | |
| | if self.source_ds == 'acg': |
| | self.src_train = FabricsDataset(split='train', |
| | dir=self.source_dir, |
| | matlist=self.source_list, |
| | **train_kwargs) |
| | |
| | if self.target_ds == 'sd': |
| | self.tgt_train = StableDiffusion(split='train', |
| | |
| | dir=self.target_dir, |
| | matlist=self.target_list, |
| | **train_kwargs) |
| |
|
| | if not self.source_ds: |
| | self.src_train = EmptyDataset(len(self.tgt_train)) |
| | if not self.target_ds: |
| | self.tgt_train = EmptyDataset(len(self.src_train)) |
| |
|
| | if stage == 'fit' or stage == 'validate': |
| | validate_kwargs = dict(transform=eval_tf, |
| | renderer=renderer, |
| | set_seed_render=True) |
| |
|
| | |
| | if self.source_ds == 'acg': |
| | self.src_valid = FabricsDataset(split='valid', |
| | dir=self.source_dir, |
| | matlist=self.source_list, |
| | **validate_kwargs) |
| | |
| | if self.target_ds == 'sd': |
| | self.tgt_valid = StableDiffusion(split='valid', |
| | pseudo_labels=False, |
| | dir=self.target_dir, |
| | |
| | matlist=self.target_list, |
| | **validate_kwargs) |
| |
|
| | if not self.source_ds: |
| | self.src_valid = EmptyDataset(len(self.tgt_valid)) |
| | if not self.target_ds: |
| | self.tgt_valid = EmptyDataset(len(self.src_valid)) |
| |
|
| | elif stage == 'test': |
| | assert self.test_ds |
| |
|
| | test_kwargs = dict(pseudo_labels=False, |
| | matlist=self.test_list, |
| | transform=eval_tf, |
| | renderer=renderer, |
| | dir=self.test_dir, |
| | set_seed_render=True) |
| |
|
| | if self.test_ds == 'acg': |
| | self.eval = [FabricsDataset(split='all', **test_kwargs)] |
| | elif self.test_ds == 'sd': |
| | self.eval = [StableDiffusion(split='all', **test_kwargs)] |
| |
|
| | elif stage == 'predict': |
| | predict_kwargs = dict(split='all', |
| | pseudo_labels=False, |
| | dir=self.predict_dir, |
| | matlist=None, |
| | transform=eval_tf, |
| | renderer=renderer) |
| |
|
| | if self.predict_ds == 'sd': |
| | self.ds = StableDiffusion(**predict_kwargs) |
| |
|
| | def train_dataloader(self): |
| | src_dl = DataLoader(dataset=self.src_train, |
| | batch_size=self.batch_size, |
| | drop_last=True, |
| | shuffle=True, |
| | num_workers=self.num_workers, |
| | collate_fn=collate_fn) |
| | tgt_dl = DataLoader(dataset=self.tgt_train, |
| | batch_size=self.batch_size, |
| | drop_last=True, |
| | shuffle=True, |
| | num_workers=self.num_workers, |
| | collate_fn=collate_fn) |
| |
|
| | mix = MultiLoader(src_dl, tgt_dl) |
| | return mix |
| |
|
| | def val_dataloader(self): |
| | src_dl = DataLoader(dataset=self.src_valid, |
| | batch_size=self.batch_size, |
| | drop_last=False, |
| | shuffle=False, |
| | num_workers=self.num_workers, |
| | collate_fn=collate_fn) |
| | tgt_dl = DataLoader(dataset=self.tgt_valid, |
| | batch_size=self.batch_size, |
| | drop_last=False, |
| | shuffle=False, |
| | num_workers=self.num_workers, |
| | collate_fn=collate_fn) |
| |
|
| | mix = MultiLoader(src_dl, tgt_dl) |
| | return mix |
| |
|
| | def test_dataloader(self): |
| | return [DataLoader(dataset=ds, |
| | batch_size=self.batch_size, |
| | drop_last=False, |
| | shuffle=False, |
| | num_workers=self.num_workers) for ds in self.eval] |
| |
|
| | def predict_dataloader(self): |
| | return DataLoader(dataset=self.ds, |
| | batch_size=1, |
| | drop_last=False, |
| | shuffle=False, |
| | num_workers=1) |