vrevar
Add application file
04c78c7
import typing
from pathlib import Path
from pprint import pprint
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule
from . import augment as Aug
from ..render import Renderer
from .utils import MultiLoader, EmptyDataset, collate_fn
from .source import AmbientCG
from .source import FabricsDataset
from .target import StableDiffusion
def is_set(x):
return x is not None
class DataModule(LightningDataModule):
def __init__(
self,
batch_size: int = 1,
num_workers: int = 1,
transform: bool = None,
source_ds: str = '',
target_ds: str = '',
test_ds: str = '',
predict_ds: str = '',
source_list: typing.Optional[Path] = None,
target_list: typing.Optional[Path] = None,
target_val_list: typing.Optional[Path] = None,
test_list: typing.Optional[Path] = None,
predict_list: typing.Optional[Path] = None,
source_dir: typing.Optional[Path] = None,
target_dir: typing.Optional[Path] = None,
predict_dir: typing.Optional[Path] = None,
test_dir: typing.Optional[Path] = None,
pseudo_labels: bool = False,
input_size: int = 512,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.transform = transform
self.pseudo_labels = pseudo_labels
self.source_ds = source_ds
self.target_ds = target_ds
self.test_ds = test_ds
self.predict_ds = predict_ds
self.source_list = source_list
self.target_list = target_list
self.target_val_list = target_val_list
self.predict_list = predict_list
self.test_list = test_list
self.source_dir = source_dir
self.target_dir = target_dir
self.predict_dir = predict_dir
self.test_dir = test_dir
self.input_size = input_size
# self.use_ref = use_ref
assert self.source_ds or self.target_ds or self.test_ds or self.predict_ds
if self.source_ds:
assert is_set(source_list)
if self.target_ds:
assert is_set(target_list)
if self.target_ds != 'sd':
assert is_set(target_val_list)
def setup(self, stage: str):
renderer = Renderer(return_params=True)
eval_tf = [
Aug.NormalizeGeometry(),
# Aug.CenterCrop((2048,2048)),
Aug.Resize([self.input_size, self.input_size], antialias=True)]
if stage == 'fit':
if self.transform:
train_tf = [
Aug.RandomResizedCrop((512,512), scale=(1/16, 1/4), ratio=(1.,1.)),
# Aug.RandomCrop(self.input_size),
Aug.NormalizeGeometry(),
Aug.RandomHorizontalFlip(),
Aug.RandomVerticalFlip(),
Aug.RandomIncrementRotate(p=1.),
Aug.ColorJitter(brightness=.2, hue=.05, contrast=0.1)
]
else:
train_tf = [
Aug.CenterCrop((self.input_size, self.input_size)),
Aug.NormalizeGeometry()]
train_kwargs = dict(pseudo_labels=self.pseudo_labels,
renderer=renderer,
transform=train_tf)
print('stage fit:')
pprint(train_kwargs)
## SOURCE train dataset
if self.source_ds == 'acg':
self.src_train = FabricsDataset(split='train',
dir=self.source_dir,
matlist=self.source_list,
**train_kwargs)
## TARGET train dataset
if self.target_ds == 'sd':
self.tgt_train = StableDiffusion(split='train',
# use_ref=self.use_ref,
dir=self.target_dir,
matlist=self.target_list,
**train_kwargs)
if not self.source_ds:
self.src_train = EmptyDataset(len(self.tgt_train))
if not self.target_ds:
self.tgt_train = EmptyDataset(len(self.src_train))
if stage == 'fit' or stage == 'validate':
validate_kwargs = dict(transform=eval_tf,
renderer=renderer,
set_seed_render=True)
## SOURCE validation dataset
if self.source_ds == 'acg':
self.src_valid = FabricsDataset(split='valid',
dir=self.source_dir,
matlist=self.source_list,
**validate_kwargs)
## TARGET validation dataset
if self.target_ds == 'sd':
self.tgt_valid = StableDiffusion(split='valid',
pseudo_labels=False,
dir=self.target_dir,
# use_ref=self.use_ref,
matlist=self.target_list,
**validate_kwargs)
if not self.source_ds:
self.src_valid = EmptyDataset(len(self.tgt_valid))
if not self.target_ds:
self.tgt_valid = EmptyDataset(len(self.src_valid))
elif stage == 'test':
assert self.test_ds
test_kwargs = dict(pseudo_labels=False,
matlist=self.test_list,
transform=eval_tf,
renderer=renderer,
dir=self.test_dir,
set_seed_render=True)
if self.test_ds == 'acg':
self.eval = [FabricsDataset(split='all', **test_kwargs)]
elif self.test_ds == 'sd':
self.eval = [StableDiffusion(split='all', **test_kwargs)]
elif stage == 'predict':
predict_kwargs = dict(split='all',
pseudo_labels=False,
dir=self.predict_dir,
matlist=None,
transform=eval_tf,
renderer=renderer)
if self.predict_ds == 'sd':
self.ds = StableDiffusion(**predict_kwargs)
def train_dataloader(self):
src_dl = DataLoader(dataset=self.src_train,
batch_size=self.batch_size,
drop_last=True,
shuffle=True,
num_workers=self.num_workers,
collate_fn=collate_fn)
tgt_dl = DataLoader(dataset=self.tgt_train,
batch_size=self.batch_size,
drop_last=True,
shuffle=True,
num_workers=self.num_workers,
collate_fn=collate_fn)
mix = MultiLoader(src_dl, tgt_dl)
return mix
def val_dataloader(self):
src_dl = DataLoader(dataset=self.src_valid,
batch_size=self.batch_size,
drop_last=False,
shuffle=False,
num_workers=self.num_workers,
collate_fn=collate_fn)
tgt_dl = DataLoader(dataset=self.tgt_valid,
batch_size=self.batch_size,
drop_last=False,
shuffle=False,
num_workers=self.num_workers,
collate_fn=collate_fn)
mix = MultiLoader(src_dl, tgt_dl)
return mix
def test_dataloader(self):
return [DataLoader(dataset=ds,
batch_size=self.batch_size,
drop_last=False,
shuffle=False,
num_workers=self.num_workers) for ds in self.eval]
def predict_dataloader(self):
return DataLoader(dataset=self.ds,
batch_size=1,
drop_last=False,
shuffle=False,
num_workers=1)