Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import math | |
| from dataclasses import dataclass, field | |
| import os | |
| import imageio | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from tgs.utils.config import parse_structured | |
| from tgs.utils.ops import get_intrinsic_from_fov, get_ray_directions, get_rays | |
| from tgs.utils.typing import * | |
| def _parse_scene_list_single(scene_list_path: str): | |
| if scene_list_path.endswith(".json"): | |
| with open(scene_list_path) as f: | |
| all_scenes = json.loads(f.read()) | |
| elif scene_list_path.endswith(".txt"): | |
| with open(scene_list_path) as f: | |
| all_scenes = [p.strip() for p in f.readlines()] | |
| else: | |
| all_scenes = [scene_list_path] | |
| return all_scenes | |
| def _parse_scene_list(scene_list_path: Union[str, List[str]]): | |
| all_scenes = [] | |
| if isinstance(scene_list_path, str): | |
| scene_list_path = [scene_list_path] | |
| for scene_list_path_ in scene_list_path: | |
| all_scenes += _parse_scene_list_single(scene_list_path_) | |
| return all_scenes | |
| class CustomImageDataModuleConfig: | |
| image_list: Any = "" | |
| background_color: Tuple[float, float, float] = field( | |
| default_factory=lambda: (1.0, 1.0, 1.0) | |
| ) | |
| relative_pose: bool = False | |
| cond_height: int = 512 | |
| cond_width: int = 512 | |
| cond_camera_distance: float = 1.6 | |
| cond_fovy_deg: float = 40.0 | |
| cond_elevation_deg: float = 0.0 | |
| cond_azimuth_deg: float = 0.0 | |
| num_workers: int = 16 | |
| eval_height: int = 512 | |
| eval_width: int = 512 | |
| eval_batch_size: int = 1 | |
| eval_elevation_deg: float = 0.0 | |
| eval_camera_distance: float = 1.6 | |
| eval_fovy_deg: float = 40.0 | |
| n_test_views: int = 120 | |
| num_views_output: int = 120 | |
| only_3dgs: bool = False | |
| class CustomImageOrbitDataset(Dataset): | |
| def __init__(self, cfg: Any) -> None: | |
| super().__init__() | |
| self.cfg: CustomImageDataModuleConfig = parse_structured(CustomImageDataModuleConfig, cfg) | |
| self.n_views = self.cfg.n_test_views | |
| assert self.n_views % self.cfg.num_views_output == 0 | |
| self.all_scenes = _parse_scene_list(self.cfg.image_list) | |
| azimuth_deg: Float[Tensor, "B"] = torch.linspace(0, 360.0, self.n_views + 1)[ | |
| : self.n_views | |
| ] | |
| elevation_deg: Float[Tensor, "B"] = torch.full_like( | |
| azimuth_deg, self.cfg.eval_elevation_deg | |
| ) | |
| camera_distances: Float[Tensor, "B"] = torch.full_like( | |
| elevation_deg, self.cfg.eval_camera_distance | |
| ) | |
| elevation = elevation_deg * math.pi / 180 | |
| azimuth = azimuth_deg * math.pi / 180 | |
| # convert spherical coordinates to cartesian coordinates | |
| # right hand coordinate system, x back, y right, z up | |
| # elevation in (-90, 90), azimuth from +x to +y in (-180, 180) | |
| camera_positions: Float[Tensor, "B 3"] = torch.stack( | |
| [ | |
| camera_distances * torch.cos(elevation) * torch.cos(azimuth), | |
| camera_distances * torch.cos(elevation) * torch.sin(azimuth), | |
| camera_distances * torch.sin(elevation), | |
| ], | |
| dim=-1, | |
| ) | |
| # default scene center at origin | |
| center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions) | |
| # default camera up direction as +z | |
| up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[ | |
| None, : | |
| ].repeat(self.n_views, 1) | |
| fovy_deg: Float[Tensor, "B"] = torch.full_like( | |
| elevation_deg, self.cfg.eval_fovy_deg | |
| ) | |
| fovy = fovy_deg * math.pi / 180 | |
| lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1) | |
| right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1) | |
| up = F.normalize(torch.cross(right, lookat), dim=-1) | |
| c2w3x4: Float[Tensor, "B 3 4"] = torch.cat( | |
| [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], | |
| dim=-1, | |
| ) | |
| c2w: Float[Tensor, "B 4 4"] = torch.cat( | |
| [c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1 | |
| ) | |
| c2w[:, 3, 3] = 1.0 | |
| # get directions by dividing directions_unit_focal by focal length | |
| focal_length: Float[Tensor, "B"] = ( | |
| 0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy) | |
| ) | |
| directions_unit_focal = get_ray_directions( | |
| H=self.cfg.eval_height, | |
| W=self.cfg.eval_width, | |
| focal=1.0, | |
| ) | |
| directions: Float[Tensor, "B H W 3"] = directions_unit_focal[ | |
| None, :, :, : | |
| ].repeat(self.n_views, 1, 1, 1) | |
| directions[:, :, :, :2] = ( | |
| directions[:, :, :, :2] / focal_length[:, None, None, None] | |
| ) | |
| # must use normalize=True to normalize directions here | |
| rays_o, rays_d = get_rays(directions, c2w, keepdim=True) | |
| intrinsic: Float[Tensor, "B 3 3"] = get_intrinsic_from_fov( | |
| self.cfg.eval_fovy_deg * math.pi / 180, | |
| H=self.cfg.eval_height, | |
| W=self.cfg.eval_width, | |
| bs=self.n_views, | |
| ) | |
| intrinsic_normed: Float[Tensor, "B 3 3"] = intrinsic.clone() | |
| intrinsic_normed[..., 0, 2] /= self.cfg.eval_width | |
| intrinsic_normed[..., 1, 2] /= self.cfg.eval_height | |
| intrinsic_normed[..., 0, 0] /= self.cfg.eval_width | |
| intrinsic_normed[..., 1, 1] /= self.cfg.eval_height | |
| self.rays_o, self.rays_d = rays_o, rays_d | |
| self.intrinsic = intrinsic | |
| self.intrinsic_normed = intrinsic_normed | |
| self.c2w = c2w | |
| self.camera_positions = camera_positions | |
| self.background_color = torch.as_tensor(self.cfg.background_color) | |
| # condition | |
| self.intrinsic_cond = get_intrinsic_from_fov( | |
| np.deg2rad(self.cfg.cond_fovy_deg), | |
| H=self.cfg.cond_height, | |
| W=self.cfg.cond_width, | |
| ) | |
| self.intrinsic_normed_cond = self.intrinsic_cond.clone() | |
| self.intrinsic_normed_cond[..., 0, 2] /= self.cfg.cond_width | |
| self.intrinsic_normed_cond[..., 1, 2] /= self.cfg.cond_height | |
| self.intrinsic_normed_cond[..., 0, 0] /= self.cfg.cond_width | |
| self.intrinsic_normed_cond[..., 1, 1] /= self.cfg.cond_height | |
| if self.cfg.relative_pose: | |
| self.c2w_cond = torch.as_tensor( | |
| [ | |
| [0, 0, 1, self.cfg.cond_camera_distance], | |
| [1, 0, 0, 0], | |
| [0, 1, 0, 0], | |
| [0, 0, 0, 1], | |
| ] | |
| ).float() | |
| else: | |
| cond_elevation = self.cfg.cond_elevation_deg * math.pi / 180 | |
| cond_azimuth = self.cfg.cond_azimuth_deg * math.pi / 180 | |
| cond_camera_position: Float[Tensor, "3"] = torch.as_tensor( | |
| [ | |
| self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.cos(cond_azimuth), | |
| self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.sin(cond_azimuth), | |
| self.cfg.cond_camera_distance * np.sin(cond_elevation), | |
| ], dtype=torch.float32 | |
| ) | |
| cond_center: Float[Tensor, "3"] = torch.zeros_like(cond_camera_position) | |
| cond_up: Float[Tensor, "3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32) | |
| cond_lookat: Float[Tensor, "3"] = F.normalize(cond_center - cond_camera_position, dim=-1) | |
| cond_right: Float[Tensor, "3"] = F.normalize(torch.cross(cond_lookat, cond_up), dim=-1) | |
| cond_up = F.normalize(torch.cross(cond_right, cond_lookat), dim=-1) | |
| cond_c2w3x4: Float[Tensor, "3 4"] = torch.cat( | |
| [torch.stack([cond_right, cond_up, -cond_lookat], dim=-1), cond_camera_position[:, None]], | |
| dim=-1, | |
| ) | |
| cond_c2w: Float[Tensor, "4 4"] = torch.cat( | |
| [cond_c2w3x4, torch.zeros_like(cond_c2w3x4[:1])], dim=0 | |
| ) | |
| cond_c2w[3, 3] = 1.0 | |
| self.c2w_cond = cond_c2w | |
| def __len__(self): | |
| if self.cfg.only_3dgs: | |
| return len(self.all_scenes) | |
| else: | |
| return len(self.all_scenes) * self.n_views // self.cfg.num_views_output | |
| def __getitem__(self, index): | |
| if self.cfg.only_3dgs: | |
| scene_index = index | |
| view_index = [0] | |
| else: | |
| scene_index = index * self.cfg.num_views_output // self.n_views | |
| view_start = index % (self.n_views // self.cfg.num_views_output) | |
| view_index = list(range(self.n_views))[view_start * self.cfg.num_views_output : | |
| (view_start + 1) * self.cfg.num_views_output] | |
| img_path = self.all_scenes[scene_index] | |
| img_cond = torch.from_numpy( | |
| np.asarray( | |
| Image.fromarray(imageio.v2.imread(img_path)) | |
| .convert("RGBA") | |
| .resize((self.cfg.cond_width, self.cfg.cond_height)) | |
| ) | |
| / 255.0 | |
| ).float() | |
| mask_cond: Float[Tensor, "Hc Wc 1"] = img_cond[:, :, -1:] | |
| rgb_cond: Float[Tensor, "Hc Wc 3"] = img_cond[ | |
| :, :, :3 | |
| ] * mask_cond + self.background_color[None, None, :] * (1 - mask_cond) | |
| out = { | |
| "rgb_cond": rgb_cond.unsqueeze(0), | |
| "c2w_cond": self.c2w_cond.unsqueeze(0), | |
| "mask_cond": mask_cond.unsqueeze(0), | |
| "intrinsic_cond": self.intrinsic_cond.unsqueeze(0), | |
| "intrinsic_normed_cond": self.intrinsic_normed_cond.unsqueeze(0), | |
| "view_index": torch.as_tensor(view_index), | |
| "rays_o": self.rays_o[view_index], | |
| "rays_d": self.rays_d[view_index], | |
| "intrinsic": self.intrinsic[view_index], | |
| "intrinsic_normed": self.intrinsic_normed[view_index], | |
| "c2w": self.c2w[view_index], | |
| "camera_positions": self.camera_positions[view_index], | |
| } | |
| out["c2w"][..., :3, 1:3] *= -1 | |
| out["c2w_cond"][..., :3, 1:3] *= -1 | |
| instance_id = os.path.split(img_path)[-1].split('.')[0] | |
| out["index"] = torch.as_tensor(scene_index) | |
| out["background_color"] = self.background_color | |
| out["instance_id"] = instance_id | |
| return out | |
| def collate(self, batch): | |
| batch = torch.utils.data.default_collate(batch) | |
| batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width}) | |
| return batch |