Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import Literal, Optional | |
| from optgs.dataset.data_types import BatchedViews | |
| import numpy as np | |
| import torch | |
| import math | |
| import torch.nn.functional as F | |
| from pathlib import Path | |
| from optgs.experimental.edgs.init import init_gaussians_with_corr | |
| from optgs.experimental.initializers_utils import knn, points_to_gaussians | |
| from optgs.model.types import Gaussians | |
| from optgs.scene_trainer.common.gaussian_adapter import build_covariance | |
| from optgs.scene_trainer.initializer.initializer import InitializerOutput, NonlearnedInitializer, NonlearnedInitializerCfg | |
| class InitializerEdgsCfg(NonlearnedInitializerCfg): | |
| name: Literal["edgs"] | |
| sh_degree: int | |
| init_opacity: float | |
| scaling_factor: float | |
| roma_model_type: str | |
| sample_init_gaussians: int # if >0, randomly sample this many gaussians from the initialized set | |
| def get_gaussian_param_num(self): | |
| # calculate the number of parameters per Gaussian | |
| sh_d = self.get_sh_d() | |
| # TODO Naama: check where this is used, and if it is needed | |
| init_gaussian_param_num = 3 + 4 + 3 * sh_d + 2 + 1 | |
| return init_gaussian_param_num | |
| def get_sh_d(self): | |
| sh_d = (self.sh_degree + 1) ** 2 | |
| return sh_d | |
| class InitializerEdgs(NonlearnedInitializer[InitializerEdgsCfg]): | |
| def __init__(self, cfg: InitializerEdgsCfg) -> None: | |
| super().__init__(cfg) | |
| def forward( | |
| self, | |
| context: BatchedViews, | |
| visualization_dump: Optional[dict] = None, | |
| cached_data_path: Optional[Path] = None, | |
| **kwargs | |
| ) -> InitializerOutput: | |
| device = context["extrinsics"].device | |
| # unpack context (batch_dim = 1) | |
| viewpoints_img = context["image"].squeeze(0) # [N, 3, H, W] | |
| h, w = viewpoints_img.shape[2], viewpoints_img.shape[3] | |
| # poses | |
| viewpoints_c2w = context["extrinsics"].squeeze(0).clone() # [N, 4, 4] | |
| camera_centers = viewpoints_c2w[..., :3, 3] | |
| viewpoints_w2c = torch.inverse(viewpoints_c2w) # [N, 4, 4] | |
| # convert to column-major | |
| viewpoints_w2c = viewpoints_w2c.permute(0, 2, 1) | |
| # intrinsics | |
| viewpoints_intrinsics = context["intrinsics"].squeeze(0).clone() # [N, 3, 3] | |
| # un-normalize intrinsics by multiplying by image size | |
| viewpoints_intrinsics[:, 0, :] *= w | |
| viewpoints_intrinsics[:, 1, :] *= h | |
| def getProjectionMatrix(znear, zfar, fovX, fovY): | |
| tanHalfFovY = math.tan((fovY / 2)) | |
| tanHalfFovX = math.tan((fovX / 2)) | |
| top = tanHalfFovY * znear | |
| bottom = -top | |
| right = tanHalfFovX * znear | |
| left = -right | |
| P = torch.zeros(4, 4) | |
| z_sign = 1.0 | |
| P[0, 0] = 2.0 * znear / (right - left) | |
| P[1, 1] = 2.0 * znear / (top - bottom) | |
| P[0, 2] = (right + left) / (right - left) | |
| P[1, 2] = (top + bottom) / (top - bottom) | |
| P[3, 2] = z_sign | |
| P[2, 2] = z_sign * zfar / (zfar - znear) | |
| P[2, 3] = -(zfar * znear) / (zfar - znear) | |
| return P | |
| def focal2fov(focal, pixels): | |
| return 2 * math.atan(pixels / (2 * focal)) | |
| viewpoints_proj = [] | |
| for idx, intrinsic in enumerate(viewpoints_intrinsics): | |
| fx = intrinsic[0, 0] | |
| fy = intrinsic[1, 1] | |
| znear = 0.01 | |
| zfar = 100.0 | |
| fovY = focal2fov(fy, h) | |
| fovX = focal2fov(fx, w) | |
| proj = getProjectionMatrix( | |
| znear=znear, zfar=zfar, fovX=fovX, fovY=fovY | |
| ).transpose(0, 1).cuda() | |
| viewpoints_proj.append(proj) | |
| viewpoints_proj = torch.stack(viewpoints_proj, dim=0) # [N, 4, 4] | |
| # compute full projection matrices | |
| viewpoints_full_proj = (viewpoints_w2c.bmm(viewpoints_proj)) # [N, 4, 4] | |
| # check if points_dict is stored on disk already (cached) | |
| found_cached = False | |
| if cached_data_path is not None: | |
| print("Checking for cached points_dict at:", str(cached_data_path)) | |
| cache_path = cached_data_path / "points_dict.pt" | |
| if cache_path.exists(): | |
| points_dict = torch.load(cache_path) | |
| print("Loaded cached points_dict from:", str(cache_path)) | |
| found_cached = True | |
| else: | |
| print("No cached points_dict found at:", str(cache_path)) | |
| if not found_cached: | |
| # recompute points_dict | |
| _, _, points_dict = init_gaussians_with_corr( | |
| viewpoints_img=viewpoints_img, # [N, 3, H, W] | |
| viewpoints_w2c=viewpoints_w2c, # [N, 4, 4] | |
| viewpoints_proj=viewpoints_full_proj, # [N, 4, 4] | |
| camera_centers=camera_centers, # [N, 3] | |
| init_opacity=self.cfg.init_opacity, | |
| roma_model_type=self.cfg.roma_model_type, | |
| verbose=False | |
| ) | |
| if cached_data_path is not None: | |
| print("Saving points_dict to cache at:", str(cache_path)) | |
| cached_data_path.mkdir(parents=True, exist_ok=True) | |
| torch.save(points_dict, cache_path) | |
| points_dict["scales"] *= self.cfg.scaling_factor | |
| # printing some stats | |
| for k, v in points_dict.items(): | |
| print(f"points_dict[{k}]: shape={v.shape}, dtype={v.dtype}, min={v.min().item()}, max={v.max().item()}") | |
| # downsample if needed | |
| if self.cfg.sample_init_gaussians > 0: | |
| # randomly sample a subset of gaussians | |
| total_points = points_dict["xyz"].shape[0] | |
| sample_num = min(self.cfg.sample_init_gaussians, total_points) | |
| sampled_indices = torch.randperm(total_points)[:sample_num] | |
| points_dict = {k: v[sampled_indices] for k, v in points_dict.items()} | |
| print("Nr points after sampling:", points_dict["xyz"].shape[0]) | |
| # pre-activation values on device | |
| gaussians_dict = points_to_gaussians(points_dict, sh_degree=self.cfg.sh_degree, device=device) | |
| means = gaussians_dict["xyz"] | |
| sh0 = gaussians_dict["sh0"] | |
| shN = gaussians_dict["shN"] | |
| harmonics = torch.cat([sh0, shN], dim=1) # [N, sh_d, 3] | |
| harmonics = harmonics.permute(0, 2, 1) # [N, 3, sh_d] | |
| rotations_unnorm = gaussians_dict["rotations_unnorm"] | |
| # post-activation values | |
| opacities = torch.sigmoid(gaussians_dict["opacities_raw"]) | |
| scales = torch.exp(gaussians_dict["scales_raw"]) | |
| rotations = F.normalize(gaussians_dict["rotations_unnorm"], dim=-1) | |
| covariances = build_covariance(scale=scales, rotation_xyzw=rotations) | |
| print("Nr gaussians initialized:", means.shape[0]) | |
| gaussians = Gaussians( | |
| means=means.unsqueeze(0), | |
| covariances=covariances.unsqueeze(0), | |
| harmonics=harmonics.unsqueeze(0), # [1, N, 3, sh_d] | |
| opacities=opacities.unsqueeze(0), | |
| scales=scales.unsqueeze(0), | |
| rotations=rotations.unsqueeze(0), | |
| rotations_unnorm=rotations_unnorm.unsqueeze(0), | |
| ) | |
| return InitializerOutput( | |
| gaussians=gaussians, | |
| features=None, | |
| depths=None | |
| ) | |