Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,675 Bytes
c28dddb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
import json
import numpy as np
from PIL import Image
import torchvision.transforms as T
from dataset.base_dataset import BaseDataset
import random
from tqdm import tqdm
import imageio
import torch
def make_white_background(src_img):
'''Make the white background for the input RGBA image.'''
src_img.load()
background = Image.new("RGB", src_img.size, (255, 255, 255))
background.paste(src_img, mask=src_img.split()[3]) # 3 is the alpha channel
return background
class MyDataset(BaseDataset):
"""
Dataset for training and testing on the PartNet-Mobility and ACD datasets (with our preprocessing).
The GT graph is given.
"""
def __init__(self, hparams, model_ids, mode="train", json_name="object.json"):
self.hparams = hparams
self.json_name = json_name
self.model_ids = self._filter_models(model_ids)
self.mode = mode
self.map_cat = False
self.get_acd_mapping()
self.no_GT = (
True if self.hparams.get("test_no_GT", False) and self.hparams.get("test_pred_G", False)
else False
)
self.pred_G = (
False
if mode in ["train", "val"]
else self.hparams.get("test_pred_G", False)
)
if mode == 'test':
if "acd" in hparams.test_which:
self.map_cat = True
self.files = self._cache_data()
print(f"[INFO] {mode} dataset: {len(self)} data samples loaded.")
def _cache_data_train(self):
json_data_root = self.hparams.json_root
data_root = self.hparams.root
# number of views per model and in total
n_views_per_model = self.hparams.n_views_per_model
n_views = n_views_per_model * len(self.model_ids)
# json files for each model
json_files = []
# mapping to the index of the corresponding model in json_files
model_mappings = []
# space for dinov2 patch features
feats = np.empty((n_views, 512, 768), dtype=np.float16)
# space for object masks on image patches
obj_masks = np.empty((n_views, 256), dtype=bool)
# input images (not required in training)
imgs = None
# load data for non-aug views
i = 0 # index for views
for j, model_id in enumerate(self.model_ids):
print(model_id)
# if j % 10 == 0 and torch.distributed.get_rank() == 0:
# print(f"\rLoading training data: {j}/{len(self.model_ids)}")
# 3D data
with open(os.path.join(json_data_root, model_id, self.json_name), "r") as f:
json_file = json.load(f)
json_files.append(json_file)
filenames = os.listdir(os.path.join(data_root, model_id, 'features'))
filenames = [f for f in filenames if 'high_res' not in f]
filenames = filenames[:self.hparams.n_views_per_model]
for filename in filenames:
view_feat = np.load(os.path.join(data_root, model_id, 'features', filename))
first_frame_feat = view_feat[0]
if self.hparams.frame_mode == 'last_frame':
second_frame_feat = view_feat[-2]
elif self.hparams.frame_mode == 'random_state_frame':
second_frame_feat = view_feat[-1]
else:
raise NotImplementedError("Please provide correct frame mode: last_frame | random_state_frame")
feats[i : i + 1, :256, :] = first_frame_feat.astype(np.float16)
feats[i : i + 1, 256:, :] = second_frame_feat.astype(np.float16)
i = i + 1
model_mappings += [j] * n_views_per_model
# object masks for all views
# all_obj_masks = np.load(
# os.path.join(json_data_root, model_id, "features/patch_obj_masks.npy")
# ) # (20, Np)
# obj_masks[i : i + n_views_per_model] = all_obj_masks[:n_views_per_model]
return {
"len": n_views,
"gt_files": json_files,
"features": feats,
"obj_masks": None,
"model_mappings": model_mappings,
"imgs": imgs,
}
def _cache_data_non_train(self):
# number of views per model and in total
n_views_per_model = 2
n_views = n_views_per_model * len(self.model_ids)
# json files for each model
gt_files = []
pred_files = [] # for predicted graphs
# mapping to the index of the corresponding model in json_files
model_mappings = []
# space for dinov2 patch features
feats = np.empty((n_views, 512, 768), dtype=np.float16)
# space for input images
first_imgs = np.empty((n_views, 128, 128, 3), dtype=np.uint8)
second_imgs = np.empty((n_views, 128, 128, 3), dtype=np.uint8)
# transformation for input images
transform = T.Compose(
[
T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.Resize(128, interpolation=T.InterpolationMode.BICUBIC),
]
)
i = 0 # index for views
desc = f'Loading {self.mode} data'
for j, model_id in tqdm(enumerate(self.model_ids), total=len(self.model_ids), desc=desc):
with open(os.path.join(self.hparams.json_root, model_id, self.json_name), "r") as f:
json_file = json.load(f)
gt_files.append(json_file)
# filename_dir = os.path.join(self.hparams.root, model_id, 'features')
for filename in ['18.npy', '19.npy']:
view_feat = np.load(os.path.join(self.hparams.root, model_id, 'features', filename))
first_frame_feat = view_feat[0]
if self.hparams.frame_mode == 'last_frame':
second_frame_feat = view_feat[-2]
elif self.hparams.frame_mode == 'random_state_frame':
second_frame_feat = view_feat[-1]
else:
raise NotImplementedError("Please provide correct frame mode: last_frame | random_state_frame")
feats[i : i + 1, :256, :] = first_frame_feat.astype(np.float16)
feats[i : i + 1, 256:, :] = second_frame_feat.astype(np.float16)
video_path = os.path.join(self.hparams.root, model_id, 'imgs', 'animation_' + filename.replace('.npy', '.mp4'))
reader = imageio.get_reader(video_path)
frames = []
for frame in reader:
frames.append(frame)
reader.close()
first_img = Image.fromarray(frames[0])
if first_img.mode == 'RGBA':
first_img = make_white_background(first_img)
first_img = np.asarray(transform(first_img), dtype=np.int8)
first_imgs[i] = first_img
if self.hparams.frame_mode == 'last_frame':
second_img = Image.fromarray(frames[-1])
elif self.hparams.frame_mode == 'random_state_frame':
second_img_path = video_path.replace('animation', 'random').replace('.mp4', '.png')
second_img = Image.open(second_img_path)
if second_img.mode == 'RGBA':
second_img = make_white_background(second_img)
second_img = np.asarray(transform(second_img), dtype=np.int8)
second_imgs[i] = second_img
i = i + 1
# mapping to json file
model_mappings += [j] * n_views_per_model
return {
"len": n_views,
"gt_files": gt_files,
"pred_files": pred_files,
"features": feats,
"model_mappings": model_mappings,
"imgs": [first_imgs, second_imgs],
}
def _cache_data(self):
"""
Function to cache data from disk.
"""
if self.mode == "train":
return self._cache_data_train()
else:
return self._cache_data_non_train()
def _get_item_train_val(self, index):
model_i = self.files["model_mappings"][index]
gt_file = self.files["gt_files"][model_i]
data, cond = self._prepare_input_GT(
file=gt_file, model_id=self.model_ids[model_i]
)
if self.mode == "val":
# input image for visualization
img_first = self.files["imgs"][0][index]
img_last = self.files["imgs"][1][index]
cond["img"] = np.concatenate([img_first, img_last], axis=1)
# else:
# # object masks on patches
# # obj_mask = self.files["obj_masks"][index][None, ...].repeat(self.hparams.K * 5, axis=0)
# cond["img_obj_mask"] = [None]
return data, cond
def _get_item_test(self, index):
model_i = self.files["model_mappings"][index]
gt_file = None if self.no_GT else self.files["gt_files"][model_i]
if self.hparams.get('G_dir', None) is None:
data, cond = self._prepare_input_GT(file=gt_file, model_id=self.model_ids[model_i])
else:
if index % 2 == 0:
filename = '18.json'
else:
filename = '19.json'
pred_file_path = os.path.join(self.hparams.G_dir, self.model_ids[model_i], filename)
with open(pred_file_path, "r") as f:
pred_file = json.load(f)
data, cond = self._prepare_input(model_id=self.model_ids[model_i], pred_file=pred_file, gt_file=gt_file)
# input image for visualization
img_first = self.files["imgs"][0][index]
img_last = self.files["imgs"][1][index]
cond["img"] = np.concatenate([img_first, img_last], axis=1)
return data, cond
def __getitem__(self, index):
# input image features
feat = self.files["features"][index]
# prepare input, GT data and other axillary info
if self.mode == "test":
data, cond = self._get_item_test(index)
else:
data, cond = self._get_item_train_val(index)
return data, cond, feat
def __len__(self):
return self.files["len"]
if __name__ == '__main__':
from types import SimpleNamespace
class EnhancedNamespace(SimpleNamespace):
def get(self, key, default=None):
return getattr(self, key, default)
hparams = {
"name": "dm_singapo",
"json_root": "/home/users/ruiqi.wu/singapo/", # root directory of the dataset
"batch_size": 20, # batch size for training
"num_workers": 8, # number of workers for data loading
"K": 32, # maximum number of nodes (parts) in the graph (object)
"split_file": "/home/users/ruiqi.wu/singapo/data/data_split.json",
"n_views_per_model": 5,
"root": "/home/users/ruiqi.wu/manipulate_3d_generate/data/blender_version",
"frame_mode": "last_frame"
}
hparams = EnhancedNamespace(**hparams)
with open(hparams.split_file , "r") as f:
splits = json.load(f)
train_ids = splits["train"]
val_ids = [i for i in train_ids if "augmented" not in i]
val_ids = [val_id for val_id in val_ids if os.path.exists(os.path.join(hparams.root, val_id, "features"))]
dataset = MyDataset(hparams, model_ids=val_ids[:20], mode="valid")
for i in range(20):
data, cond, feat = dataset.__getitem__(i)
import ipdb
ipdb.set_trace() |