File size: 1,973 Bytes
04c78c7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | import typing
from pathlib import Path
from PIL import Image
import random
import torch
from easydict import EasyDict
import torchvision.transforms.functional as tf
from torch.utils.data import Dataset
from ..utils.log import get_matlist
from . import augment as Aug
class StableDiffusion(Dataset):
def __init__(
self,
split,
pseudo_labels,
transform,
renderer,
matlist,
use_ref=False,
dir: typing.Optional[Path] = None,
**kwargs
):
assert dir.is_dir()
assert split in ['train', 'valid', 'all']
self.split = split
self.renderer = renderer
self.pseudo_labels = pseudo_labels
self.use_ref = use_ref
# self.pl_dir = pl_dir
if matlist == None:
files = sorted(dir.rglob('**/outputs/*[0-9].png'))
files += sorted(dir.rglob('**/out_renorm/*[0-9].png'))
print(f'total={len(files)}')
files = [x for x in files if not (x.parent/f'{x.stem}_roughness.png').is_file()]
print(f'after={len(files)}')
else:
files = get_matlist(matlist, dir)
### Train/Validation Split
k = int(len(files)*.98)
if split == 'train':
self.files = files[:k]
elif split == 'valid':
self.files = files[k:]
elif split == 'all':
self.files = files
random.shuffle(self.files)
print(f'StableDiffusion list={matlist}:{self.split}=[{len(self.files)}/{len(files)}]')
dtypes = ['input']
self.tf = Aug.Pipeline(*transform, dtypes=dtypes)
def __getitem__(self, index):
path = self.files[index]
name = path.stem
o = EasyDict(dir=str(path.parent), name=name)
I = tf.to_tensor(Image.open(path).convert('RGB'))
o.path = str(path)
o.input, *_ = self.tf([I])
return o
def __len__(self):
return len(self.files)
|