Spaces:
Build error
Build error
| from pathlib import Path | |
| import json | |
| import numpy as np | |
| import PIL.Image as Image | |
| import torch | |
| import torchvision.transforms.functional as F | |
| from torch.utils.data import Dataset | |
| from vhap.util.log import get_logger | |
| logger = get_logger(__name__) | |
| class NeRFDataset(Dataset): | |
| def __init__( | |
| self, | |
| root_folder, | |
| division=None, | |
| camera_convention_conversion=None, | |
| target_extrinsic_type='w2c', | |
| use_fg_mask=False, | |
| use_flame_param=False, | |
| ): | |
| """ | |
| Args: | |
| root_folder: Path to dataset with the following directory layout | |
| <root_folder>/ | |
| | | |
| |---<images>/ | |
| | |---00000.jpg | |
| | |... | |
| | | |
| |---<fg_masks>/ | |
| | |---00000.png | |
| | |... | |
| | | |
| |---<flame_param>/ | |
| | |---00000.npz | |
| | |... | |
| | | |
| |---transforms_backup.json # backup of the original transforms.json | |
| |---transforms_backup_flame.json # backup of the original transforms.json with flame_param | |
| |---transforms.json # the final transforms.json | |
| |---transforms_train.json # the final transforms.json for training | |
| |---transforms_val.json # the final transforms.json for validation | |
| |---transforms_test.json # the final transforms.json for testing | |
| """ | |
| super().__init__() | |
| self.root_folder = Path(root_folder) | |
| self.division = division | |
| self.camera_convention_conversion = camera_convention_conversion | |
| self.target_extrinsic_type = target_extrinsic_type | |
| self.use_fg_mask = use_fg_mask | |
| self.use_flame_param = use_flame_param | |
| logger.info(f"Loading NeRF scene from: {root_folder}") | |
| # data division | |
| if division is None: | |
| tranform_path = self.root_folder / "transforms.json" | |
| elif division == "train": | |
| tranform_path = self.root_folder / "transforms_train.json" | |
| elif division == "val": | |
| tranform_path = self.root_folder / "transforms_val.json" | |
| elif division == "test": | |
| tranform_path = self.root_folder / "transforms_test.json" | |
| else: | |
| raise NotImplementedError(f"Unknown division type: {division}") | |
| logger.info(f"division: {division}") | |
| self.transforms = json.load(open(tranform_path, "r")) | |
| logger.info(f"number of timesteps: {len(self.transforms['timestep_indices'])}, number of cameras: {len(self.transforms['camera_indices'])}") | |
| assert len(self.transforms['timestep_indices']) == max(self.transforms['timestep_indices']) + 1 | |
| def __len__(self): | |
| return len(self.transforms['frames']) | |
| def __getitem__(self, i): | |
| frame = self.transforms['frames'][i] | |
| # 'timestep_index', 'timestep_index_original', 'timestep_id', 'camera_index', 'camera_id', 'cx', 'cy', 'fl_x', 'fl_y', 'h', 'w', 'camera_angle_x', 'camera_angle_y', 'transform_matrix', 'file_path', 'fg_mask_path', 'flame_param_path'] | |
| K = torch.eye(3) | |
| K[[0, 1, 0, 1], [0, 1, 2, 2]] = torch.tensor( | |
| [frame["fl_x"], frame["fl_y"], frame["cx"], frame["cy"]] | |
| ) | |
| c2w = torch.tensor(frame['transform_matrix']) | |
| if self.target_extrinsic_type == "w2c": | |
| extrinsic = c2w.inverse() | |
| elif self.target_extrinsic_type == "c2w": | |
| extrinsic = c2w | |
| else: | |
| raise NotImplementedError(f"Unknown extrinsic type: {self.target_extrinsic_type}") | |
| img_path = self.root_folder / frame['file_path'] | |
| item = { | |
| 'timestep_index': frame['timestep_index'], | |
| 'camera_index': frame['camera_index'], | |
| 'intrinsics': K, | |
| 'extrinsics': extrinsic, | |
| 'image_height': frame['h'], | |
| 'image_width': frame['w'], | |
| 'image': np.array(Image.open(img_path)), | |
| 'image_path': img_path, | |
| } | |
| if self.use_fg_mask and 'fg_mask_path' in frame: | |
| fg_mask_path = self.root_folder / frame['fg_mask_path'] | |
| item["fg_mask"] = np.array(Image.open(fg_mask_path)) | |
| item["fg_mask_path"] = fg_mask_path | |
| if self.use_flame_param and 'flame_param_path' in frame: | |
| npz = np.load(self.root_folder / frame['flame_param_path'], allow_pickle=True) | |
| item["flame_param"] = dict(npz) | |
| return item | |
| def apply_to_tensor(self, item): | |
| if self.img_to_tensor: | |
| if "rgb" in item: | |
| item["rgb"] = F.to_tensor(item["rgb"]) | |
| # if self.rgb_range_shift: | |
| # item["rgb"] = (item["rgb"] - 0.5) / 0.5 | |
| if "alpha_map" in item: | |
| item["alpha_map"] = F.to_tensor(item["alpha_map"]) | |
| return item | |
| if __name__ == "__main__": | |
| from tqdm import tqdm | |
| from dataclasses import dataclass | |
| import tyro | |
| from torch.utils.data import DataLoader | |
| class Args: | |
| root_folder: str | |
| subject: str | |
| sequence: str | |
| use_landmark: bool = False | |
| batchify_all_views: bool = False | |
| args = tyro.cli(Args) | |
| dataset = NeRFDataset(root_folder=args.root_folder) | |
| print(len(dataset)) | |
| sample = dataset[0] | |
| print(sample.keys()) | |
| dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1) | |
| for item in tqdm(dataloader): | |
| pass | |