Spaces:
Sleeping
Sleeping
File size: 7,357 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | from dataclasses import dataclass
from typing import Literal, Optional
from optgs.dataset.data_types import BatchedViews
import numpy as np
import torch
import math
import torch.nn.functional as F
from pathlib import Path
from optgs.experimental.edgs.init import init_gaussians_with_corr
from optgs.experimental.initializers_utils import knn, points_to_gaussians
from optgs.model.types import Gaussians
from optgs.scene_trainer.common.gaussian_adapter import build_covariance
from optgs.scene_trainer.initializer.initializer import InitializerOutput, NonlearnedInitializer, NonlearnedInitializerCfg
@dataclass
class InitializerEdgsCfg(NonlearnedInitializerCfg):
name: Literal["edgs"]
sh_degree: int
init_opacity: float
scaling_factor: float
roma_model_type: str
sample_init_gaussians: int # if >0, randomly sample this many gaussians from the initialized set
def get_gaussian_param_num(self):
# calculate the number of parameters per Gaussian
sh_d = self.get_sh_d()
# TODO Naama: check where this is used, and if it is needed
init_gaussian_param_num = 3 + 4 + 3 * sh_d + 2 + 1
return init_gaussian_param_num
def get_sh_d(self):
sh_d = (self.sh_degree + 1) ** 2
return sh_d
class InitializerEdgs(NonlearnedInitializer[InitializerEdgsCfg]):
def __init__(self, cfg: InitializerEdgsCfg) -> None:
super().__init__(cfg)
def forward(
self,
context: BatchedViews,
visualization_dump: Optional[dict] = None,
cached_data_path: Optional[Path] = None,
**kwargs
) -> InitializerOutput:
device = context["extrinsics"].device
# unpack context (batch_dim = 1)
viewpoints_img = context["image"].squeeze(0) # [N, 3, H, W]
h, w = viewpoints_img.shape[2], viewpoints_img.shape[3]
# poses
viewpoints_c2w = context["extrinsics"].squeeze(0).clone() # [N, 4, 4]
camera_centers = viewpoints_c2w[..., :3, 3]
viewpoints_w2c = torch.inverse(viewpoints_c2w) # [N, 4, 4]
# convert to column-major
viewpoints_w2c = viewpoints_w2c.permute(0, 2, 1)
# intrinsics
viewpoints_intrinsics = context["intrinsics"].squeeze(0).clone() # [N, 3, 3]
# un-normalize intrinsics by multiplying by image size
viewpoints_intrinsics[:, 0, :] *= w
viewpoints_intrinsics[:, 1, :] *= h
def getProjectionMatrix(znear, zfar, fovX, fovY):
tanHalfFovY = math.tan((fovY / 2))
tanHalfFovX = math.tan((fovX / 2))
top = tanHalfFovY * znear
bottom = -top
right = tanHalfFovX * znear
left = -right
P = torch.zeros(4, 4)
z_sign = 1.0
P[0, 0] = 2.0 * znear / (right - left)
P[1, 1] = 2.0 * znear / (top - bottom)
P[0, 2] = (right + left) / (right - left)
P[1, 2] = (top + bottom) / (top - bottom)
P[3, 2] = z_sign
P[2, 2] = z_sign * zfar / (zfar - znear)
P[2, 3] = -(zfar * znear) / (zfar - znear)
return P
def focal2fov(focal, pixels):
return 2 * math.atan(pixels / (2 * focal))
viewpoints_proj = []
for idx, intrinsic in enumerate(viewpoints_intrinsics):
fx = intrinsic[0, 0]
fy = intrinsic[1, 1]
znear = 0.01
zfar = 100.0
fovY = focal2fov(fy, h)
fovX = focal2fov(fx, w)
proj = getProjectionMatrix(
znear=znear, zfar=zfar, fovX=fovX, fovY=fovY
).transpose(0, 1).cuda()
viewpoints_proj.append(proj)
viewpoints_proj = torch.stack(viewpoints_proj, dim=0) # [N, 4, 4]
# compute full projection matrices
viewpoints_full_proj = (viewpoints_w2c.bmm(viewpoints_proj)) # [N, 4, 4]
# check if points_dict is stored on disk already (cached)
found_cached = False
if cached_data_path is not None:
print("Checking for cached points_dict at:", str(cached_data_path))
cache_path = cached_data_path / "points_dict.pt"
if cache_path.exists():
points_dict = torch.load(cache_path)
print("Loaded cached points_dict from:", str(cache_path))
found_cached = True
else:
print("No cached points_dict found at:", str(cache_path))
if not found_cached:
# recompute points_dict
_, _, points_dict = init_gaussians_with_corr(
viewpoints_img=viewpoints_img, # [N, 3, H, W]
viewpoints_w2c=viewpoints_w2c, # [N, 4, 4]
viewpoints_proj=viewpoints_full_proj, # [N, 4, 4]
camera_centers=camera_centers, # [N, 3]
init_opacity=self.cfg.init_opacity,
roma_model_type=self.cfg.roma_model_type,
verbose=False
)
if cached_data_path is not None:
print("Saving points_dict to cache at:", str(cache_path))
cached_data_path.mkdir(parents=True, exist_ok=True)
torch.save(points_dict, cache_path)
points_dict["scales"] *= self.cfg.scaling_factor
# printing some stats
for k, v in points_dict.items():
print(f"points_dict[{k}]: shape={v.shape}, dtype={v.dtype}, min={v.min().item()}, max={v.max().item()}")
# downsample if needed
if self.cfg.sample_init_gaussians > 0:
# randomly sample a subset of gaussians
total_points = points_dict["xyz"].shape[0]
sample_num = min(self.cfg.sample_init_gaussians, total_points)
sampled_indices = torch.randperm(total_points)[:sample_num]
points_dict = {k: v[sampled_indices] for k, v in points_dict.items()}
print("Nr points after sampling:", points_dict["xyz"].shape[0])
# pre-activation values on device
gaussians_dict = points_to_gaussians(points_dict, sh_degree=self.cfg.sh_degree, device=device)
means = gaussians_dict["xyz"]
sh0 = gaussians_dict["sh0"]
shN = gaussians_dict["shN"]
harmonics = torch.cat([sh0, shN], dim=1) # [N, sh_d, 3]
harmonics = harmonics.permute(0, 2, 1) # [N, 3, sh_d]
rotations_unnorm = gaussians_dict["rotations_unnorm"]
# post-activation values
opacities = torch.sigmoid(gaussians_dict["opacities_raw"])
scales = torch.exp(gaussians_dict["scales_raw"])
rotations = F.normalize(gaussians_dict["rotations_unnorm"], dim=-1)
covariances = build_covariance(scale=scales, rotation_xyzw=rotations)
print("Nr gaussians initialized:", means.shape[0])
gaussians = Gaussians(
means=means.unsqueeze(0),
covariances=covariances.unsqueeze(0),
harmonics=harmonics.unsqueeze(0), # [1, N, 3, sh_d]
opacities=opacities.unsqueeze(0),
scales=scales.unsqueeze(0),
rotations=rotations.unsqueeze(0),
rotations_unnorm=rotations_unnorm.unsqueeze(0),
)
return InitializerOutput(
gaussians=gaussians,
features=None,
depths=None
)
|