File size: 11,675 Bytes
c28dddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
import json
import numpy as np
from PIL import Image
import torchvision.transforms as T
from dataset.base_dataset import BaseDataset
import random
from tqdm import tqdm
import imageio
import torch

def make_white_background(src_img):
    '''Make the white background for the input RGBA image.'''
    src_img.load() 
    background = Image.new("RGB", src_img.size, (255, 255, 255))
    background.paste(src_img, mask=src_img.split()[3]) # 3 is the alpha channel
    return background

class MyDataset(BaseDataset):

    """
    Dataset for training and testing on the PartNet-Mobility and ACD datasets (with our preprocessing).
    The GT graph is given.
    """

    def __init__(self, hparams, model_ids, mode="train", json_name="object.json"):
        self.hparams = hparams
        self.json_name = json_name
        self.model_ids = self._filter_models(model_ids)
        self.mode = mode
        self.map_cat = False
        self.get_acd_mapping()

        self.no_GT = (
            True if self.hparams.get("test_no_GT", False) and self.hparams.get("test_pred_G", False)
            else False
        )
        self.pred_G = (
            False
            if mode in ["train", "val"]
            else self.hparams.get("test_pred_G", False)
        )

        if mode == 'test':
            if "acd" in hparams.test_which:
                self.map_cat = True
        
        self.files = self._cache_data()
        print(f"[INFO] {mode} dataset: {len(self)} data samples loaded.")

    def _cache_data_train(self):
        json_data_root = self.hparams.json_root
        data_root = self.hparams.root
        # number of views per model and in total
        n_views_per_model = self.hparams.n_views_per_model
        n_views = n_views_per_model * len(self.model_ids)
        # json files for each model
        json_files = []
        # mapping to the index of the corresponding model in json_files
        model_mappings = []
        # space for dinov2 patch features
        feats = np.empty((n_views, 512, 768), dtype=np.float16)
        # space for object masks on image patches
        obj_masks = np.empty((n_views, 256), dtype=bool)
        # input images (not required in training)
        imgs = None
        # load data for non-aug views
        i = 0  # index for views
        for j, model_id in enumerate(self.model_ids):
            print(model_id)
            # if j % 10 == 0 and torch.distributed.get_rank() == 0:
            #     print(f"\rLoading training data: {j}/{len(self.model_ids)}")
            # 3D data
            with open(os.path.join(json_data_root, model_id, self.json_name), "r") as f:
                json_file = json.load(f)
            json_files.append(json_file)
            filenames = os.listdir(os.path.join(data_root, model_id, 'features'))
            filenames = [f for f in filenames if 'high_res' not in f]
            filenames = filenames[:self.hparams.n_views_per_model]
            for filename in filenames:
                view_feat = np.load(os.path.join(data_root, model_id, 'features', filename))
                first_frame_feat = view_feat[0]
                if self.hparams.frame_mode == 'last_frame':
                    second_frame_feat = view_feat[-2]
                elif self.hparams.frame_mode == 'random_state_frame':
                    second_frame_feat = view_feat[-1]
                else:
                    raise NotImplementedError("Please provide correct frame mode: last_frame | random_state_frame")
                feats[i : i + 1, :256, :] = first_frame_feat.astype(np.float16)
                feats[i : i + 1, 256:, :] = second_frame_feat.astype(np.float16)
                i = i + 1
            model_mappings += [j] * n_views_per_model
            # object masks for all views
            # all_obj_masks = np.load(
            #     os.path.join(json_data_root, model_id, "features/patch_obj_masks.npy")
            # )  # (20, Np)
            # obj_masks[i : i + n_views_per_model] = all_obj_masks[:n_views_per_model]
        return {
            "len": n_views,
            "gt_files": json_files,
            "features": feats,
            "obj_masks": None,
            "model_mappings": model_mappings,
            "imgs": imgs,
        }

    def _cache_data_non_train(self):
        # number of views per model and in total
        n_views_per_model = 2
        n_views = n_views_per_model * len(self.model_ids)
        # json files for each model
        gt_files = []
        pred_files = []  # for predicted graphs
        # mapping to the index of the corresponding model in json_files
        model_mappings = []
        # space for dinov2 patch features
        feats = np.empty((n_views, 512, 768), dtype=np.float16)
        # space for input images
        first_imgs = np.empty((n_views, 128, 128, 3), dtype=np.uint8)
        second_imgs = np.empty((n_views, 128, 128, 3), dtype=np.uint8)
        # transformation for input images
        transform = T.Compose(
            [
                T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
                T.CenterCrop(224),
                T.Resize(128, interpolation=T.InterpolationMode.BICUBIC),
            ]
        )

        i = 0  # index for views
        desc = f'Loading {self.mode} data'
        for j, model_id in tqdm(enumerate(self.model_ids), total=len(self.model_ids), desc=desc):
            with open(os.path.join(self.hparams.json_root, model_id, self.json_name), "r") as f:
                json_file = json.load(f)
            gt_files.append(json_file)
            # filename_dir = os.path.join(self.hparams.root, model_id, 'features')
            for filename in ['18.npy', '19.npy']:
                view_feat = np.load(os.path.join(self.hparams.root, model_id, 'features', filename))
                first_frame_feat = view_feat[0]
                if self.hparams.frame_mode == 'last_frame':
                    second_frame_feat = view_feat[-2]
                elif self.hparams.frame_mode == 'random_state_frame':
                    second_frame_feat = view_feat[-1]
                else:
                    raise NotImplementedError("Please provide correct frame mode: last_frame | random_state_frame")
                feats[i : i + 1, :256, :] = first_frame_feat.astype(np.float16)
                feats[i : i + 1, 256:, :] = second_frame_feat.astype(np.float16)

                video_path = os.path.join(self.hparams.root, model_id, 'imgs', 'animation_' + filename.replace('.npy', '.mp4'))
                reader = imageio.get_reader(video_path)
                frames = []
                for frame in reader:
                    frames.append(frame)
                reader.close()

                first_img = Image.fromarray(frames[0])
                if first_img.mode == 'RGBA':
                    first_img = make_white_background(first_img)


                first_img = np.asarray(transform(first_img), dtype=np.int8)
                first_imgs[i] = first_img

                if self.hparams.frame_mode == 'last_frame':
                    second_img = Image.fromarray(frames[-1])
                elif self.hparams.frame_mode == 'random_state_frame':
                    second_img_path = video_path.replace('animation', 'random').replace('.mp4', '.png')
                    second_img = Image.open(second_img_path)
                if second_img.mode == 'RGBA':
                    second_img = make_white_background(second_img)
                second_img = np.asarray(transform(second_img), dtype=np.int8)
                second_imgs[i] = second_img

                i = i + 1
            # mapping to json file
            model_mappings += [j] * n_views_per_model

        return {
            "len": n_views,
            "gt_files": gt_files,
            "pred_files": pred_files,
            "features": feats,
            "model_mappings": model_mappings,
            "imgs": [first_imgs, second_imgs],
        }

    def _cache_data(self):
        """
        Function to cache data from disk.
        """
        if self.mode == "train":
            return self._cache_data_train()
        else:
            return self._cache_data_non_train()

    def _get_item_train_val(self, index):
        model_i = self.files["model_mappings"][index]
        gt_file = self.files["gt_files"][model_i]
        data, cond = self._prepare_input_GT(
            file=gt_file, model_id=self.model_ids[model_i]
        )
        if self.mode == "val":
            # input image for visualization
            img_first = self.files["imgs"][0][index]
            img_last = self.files["imgs"][1][index]
            cond["img"] = np.concatenate([img_first, img_last], axis=1)
        # else:
        #     # object masks on patches
        #     # obj_mask = self.files["obj_masks"][index][None, ...].repeat(self.hparams.K * 5, axis=0)
        #     cond["img_obj_mask"] = [None]
        return data, cond

    def _get_item_test(self, index):
        model_i = self.files["model_mappings"][index]

        gt_file = None if self.no_GT else self.files["gt_files"][model_i] 

        if self.hparams.get('G_dir', None) is None:
            data, cond = self._prepare_input_GT(file=gt_file, model_id=self.model_ids[model_i])
        else:
            if index % 2 == 0:
                filename = '18.json'
            else:
                filename = '19.json'
            pred_file_path = os.path.join(self.hparams.G_dir, self.model_ids[model_i], filename)
            with open(pred_file_path, "r") as f:
                pred_file = json.load(f)
            data, cond = self._prepare_input(model_id=self.model_ids[model_i], pred_file=pred_file, gt_file=gt_file)
        # input image for visualization
        img_first = self.files["imgs"][0][index]
        img_last = self.files["imgs"][1][index]
        cond["img"] = np.concatenate([img_first, img_last], axis=1)
        return data, cond

    def __getitem__(self, index):
        # input image features
        feat = self.files["features"][index]

        # prepare input, GT data and other axillary info
        if self.mode == "test":
            data, cond = self._get_item_test(index)
        else:
            data, cond = self._get_item_train_val(index)

        return data, cond, feat

    def __len__(self):
        return self.files["len"]

