File size: 1,973 Bytes
04c78c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import typing
from pathlib import Path
from PIL import Image
import random

import torch
from easydict import EasyDict
import torchvision.transforms.functional as tf
from torch.utils.data import Dataset

from ..utils.log import get_matlist
from . import augment as Aug


class StableDiffusion(Dataset):
    def __init__(
        self,
        split,
        pseudo_labels,
        transform,
        renderer,
        matlist,
        use_ref=False,
        dir: typing.Optional[Path] = None,
        **kwargs
    ):
        assert dir.is_dir()
        assert split in ['train', 'valid', 'all']

        self.split = split
        self.renderer = renderer
        self.pseudo_labels = pseudo_labels
        self.use_ref = use_ref
        # self.pl_dir = pl_dir

        if matlist == None:
            files = sorted(dir.rglob('**/outputs/*[0-9].png'))
            files += sorted(dir.rglob('**/out_renorm/*[0-9].png'))
            print(f'total={len(files)}')
            files = [x for x in files if not (x.parent/f'{x.stem}_roughness.png').is_file()]
            print(f'after={len(files)}')
        else:
            files = get_matlist(matlist, dir)

        ### Train/Validation Split
        k = int(len(files)*.98)
        if split == 'train':
            self.files = files[:k]
        elif split == 'valid':
            self.files = files[k:]
        elif split == 'all':
            self.files = files

        random.shuffle(self.files)

        print(f'StableDiffusion list={matlist}:{self.split}=[{len(self.files)}/{len(files)}]')

        dtypes = ['input']
        self.tf = Aug.Pipeline(*transform, dtypes=dtypes)

    def __getitem__(self, index):
        path = self.files[index]
        name = path.stem

        o = EasyDict(dir=str(path.parent), name=name)

        I = tf.to_tensor(Image.open(path).convert('RGB'))

        o.path = str(path)
        o.input, *_ = self.tf([I])
        return o

    def __len__(self):
        return len(self.files)