Spaces:
Runtime error
Runtime error
| import sys | |
| sys.path.insert(1, '.') | |
| from typing import Dict | |
| import webdataset as wds | |
| import numpy as np | |
| from omegaconf import DictConfig, ListConfig | |
| import torch | |
| from torch.utils.data import Dataset | |
| from pathlib import Path | |
| import json | |
| from PIL import Image | |
| from torchvision import transforms | |
| import torchvision | |
| from einops import rearrange | |
| from ldm.util import instantiate_from_config | |
| from datasets import load_dataset | |
| import pytorch_lightning as pl | |
| import copy | |
| import csv | |
| import cv2 | |
| import random | |
| import matplotlib.pyplot as plt | |
| from torch.utils.data import DataLoader | |
| import json | |
| import os, sys | |
| import webdataset as wds | |
| import math | |
| from torch.utils.data.distributed import DistributedSampler | |
| import glob | |
| import pickle | |
| from ldm.data.objaverse_rendered import get_rendered_objaverse_list_v0 | |
| from ldm.data.decoder import ObjaverseDataDecoder, ObjaverseDecoerWDS, nodesplitter | |
| from loguru import logger | |
| from torch import distributed as dist | |
| from tqdm import tqdm | |
| from multiprocessing.pool import ThreadPool | |
| # Some hacky things to make experimentation easier | |
| def make_transform_multi_folder_data(paths, caption_files=None, **kwargs): | |
| ds = make_multi_folder_data(paths, caption_files, **kwargs) | |
| return TransformDataset(ds) | |
| def make_nfp_data(base_path): | |
| dirs = list(Path(base_path).glob("*/")) | |
| print(f"Found {len(dirs)} folders") | |
| print(dirs) | |
| tforms = [transforms.Resize(512), transforms.CenterCrop(512)] | |
| datasets = [NfpDataset(x, image_transforms=copy.copy(tforms), default_caption="A view from a train window") for x in dirs] | |
| return torch.utils.data.ConcatDataset(datasets) | |
| class VideoDataset(Dataset): | |
| def __init__(self, root_dir, image_transforms, caption_file, offset=8, n=2): | |
| self.root_dir = Path(root_dir) | |
| self.caption_file = caption_file | |
| self.n = n | |
| ext = "mp4" | |
| self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}"))) | |
| self.offset = offset | |
| if isinstance(image_transforms, ListConfig): | |
| image_transforms = [instantiate_from_config(tt) for tt in image_transforms] | |
| image_transforms.extend([transforms.ToTensor(), | |
| transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) | |
| image_transforms = transforms.Compose(image_transforms) | |
| self.tform = image_transforms | |
| with open(self.caption_file) as f: | |
| reader = csv.reader(f) | |
| rows = [row for row in reader] | |
| self.captions = dict(rows) | |
| def __len__(self): | |
| return len(self.paths) | |
| def __getitem__(self, index): | |
| for i in range(10): | |
| try: | |
| return self._load_sample(index) | |
| except Exception: | |
| # Not really good enough but... | |
| print("uh oh") | |
| def _load_sample(self, index): | |
| n = self.n | |
| filename = self.paths[index] | |
| min_frame = 2*self.offset + 2 | |
| vid = cv2.VideoCapture(str(filename)) | |
| max_frames = int(vid.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| curr_frame_n = random.randint(min_frame, max_frames) | |
| vid.set(cv2.CAP_PROP_POS_FRAMES,curr_frame_n) | |
| _, curr_frame = vid.read() | |
| prev_frames = [] | |
| for i in range(n): | |
| prev_frame_n = curr_frame_n - (i+1)*self.offset | |
| vid.set(cv2.CAP_PROP_POS_FRAMES,prev_frame_n) | |
| _, prev_frame = vid.read() | |
| prev_frame = self.tform(Image.fromarray(prev_frame[...,::-1])) | |
| prev_frames.append(prev_frame) | |
| vid.release() | |
| caption = self.captions[filename.name] | |
| data = { | |
| "image": self.tform(Image.fromarray(curr_frame[...,::-1])), | |
| "prev": torch.cat(prev_frames, dim=-1), | |
| "txt": caption | |
| } | |
| return data | |
| # end hacky things | |
| def make_tranforms(image_transforms): | |
| # if isinstance(image_transforms, ListConfig): | |
| # image_transforms = [instantiate_from_config(tt) for tt in image_transforms] | |
| image_transforms = [] | |
| image_transforms.extend([transforms.ToTensor(), | |
| transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) | |
| image_transforms = transforms.Compose(image_transforms) | |
| return image_transforms | |
| def make_multi_folder_data(paths, caption_files=None, **kwargs): | |
| """Make a concat dataset from multiple folders | |
| Don't suport captions yet | |
| If paths is a list, that's ok, if it's a Dict interpret it as: | |
| k=folder v=n_times to repeat that | |
| """ | |
| list_of_paths = [] | |
| if isinstance(paths, (Dict, DictConfig)): | |
| assert caption_files is None, \ | |
| "Caption files not yet supported for repeats" | |
| for folder_path, repeats in paths.items(): | |
| list_of_paths.extend([folder_path]*repeats) | |
| paths = list_of_paths | |
| if caption_files is not None: | |
| datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] | |
| else: | |
| datasets = [FolderData(p, **kwargs) for p in paths] | |
| return torch.utils.data.ConcatDataset(datasets) | |
| class NfpDataset(Dataset): | |
| def __init__(self, | |
| root_dir, | |
| image_transforms=[], | |
| ext="jpg", | |
| default_caption="", | |
| ) -> None: | |
| """assume sequential frames and a deterministic transform""" | |
| self.root_dir = Path(root_dir) | |
| self.default_caption = default_caption | |
| self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}"))) | |
| self.tform = make_tranforms(image_transforms) | |
| def __len__(self): | |
| return len(self.paths) - 1 | |
| def __getitem__(self, index): | |
| prev = self.paths[index] | |
| curr = self.paths[index+1] | |
| data = {} | |
| data["image"] = self._load_im(curr) | |
| data["prev"] = self._load_im(prev) | |
| data["txt"] = self.default_caption | |
| return data | |
| def _load_im(self, filename): | |
| im = Image.open(filename).convert("RGB") | |
| return self.tform(im) | |
| class ObjaverseDataModuleFromConfig(pl.LightningDataModule): | |
| def __init__(self, root_dir, batch_size, train=None, validation=None, | |
| test=None, num_workers=4, objaverse_data_list=None, ext="png", | |
| target_name="albedo", use_wds=True, tar_config=None, **kwargs): | |
| super().__init__(self) | |
| self.root_dir = root_dir | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| self.kwargs = kwargs | |
| self.tar_config = tar_config | |
| self.use_wds = use_wds | |
| if train is not None: | |
| dataset_config = train | |
| if validation is not None: | |
| dataset_config = validation | |
| image_transforms = [transforms.ToTensor(), | |
| transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))] | |
| image_transforms = torchvision.transforms.Compose(image_transforms) | |
| self.image_transforms = { | |
| "size": dataset_config.image_transforms.size, | |
| "totensor": image_transforms | |
| } | |
| self.target_name = target_name | |
| self.objaverse_data_list = objaverse_data_list | |
| self.ext = ext | |
| def naive_setup(self): | |
| # get object data list | |
| if self.objaverse_data_list is None or \ | |
| self.objaverse_data_list["image_list_cache_path"] == "None": | |
| # This is too slow.. | |
| self.paths = sorted(list(Path(self.root_dir).rglob(f"*{self.target_name}*.{self.ext}"))) | |
| if len(self.paths) == 0: | |
| # colmap format | |
| self.paths = sorted(list(Path(self.root_dir).rglob(f"*images_train/*.*"))) | |
| else: | |
| self.paths = get_rendered_objaverse_list_v0(self.root_dir, self.target_name, self.ext, **self.objaverse_data_list) | |
| random.shuffle(self.paths) | |
| # train val split | |
| total_objects = len(self.paths) | |
| self.paths_val = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation | |
| self.paths_train = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training | |
| if self.rank == 0: | |
| print('============= length of dataset %d =============' % len(self.paths)) | |
| print('============= length of training dataset %d =============' % len(self.paths_train)) | |
| print('============= length of Validation dataset %d =============' % len(self.paths_val)) | |
| # Split into each GPU | |
| self.paths_train = self._get_local_split(self.paths_train, self.world_size, self.rank) | |
| logger.info( | |
| f"[rank {self.rank}]: {len(self.paths_train)} images assigned." | |
| ) | |
| def _get_tar_length(self, tar_list, img_per_obj): | |
| dataset_size = 0 | |
| for _name in tar_list: | |
| num_obj = int(_name.rsplit("_num_")[1].rsplit(".")[0]) | |
| dataset_size += num_obj * img_per_obj | |
| return dataset_size | |
| def webdataset_setup(self, list_dir, tar_dir, img_per_obj, max_tars=None): | |
| # read data list and calculate size | |
| tar_name_list = sorted(os.listdir(list_dir)) | |
| if not max_tars is None: | |
| # for debugging on small scale data | |
| tar_name_list = tar_name_list[:max_tars] | |
| total_tars = len(tar_name_list) | |
| # random shuffle | |
| random.shuffle(tar_name_list) | |
| print(f"Rank {self.rank} shuffle: {tar_name_list}") | |
| # train test split | |
| self.test_tars = tar_name_list[math.floor(total_tars / 100. * 99.):] | |
| # make sure each node has one tar | |
| if len(self.test_tars) < self.world_size: | |
| self.test_tars += [self.test_tars[0]]*(self.world_size-len(self.test_tars)) | |
| self.train_tars = tar_name_list[:math.floor(total_tars / 100. * 99.)] | |
| # training tar truncation | |
| total_workers = self.num_workers * self.world_size | |
| num_tars_train = (len(self.train_tars) // total_workers) * total_workers | |
| if num_tars_train != len(self.train_tars): | |
| print(f"[WARNING] Total train tars: {len(self.train_tars)}, truncated: {len(self.train_tars)-num_tars_train}, remainnig: {num_tars_train}, total workers: {total_workers}") | |
| self.test_length = self._get_tar_length(self.test_tars, img_per_obj) | |
| self.train_length = self._get_tar_length(self.train_tars, img_per_obj) | |
| # name replace | |
| test_tars = [_name.rsplit("_num")[0]+".tar" for _name in self.test_tars] | |
| self.test_tars = [os.path.join(tar_dir, _name) for _name in test_tars] | |
| train_tars = [_name.rsplit("_num")[0]+".tar" for _name in self.train_tars] | |
| self.train_tars = [os.path.join(tar_dir, _name) for _name in train_tars] | |
| if self.rank == 0: | |
| print('============= length of dataset %d =============' % (self.test_length+self.train_length)) | |
| print('============= length of training dataset %d =============' % (self.train_length)) | |
| print('============= length of Validation dataset %d =============' % (self.test_length)) | |
| def setup(self, stage=None): | |
| try: | |
| self.world_size = dist.get_world_size() | |
| self.rank = dist.get_rank() | |
| except: | |
| self.world_size = 1 | |
| self.rank = 0 | |
| if self.rank == 0: | |
| print("#### Data ####") | |
| if self.use_wds: | |
| self.webdataset_setup(**self.tar_config) | |
| else: | |
| self.naive_setup() | |
| def _get_local_split(self, items: list, world_size: int, rank: int, seed: int = 6): | |
| """The local rank only loads a split of the dataset.""" | |
| n_items = len(items) | |
| items_permute = np.random.RandomState(seed).permutation(items) | |
| if n_items % world_size == 0: | |
| padded_items = items_permute | |
| else: | |
| padding = np.random.RandomState(seed).choice( | |
| items, world_size - (n_items % world_size), replace=True | |
| ) | |
| padded_items = np.concatenate([items_permute, padding]) | |
| assert ( | |
| len(padded_items) % world_size == 0 | |
| ), f"len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}" | |
| n_per_rank = len(padded_items) // world_size | |
| local_items = padded_items[n_per_rank * rank : n_per_rank * (rank + 1)] | |
| return local_items | |
| def train_dataloader(self): | |
| if self.use_wds: | |
| loader = self.train_dataloader_wds() | |
| else: | |
| loader = self.train_dataloader_naive() | |
| return loader | |
| def val_dataloader(self): | |
| if self.use_wds: | |
| loader = self.val_dataloader_wds() | |
| else: | |
| loader = self.val_dataloader_naive() | |
| return loader | |
| def train_dataloader_naive(self): | |
| dataset = ObjaverseData(root_dir=self.root_dir, \ | |
| image_transforms=self.image_transforms, | |
| image_list = self.paths_train, target_name=self.target_name, | |
| **self.kwargs) | |
| return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) | |
| def val_dataloader_naive(self): | |
| dataset = ObjaverseData(root_dir=self.root_dir, \ | |
| image_transforms=self.image_transforms, | |
| image_list = self.paths_val, target_name=self.target_name, | |
| **self.kwargs) | |
| return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) | |
| def train_dataloader_wds(self): | |
| decoder = ObjaverseDecoerWDS(root_dir=self.root_dir, \ | |
| image_transforms=self.image_transforms, | |
| image_list = None, target_name=self.target_name, | |
| **self.kwargs) | |
| worker_batch = self.batch_size | |
| epoch_length = self.train_length // worker_batch // self.num_workers // self.world_size | |
| dataset = (wds.WebDataset(self.train_tars, | |
| shardshuffle=min(1000, len(self.train_tars)), | |
| nodesplitter=wds.shardlists.split_by_node) | |
| .shuffle(5000, initial=1000) | |
| .map(decoder.process_sample) | |
| # .map(decoder.dict2tuple) | |
| .batched(worker_batch, partial=False) | |
| # .map(decoder.tuple2dict) | |
| .map(decoder.batch_reordering) | |
| .with_epoch(epoch_length) | |
| .with_length(epoch_length) | |
| ) | |
| loader = (wds.WebLoader(dataset, batch_size=None, num_workers=self.num_workers, shuffle=False) | |
| # .unbatched() | |
| # .shuffle(1000) | |
| # .batched(self.batch_size) | |
| # .map(decoder.tuple2dict) | |
| ) | |
| print(f"# Training loader length for single worker {epoch_length} with {self.num_workers} workers") | |
| return loader | |
| def val_dataloader_wds(self): | |
| decoder = ObjaverseDecoerWDS(root_dir=self.root_dir, \ | |
| image_transforms=self.image_transforms, | |
| image_list = None, target_name=self.target_name, | |
| **self.kwargs) | |
| # adjust worker number, as test has much much fewer tars | |
| val_workers = min(self.num_workers, len(self.test_tars) // self.world_size) | |
| epoch_length = max(self.test_length // self.batch_size // val_workers // self.world_size, 1) | |
| dataset = (wds.WebDataset(self.test_tars, | |
| shardshuffle=min(1000, len(self.test_tars)), | |
| handler=wds.ignore_and_continue, | |
| nodesplitter=wds.shardlists.split_by_node) | |
| .shuffle(1000) | |
| .map(decoder.process_sample) | |
| # .map(decoder.dict2tuple) | |
| .batched(self.batch_size, partial=False) | |
| .with_epoch(epoch_length) | |
| .with_length(epoch_length) | |
| ) | |
| loader = (wds.WebLoader(dataset, batch_size=None, num_workers=val_workers, shuffle=False) | |
| .unbatched() | |
| .shuffle(1000) | |
| .batched(self.batch_size) | |
| # .map(decoder.tuple2dict) | |
| .map(decoder.batch_reordering) | |
| ) | |
| print(f"# Validation loader length for single worker {epoch_length} with {val_workers} workers") | |
| return loader | |
| def test_dataloader(self): | |
| # testing will use all given data | |
| return wds.WebLoader(ObjaverseData(root_dir=self.root_dir, test=True, | |
| image_transforms=self.image_transforms, | |
| image_list = self.paths, target_name=self.target_name, | |
| **self.kwargs), | |
| batch_size=32, num_workers=self.num_workers, shuffle=False, | |
| ) | |
| class ObjaverseData(ObjaverseDataDecoder, Dataset): | |
| def __init__(self, | |
| root_dir='.objaverse/hf-objaverse-v1/views', | |
| image_list=None, | |
| threads=64, | |
| **kargs | |
| ) -> None: | |
| """Create a dataset from blender rendering results. | |
| If you pass in a root directory it will be searched for images | |
| ending in ext (ext can be a list) | |
| """ | |
| self.paths = image_list | |
| self.root_dir = Path(root_dir) | |
| ObjaverseDataDecoder.__init__(self, **kargs) | |
| # pre-load data | |
| print(f"Data pre loading start with {threads}...") | |
| self.all_target_im = np.zeros((len(self.paths), self.img_size, self.img_size, 3), dtype=np.uint8) + 0 | |
| self.all_cond_im = np.zeros((len(self.paths), self.img_size, self.img_size, 3), dtype=np.uint8) + 0 | |
| self.all_filename = ["empty"] * len(self.paths) | |
| if self.condition_name == "normal": | |
| self.all_normal_img = np.zeros((len(self.paths), self.img_size, self.img_size, 3), dtype=np.uint8) + 0 | |
| self.all_crop_idx = np.zeros((len(self.paths), 6), dtype=int) + 0 | |
| print("Array allocated..") | |
| def parallel_load(index): | |
| pbar.update(1) | |
| self.preload_item(index) | |
| pbar = tqdm(total=len(self.paths)) | |
| with ThreadPool(threads) as pool: | |
| pool.map(parallel_load, range(len(self.paths))) | |
| pool.close() | |
| pool.join() | |
| print("Data pre loading done...") | |
| def __len__(self): | |
| return len(self.paths) | |
| def load_mask(self, mask_filename, cond_im): | |
| # auto image file extention | |
| glob_files = glob.glob(mask_filename.rsplit(".", 1)[0] + ".*") | |
| if len(glob_files) == 0: | |
| print("Warning: no mask image find") | |
| img_mask = np.ones_like(cond_im) | |
| if cond_im.shape[-1] == 4: | |
| print("Use image mask") | |
| img_mask = img_mask * cond_im[:, :, -1:] | |
| elif len(glob_files) == 1: | |
| img_mask = np.array(self.normalized_read(glob_files[0])) | |
| else: | |
| raise NotImplementedError("Too many mask images found! {}") | |
| return img_mask | |
| def preload_item(self, index): | |
| path = self.paths[index] | |
| filename = os.path.join(path) | |
| filename, condition_filename, \ | |
| mask_filename, normal_condition_filename, filename_targets = self.path_parsing(filename) | |
| # get file streams | |
| if filename_targets is None: | |
| filename_read = filename | |
| else: | |
| filename_read = filename_targets | |
| # image reading | |
| target_im, cond_im, normal_img = self.read_images(filename_read, | |
| condition_filename, normal_condition_filename) | |
| # mask reading | |
| img_mask = self.load_mask(mask_filename, cond_im) | |
| # post processing | |
| target_im, cond_im, normal_img, crop_idx = self.image_post_processing(img_mask, target_im, cond_im, normal_img) | |
| if self.test: | |
| # crop out valid_mask | |
| self.all_crop_idx[index] = crop_idx | |
| # put results | |
| self.all_target_im[index] = target_im | |
| self.all_cond_im[index] = cond_im | |
| self.all_filename[index] = filename | |
| if self.condition_name == "normal": | |
| self.all_normal_img[index] = normal_img | |
| def get_camera(self, input_filename): | |
| camera_file = input_filename.replace(f'{self.target_name}0001', \ | |
| 'camera').rsplit(".")[0] + ".pkl" | |
| cam_dir, cam_name = camera_file.rsplit("/", 1) | |
| cam_name = f"{cam_name:>15}" | |
| camera_file = os.path.join(cam_dir, cam_name) | |
| cam = pickle.load(open(camera_file, 'rb')) | |
| return cam | |
| def __getitem__(self, index): | |
| target_im = self.process_im(self.all_target_im[index]) | |
| cond_img = self.process_im(self.all_cond_im[index]) | |
| filename = self.all_filename[index] | |
| normal_img = self.process_im(self.all_normal_img[index]) \ | |
| if self.condition_name == "normal" \ | |
| else None | |
| sample = self.parse_item(target_im, cond_img, normal_img, filename) | |
| if self.test: | |
| sample["crop_idx"] = self.all_crop_idx[index] | |
| return sample | |
| if __name__ == "__main__": | |
| import pyhocon | |
| class DictAsMember(dict): | |
| def __getattr__(self, name): | |
| value = self[name] | |
| if isinstance(value, dict): | |
| value = DictAsMember(value) | |
| return value | |
| def ConfigAsMember(config): | |
| config_dict = DictAsMember(config) | |
| for key in config_dict.keys(): | |
| if isinstance(config_dict[key], pyhocon.config_tree.ConfigTree): | |
| config_dict[key] = ConfigAsMember(config_dict[key]) | |
| return config_dict | |
| train_config = DictAsMember({ | |
| "validation": False, | |
| "image_transforms": {"size": 256} | |
| }) | |
| val_config = DictAsMember({ | |
| "validation": True, | |
| "image_transforms": {"size": 256} | |
| }) | |
| objaverse_data_list = DictAsMember({ | |
| "image_list_cache_path": "image_lists/half_400000_image_list.npz", | |
| }) | |
| data_module = ObjaverseDataModuleFromConfig(root_dir='/mnt/volumes/perception/hujunkang/codes/renders/material-diffusion/data/objaverse_rendering', | |
| batch_size=4, train=train_config, validation=val_config, | |
| test=None, num_workers=1, objaverse_data_list=objaverse_data_list, ext="png", | |
| target_name="albedo", use_wds=False, tar_config=None) | |
| data_module.setup() | |
| train_dataloader_naive = data_module.train_dataloader_naive() |