| from pathlib import Path |
| import json |
| import numpy as np |
| import PIL.Image as Image |
| import torch |
| import torchvision.transforms.functional as F |
| from torch.utils.data import Dataset |
| from vhap.util.log import get_logger |
|
|
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class NeRFDataset(Dataset): |
| def __init__( |
| self, |
| root_folder, |
| division=None, |
| camera_convention_conversion=None, |
| target_extrinsic_type='w2c', |
| use_fg_mask=False, |
| use_flame_param=False, |
| ): |
| """ |
| Args: |
| root_folder: Path to dataset with the following directory layout |
| <root_folder>/ |
| | |
| |---<images>/ |
| | |---00000.jpg |
| | |... |
| | |
| |---<fg_masks>/ |
| | |---00000.png |
| | |... |
| | |
| |---<flame_param>/ |
| | |---00000.npz |
| | |... |
| | |
| |---transforms_backup.json # backup of the original transforms.json |
| |---transforms_backup_flame.json # backup of the original transforms.json with flame_param |
| |---transforms.json # the final transforms.json |
| |---transforms_train.json # the final transforms.json for training |
| |---transforms_val.json # the final transforms.json for validation |
| |---transforms_test.json # the final transforms.json for testing |
| |
| |
| """ |
|
|
| super().__init__() |
| self.root_folder = Path(root_folder) |
| self.division = division |
| self.camera_convention_conversion = camera_convention_conversion |
| self.target_extrinsic_type = target_extrinsic_type |
| self.use_fg_mask = use_fg_mask |
| self.use_flame_param = use_flame_param |
|
|
| logger.info(f"Loading NeRF scene from: {root_folder}") |
|
|
| |
| if division is None: |
| tranform_path = self.root_folder / "transforms.json" |
| elif division == "train": |
| tranform_path = self.root_folder / "transforms_train.json" |
| elif division == "val": |
| tranform_path = self.root_folder / "transforms_val.json" |
| elif division == "test": |
| tranform_path = self.root_folder / "transforms_test.json" |
| else: |
| raise NotImplementedError(f"Unknown division type: {division}") |
| logger.info(f"division: {division}") |
|
|
| self.transforms = json.load(open(tranform_path, "r")) |
| logger.info(f"number of timesteps: {len(self.transforms['timestep_indices'])}, number of cameras: {len(self.transforms['camera_indices'])}") |
|
|
| assert len(self.transforms['timestep_indices']) == max(self.transforms['timestep_indices']) + 1 |
|
|
| def __len__(self): |
| return len(self.transforms['frames']) |
|
|
| def __getitem__(self, i): |
| frame = self.transforms['frames'][i] |
|
|
| |
|
|
| K = torch.eye(3) |
| K[[0, 1, 0, 1], [0, 1, 2, 2]] = torch.tensor( |
| [frame["fl_x"], frame["fl_y"], frame["cx"], frame["cy"]] |
| ) |
|
|
| c2w = torch.tensor(frame['transform_matrix']) |
| if self.target_extrinsic_type == "w2c": |
| extrinsic = c2w.inverse() |
| elif self.target_extrinsic_type == "c2w": |
| extrinsic = c2w |
| else: |
| raise NotImplementedError(f"Unknown extrinsic type: {self.target_extrinsic_type}") |
| |
| img_path = self.root_folder / frame['file_path'] |
|
|
| item = { |
| 'timestep_index': frame['timestep_index'], |
| 'camera_index': frame['camera_index'], |
| 'intrinsics': K, |
| 'extrinsics': extrinsic, |
| 'image_height': frame['h'], |
| 'image_width': frame['w'], |
| 'image': np.array(Image.open(img_path)), |
| 'image_path': img_path, |
| } |
|
|
| if self.use_fg_mask and 'fg_mask_path' in frame: |
| fg_mask_path = self.root_folder / frame['fg_mask_path'] |
| item["fg_mask"] = np.array(Image.open(fg_mask_path)) |
| item["fg_mask_path"] = fg_mask_path |
|
|
| if self.use_flame_param and 'flame_param_path' in frame: |
| npz = np.load(self.root_folder / frame['flame_param_path'], allow_pickle=True) |
| item["flame_param"] = dict(npz) |
|
|
| return item |
|
|
| def apply_to_tensor(self, item): |
| if self.img_to_tensor: |
| if "rgb" in item: |
| item["rgb"] = F.to_tensor(item["rgb"]) |
| |
| |
|
|
| if "alpha_map" in item: |
| item["alpha_map"] = F.to_tensor(item["alpha_map"]) |
| return item |
|
|
|
|
| if __name__ == "__main__": |
| from tqdm import tqdm |
| from dataclasses import dataclass |
| import tyro |
| from torch.utils.data import DataLoader |
|
|
| @dataclass |
| class Args: |
| root_folder: str |
| subject: str |
| sequence: str |
| use_landmark: bool = False |
| batchify_all_views: bool = False |
|
|
| args = tyro.cli(Args) |
|
|
| dataset = NeRFDataset(root_folder=args.root_folder) |
|
|
| print(len(dataset)) |
|
|
| sample = dataset[0] |
| print(sample.keys()) |
|
|
| dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1) |
| for item in tqdm(dataloader): |
| pass |
|
|