Spaces:
Runtime error
Runtime error
| import sys | |
| sys.path.insert(1, '.') | |
| import numpy as np | |
| from omegaconf import DictConfig | |
| import torch | |
| from PIL import Image | |
| import torchvision | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| from ldm.util import instantiate_from_config | |
| import os | |
| import io | |
| import pickle | |
| import webdataset as wds | |
| import imageio | |
| import time | |
| from torch import distributed as dist | |
| from itertools import chain | |
| class ObjaverseDataDecoder: | |
| def __init__(self, | |
| target_name="albedo", | |
| image_transforms=[], | |
| default_trans=torch.zeros(3), | |
| postprocess=None, | |
| return_paths=False, | |
| mask_name="alpha", | |
| test=False, | |
| condition_name=None, | |
| bg_color="white", | |
| target_name_pool=None, | |
| **kargs | |
| ) -> None: | |
| """Create a dataset from blender rendering results. | |
| If you pass in a root directory it will be searched for images | |
| ending in ext (ext can be a list) | |
| """ | |
| # testing behaves differently | |
| self.test = test | |
| self.target_name = target_name | |
| self.mask_name = mask_name | |
| self.default_trans = default_trans | |
| self.return_paths = return_paths | |
| if isinstance(postprocess, DictConfig): | |
| postprocess = instantiate_from_config(postprocess) | |
| self.postprocess = postprocess | |
| # extra condition | |
| self.condition_name = condition_name | |
| self.target_name_pool = target_name_pool if not target_name_pool is None else [target_name] | |
| self.counter = 0 | |
| self.tform = image_transforms["totensor"] | |
| self.img_size = image_transforms["size"] | |
| self.tsize = torchvision.transforms.Compose([torchvision.transforms.Resize(self.img_size)]) | |
| if bg_color == "white": | |
| self.bg_color = [1., 1., 1., 1.] | |
| elif bg_color == "noise": | |
| self.bg_color = "noise" | |
| else: | |
| raise NotImplementedError | |
| def path_parsing(self, filename, cond_name=None): | |
| # cached path loads albedo | |
| if 'albedo' in filename: | |
| filename = filename.replace('albedo', self.target_name) | |
| if self.target_name=="gloss_shaded": | |
| filename = filename.replace('gloss_direct', self.target_name).replace("exr", "jpg") | |
| filename_targets = [filename.replace(self.target_name, "gloss_direct").replace("jpg", "exr"), | |
| filename.replace(self.target_name, "gloss_color")] | |
| elif self.target_name=="diffuse_shaded": | |
| filename = filename.replace('diffuse_direct', self.target_name).replace("exr", "jpg") | |
| filename_targets = [filename.replace(self.target_name, "diffuse_direct").replace("jpg", "exr"), | |
| filename.replace(self.target_name, "albedo")] | |
| else: | |
| filename_targets = None | |
| normal_condition_filename = None | |
| if self.test and "images_train" in filename: | |
| # Currently. "images_train" exists in test set, we write this for clearity | |
| condition_filename = filename | |
| mask_filename = filename.replace('images_train', 'masks') | |
| if self.condition_name == "normal": | |
| raise NotImplementedError("Testing with normal conditioning on custom data is not supported") | |
| else: | |
| cond_name_prefix = filename.split(".", 1)[0] + "." if cond_name is None else cond_name | |
| condition_filename = cond_name_prefix + filename.rsplit('.', 1)[1] | |
| mask_filename = filename.replace(self.target_name, self.mask_name) | |
| if self.condition_name == "normal": | |
| normal_condition_filename = filename.replace(self.target_name, "normal") | |
| return filename, condition_filename, mask_filename, normal_condition_filename, filename_targets | |
| def read_images(self, filename, condition_filename, normal_condition_filename): | |
| # image reading | |
| if self.target_name in ["gloss_shaded", "diffuse_shaded"]: | |
| target_im_0 = np.array(self.normalized_read(filename[0])) | |
| target_im_1 = np.array(self.normalized_read(filename[1])) | |
| target_im = np.clip(target_im_0 * target_im_1, 0, 1) | |
| else: | |
| target_im = np.array(self.normalized_read(filename)) | |
| cond_im = np.array(self.normalized_read(condition_filename)) | |
| if self.condition_name == "normal": | |
| normal_img = np.array(self.normalized_read(normal_condition_filename)) | |
| else: | |
| normal_img = None | |
| return target_im, cond_im, normal_img | |
| def image_post_processing(self, img_mask, target_im, cond_im, normal_img): | |
| # make sure image has 3 dimension | |
| if len(img_mask.shape) == 2: | |
| img_mask = img_mask[:, :, np.newaxis] | |
| else: | |
| img_mask = img_mask[:, :, :3] | |
| # transform into desired format | |
| target_im, crop_idx = self.load_im(target_im, img_mask, self.bg_color, crop_idx=True) | |
| target_im = np.uint8(self.tsize(target_im)) | |
| cond_im = np.uint8(self.tsize(self.load_im(cond_im, img_mask, self.bg_color))) | |
| if self.condition_name == "normal": | |
| normal_img = np.uint8(self.tsize(self.load_im(normal_img, img_mask, self.bg_color))) | |
| else: | |
| normal_img = None | |
| return target_im, cond_im, normal_img, crop_idx | |
| # def cartesian_to_spherical(self, xyz): | |
| # ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) | |
| # xy = xyz[:,0]**2 + xyz[:,1]**2 | |
| # z = np.sqrt(xy + xyz[:,2]**2) | |
| # theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down | |
| # #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up | |
| # azimuth = np.arctan2(xyz[:,1], xyz[:,0]) | |
| # return np.array([theta, azimuth, z]) | |
| def load_im(self, img, img_mask, color, crop_idx=False): | |
| ''' | |
| replace background pixel with random color in rendering | |
| ''' | |
| # our rendering do not have a valid alpha channel. | |
| # We use a seperate mask, which also do not have a valid alpha | |
| if img.shape[-1] == 3: | |
| img = np.concatenate([img, np.ones_like(img[..., :1])], axis=-1) | |
| # image maske shape align with image size | |
| if (img.shape[0] != img_mask.shape[0]) or (img.shape[1] != img_mask.shape[1]): | |
| img_mask = cv2.resize(img_mask, | |
| (img.shape[1], img.shape[0]), | |
| interpolation=cv2.INTER_NEAREST)[:, :, np.newaxis] | |
| if isinstance(color, str): | |
| random_img = np.random.rand(*(img.shape)) | |
| img[img_mask[:, :, -1] <= 0.5] = random_img[img_mask[:, :, -1] <= 0.5] | |
| else: | |
| img[img_mask[:, :, -1] <= 0.5] = color | |
| if self.test: | |
| # crop out valid_mask | |
| img, crop_uv = self.center_crop(img[:, :, :3], img_mask) | |
| else: | |
| crop_uv = None | |
| # center crop | |
| if img.shape[0] > img.shape[1]: | |
| margin = int((img.shape[0] - img.shape[1]) // 2) | |
| img = img[margin:margin+img.shape[1]] | |
| elif img.shape[1] > img.shape[0]: | |
| margin = int((img.shape[1] - img.shape[0]) // 2) | |
| img = img[:, margin:margin+img.shape[0]] | |
| img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)) | |
| if crop_idx: | |
| return img, crop_uv | |
| return img | |
| def center_crop(self, img, mask, mask_ratio=.8): | |
| mask_uvs = np.vstack(np.nonzero(mask[:, :, -1] > 0.5)) | |
| min_uv, max_uv = np.min(mask_uvs, axis=-1), np.max(mask_uvs, axis=-1) | |
| img = img + (mask[..., -1:] <= 0.5) | |
| half_size = int(max(max_uv - min_uv) // 2) | |
| crop_length = (max_uv - min_uv) // 2 | |
| center_uv = min_uv + crop_length | |
| expand_hasl_size = int(half_size / mask_ratio) | |
| size = expand_hasl_size * 2 + 1 | |
| img_new = np.ones((size, size, 3)) | |
| img_new[expand_hasl_size-crop_length[0]:expand_hasl_size+crop_length[0]+1, expand_hasl_size-crop_length[1]:expand_hasl_size+crop_length[1]+1] = \ | |
| img[center_uv[0]-crop_length[0]:center_uv[0]+crop_length[0]+1, center_uv[1]-crop_length[1]:center_uv[1]+crop_length[1]+1] | |
| crop_uv = np.array([expand_hasl_size, crop_length[0], crop_length[1], center_uv[0], center_uv[1], size], dtype=int) | |
| return img_new, crop_uv | |
| def transform_normal(self, normal_input, cam): | |
| # load camera | |
| img_mask = torch.linalg.norm(normal_input, dim=-1) > 1.5 | |
| extrinsic, K = cam | |
| extrinsic = np.concatenate([extrinsic, np.zeros(4).reshape(1, 4)], axis=0) | |
| extrinsic[3, 3] = 1 | |
| pose = np.linalg.inv(extrinsic) | |
| temp = pose[1] + 0.0 | |
| pose[1] = -pose[2] | |
| pose[2] = temp | |
| extrinsic = torch.from_numpy(np.linalg.inv(pose)).float() | |
| # to normal | |
| normal_img = extrinsic[None, :3, :3] @ normal_input[..., :3].reshape(-1, 3, 1) | |
| normal_img = normal_img.reshape(normal_input.shape[0], normal_input.shape[1], 3) | |
| normal_img[img_mask] = 1.0 | |
| return normal_img | |
| def parse_item(self, target_im, cond_img, normal_img, filename, target_ids, **args): | |
| data = {} | |
| # we need to transform normal to cmaera frame | |
| if self.target_name == "normal": | |
| target_im = self.transform_normal(target_im, self.get_camera(filename, **args)) | |
| # normal conditioning | |
| if self.condition_name == "normal": | |
| normal_img = self.transform_normal(normal_img, self.get_camera(filename, **args)) | |
| data["image_target"] = target_im | |
| data["image_cond"] = cond_img | |
| if self.condition_name == "normal": | |
| data["img_normal"] = normal_img | |
| if self.test or self.return_paths: | |
| data["path"] = str(filename) | |
| data["label"] = torch.zeros(1).reshape(1, 1, 1)+target_ids | |
| if self.postprocess is not None: | |
| data = self.postprocess(data) | |
| return data | |
| def normalized_read(self, imgpath): | |
| img = np.array(imageio.imread(imgpath)) | |
| if img.dtype == np.uint8: | |
| img = img / 255.0 | |
| else: | |
| img = img ** (1 / 2.2) | |
| return img | |
| def process_im(self, im): | |
| im = Image.fromarray(im) | |
| im = im.convert("RGB") | |
| return self.tform(im) | |
| class ObjaverseDecoerWDS(ObjaverseDataDecoder): | |
| def __init__(self, **kargs) -> None: | |
| super().__init__(**kargs) | |
| def dict2tuple(self, data): | |
| returns = (data["image_target"], data["image_cond"],data["label"],) | |
| if self.condition_name == "normal": | |
| returns +=(data["img_normal"], ) | |
| if self.test or self.return_paths: | |
| returns += (data["path"],) | |
| return returns | |
| def tuple2dict(self, data): | |
| returns = {} | |
| returns["image_target"] = data[0] | |
| returns["image_cond"] = data[1] | |
| returns["label"] = data[2] | |
| if self.condition_name == "normal": | |
| returns["img_normal"] = data[3] | |
| if self.test or self.return_paths: | |
| returns["path"] = data[-1] | |
| return returns | |
| def data_filter(self, albedo, spec, diffuse_shad, spec_shad): | |
| returns = {} | |
| returns["image_target"] = data[0] | |
| returns["image_cond"] = data[1] | |
| if self.condition_name == "normal": | |
| returns["img_normal"] = data[2] | |
| if self.test or self.return_paths: | |
| returns["path"] = data[-1] | |
| return returns | |
| def get_camera(self, input_filename, sample): | |
| camera_file = input_filename.replace(f'{self.target_name}0001', \ | |
| 'camera').rsplit(".")[0] + ".pkl" | |
| mask_filename_byte = io.BytesIO(sample[camera_file]) | |
| cam = pickle.load(mask_filename_byte) | |
| return cam | |
| def process_sample(self, sample): | |
| # start_worker=time.time() | |
| results = [] | |
| for target_ids, target_name in enumerate(self.target_name_pool): | |
| _result = self.process_sample_single(sample, target_ids, target_name) | |
| results.append(self.dict2tuple(_result)) | |
| results = wds.filters.default_collation_fn(results) | |
| return results | |
| def batch_reordering(self, sample): | |
| batch_splits = [] | |
| for data_idx, _ in enumerate(sample): | |
| batch_splits.append( | |
| torch.cat( | |
| torch.chunk(sample[data_idx], dim=1, | |
| chunks=len(self.target_name_pool)), | |
| dim=0)[:,0] | |
| ) | |
| return self.tuple2dict(batch_splits) | |
| def process_sample_single(self, sample, target_ids, target_name): | |
| # get target image filename | |
| self.target_name = target_name | |
| target_file_name = self.target_name | |
| if self.target_name=="gloss_shaded": | |
| target_file_name = "gloss_direct" | |
| elif self.target_name=="diffuse_shaded": | |
| target_file_name = "diffuse_direct" | |
| for k in list(sample.keys()): | |
| if target_file_name not in k: | |
| continue | |
| target_key = k | |
| break | |
| # ############## | |
| # prev_time = start_worker | |
| # current_time = time.time() | |
| # print(f"find target takes: {current_time - prev_time}") | |
| # ############## | |
| filename, condition_filename, \ | |
| mask_filename, normal_condition_filename, filename_targets = self.path_parsing(target_key, "") | |
| # get file streams | |
| if filename_targets is None: | |
| filename_byte = io.BytesIO(sample[filename]) | |
| else: | |
| filename_byte = [io.BytesIO(sample[filename_target]) for filename_target in filename_targets] | |
| condition_filename_byte = io.BytesIO(sample[condition_filename]) | |
| normal_condition_filename_byte = io.BytesIO(sample[normal_condition_filename]) \ | |
| if self.condition_name == "normal" else None | |
| mask_filename_byte = io.BytesIO(sample[mask_filename]) | |
| # image reading | |
| target_im, cond_im, normal_img = self.read_images(filename_byte, | |
| condition_filename_byte, normal_condition_filename_byte) | |
| # mask reading | |
| img_mask = np.array(self.normalized_read(mask_filename_byte)) | |
| # post processing | |
| target_im, cond_im, normal_img, _ = self.image_post_processing(img_mask, target_im, cond_im, normal_img) | |
| # transform | |
| target_im = self.process_im(target_im) | |
| cond_im = self.process_im(cond_im) | |
| normal_img = self.process_im(normal_img) \ | |
| if self.condition_name == "normal" \ | |
| else None | |
| data = self.parse_item(target_im, cond_im, normal_img, filename, target_ids, sample=sample) | |
| # override for file path | |
| if self.test or self.return_paths: | |
| data["path"] = sample["__key__"] | |
| result = dict(__key__=sample["__key__"]) | |
| result.update(data) | |
| return result | |
| if __name__=="__main__": | |
| from torchvision import transforms | |
| from einops import rearrange | |
| torch.distributed.init_process_group(backend="nccl") | |
| image_transforms = [transforms.ToTensor(), | |
| transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))] | |
| image_transforms = torchvision.transforms.Compose(image_transforms) | |
| image_transforms = { | |
| "size": 256, | |
| "totensor": image_transforms | |
| } | |
| data_list_dir = "/home/chenxi/code/material-diffusion/data/big_data_lists" | |
| tar_name_list = sorted(os.listdir(data_list_dir))[1:4] | |
| tar_list = [_name.rsplit("_num")[0]+".tar" for _name in tar_name_list] | |
| tar_dir = "/home/chenxi/code/material-diffusion/data/big_data_transed" | |
| tars = [os.path.join(tar_dir, _name) for _name in tar_list] | |
| dataset_size = 0 | |
| imgperobj = 10 | |
| print("list dirs...") | |
| for _name in tar_name_list: | |
| num_obj = int(_name.rsplit("_num_")[1].rsplit(".")[0]) | |
| print(num_obj, " : ", _name) | |
| dataset_size += num_obj * imgperobj | |
| decoder = ObjaverseDecoerWDS(image_transforms=image_transforms, | |
| return_paths=True) | |
| batch_size = 8 | |
| print('============= length of training dataset %d =============' % (dataset_size // batch_size // 2)) | |
| dataset = (wds.WebDataset(tars, | |
| repeat=0, | |
| nodesplitter=wds.shardlists.split_by_node) | |
| .shuffle(100) | |
| .map(decoder.process_sample) | |
| .map(decoder.dict2tuple) | |
| .batched(batch_size, partial=False) | |
| .map(decoder.tuple2dict) | |
| .with_epoch(dataset_size // batch_size // 2) | |
| .with_length(dataset_size // batch_size) | |
| ) | |
| from torch.utils.data import DataLoader | |
| # loader = DataLoader(dataset, batch_size=None, num_workers=8, shuffle=False) | |
| loader = (wds.WebLoader(dataset, batch_size=None, num_workers=2, shuffle=False) | |
| .map(decoder.dict2tuple) | |
| .unbatched() | |
| # .shuffle(100) | |
| .batched(batch_size) | |
| .map(decoder.tuple2dict) | |
| ) | |
| print("# loader length", len(dataset)) | |
| for epoch in range(2): | |
| ind = -1 | |
| for sample in loader: | |
| assert "image_target" in sample | |
| assert "image_cond" in sample | |
| assert "path" in sample | |
| ind += 1 | |
| if ind != 0: | |
| continue | |
| # replace to this for file path | |
| # worker_info = torch.utils.data.get_worker_info() | |
| # if worker_info is not None: | |
| # worker = worker_info.id | |
| # num_workers = worker_info.num_workers | |
| # data["path"] = sample["__url__"]+"--"+sample["__key__"] +f".{worker}/{num_workers}" | |
| # print(f"{ind}: shape {sample['image_target'].shape} {sample['path'][0].rsplit('/', 1)[-2]}") | |
| print("##############") | |
| for i in range(len(sample['path'])): | |
| print(f"epoch {epoch}, it {ind}: shape {sample['image_target'].shape} {sample['path'][i].rsplit('--', 1)[0].rsplit('/', 2)[-1]} {sample['path'][i].rsplit('--', 1)[1].rsplit('/', 3)[-3]} {sample['path'][i].rsplit('--', 1)[1].rsplit('/',4)[-4]} {sample['path'][i].rsplit('.', 1)[-1]} rank: {dist.get_rank()}") | |
| print("##############") | |
| print(sample["path"]) | |
| print(sample["path"]) | |
| print(f"NUmber of samples: {ind} {dataset_size} {len(dataset)} rank: {dist.get_rank()}") | |
| # 1. Remember samples are batched inside each worker, the outside data loader only sees one sample | |
| # 2. All batch, epoch, and length settings are only visible within each worker | |
| # 3. Unbatch and Suffle and then re-batch in loader result in between worker shuffle. | |
| # This also allows to control of loader batching and worker batching for CPU optimization of worker-loader data transfer. | |
| # https://github.com/webdataset/webdataset/issues/141#issuecomment-1043190147 | |
| # 4. It seems that data just repeat forever to satisfy with_epoch | |
| # 5. Torch datalogger requires the dataset to have a len() method, which is used to schdule sample idx | |
| # 6. DDP sampler will return its only length | |
| # 7. WebLoader does not need length, it only raises the end of the iteration when data is running out | |
| # 8. How does torch loader deal with datasets with fewer sizes than claims? | |
| # 9. Set epoch will make sampling start from the beginning when a new epoch starts. Observed by disable shuffle and one batch repeat | |
| # And each epoch will have a different sampling seed | |
| # 10. DataLoader with IterableDataset: expected unspecified sampler option. DDP sampler will not be usable. | |
| # !0. In summary: | |
| # For ddp multi-worker training, the worker splitter and node splitter will make sure tars are splitted into each worker | |
| # We have to manually adjust with_epoch with respect to num_worker and num_node and batch_size | |
| def nodesplitter(src, group=None): | |
| if torch.distributed.is_initialized(): | |
| if group is None: | |
| group = torch.distributed.group.WORLD | |
| rank = torch.distributed.get_rank(group=group) | |
| size = torch.distributed.get_world_size(group=group) | |
| print(f"nodesplitter: rank={rank} size={size}") | |
| count = 0 | |
| for i, item in enumerate(src): | |
| if i % size == rank: | |
| yield item | |
| count += 1 | |
| print(f"nodesplitter: rank={rank} size={size} count={count} DONE") | |
| else: | |
| yield from src |