if __name__ == '__main__':
    from types import SimpleNamespace

    class EnhancedNamespace(SimpleNamespace):
        def get(self, key, default=None):
            return getattr(self, key, default)
    
    hparams = {
        "name": "dm_singapo",
        "json_root": "/home/users/ruiqi.wu/singapo/",   # root directory of the dataset
        "batch_size": 20,  # batch size for training
        "num_workers": 8,  # number of workers for data loading
        "K": 32,    # maximum number of nodes (parts) in the graph (object)
        "split_file": "/home/users/ruiqi.wu/singapo/data/data_split.json",
        "n_views_per_model": 5,
        "root": "/home/users/ruiqi.wu/manipulate_3d_generate/data/blender_version",
        "frame_mode": "last_frame"
    }
    hparams = EnhancedNamespace(**hparams)
    with open(hparams.split_file , "r") as f:
        splits = json.load(f)

        train_ids = splits["train"]
        val_ids = [i for i in train_ids if "augmented" not in i]

    val_ids = [val_id for val_id in val_ids if os.path.exists(os.path.join(hparams.root, val_id, "features"))]

    dataset = MyDataset(hparams, model_ids=val_ids[:20], mode="valid")
    for i in range(20):
        data, cond, feat = dataset.__getitem__(i)
    import ipdb
    ipdb.set_trace()