File size: 7,357 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from dataclasses import dataclass
from typing import Literal, Optional

from optgs.dataset.data_types import BatchedViews
import numpy as np
import torch
import math
import torch.nn.functional as F
from pathlib import Path
from optgs.experimental.edgs.init import init_gaussians_with_corr
from optgs.experimental.initializers_utils import knn, points_to_gaussians
from optgs.model.types import Gaussians
from optgs.scene_trainer.common.gaussian_adapter import build_covariance
from optgs.scene_trainer.initializer.initializer import InitializerOutput, NonlearnedInitializer, NonlearnedInitializerCfg


@dataclass
class InitializerEdgsCfg(NonlearnedInitializerCfg):
    name: Literal["edgs"]
    sh_degree: int
    init_opacity: float
    scaling_factor: float
    roma_model_type: str

    sample_init_gaussians: int  # if >0, randomly sample this many gaussians from the initialized set

    def get_gaussian_param_num(self):
        # calculate the number of parameters per Gaussian
        sh_d = self.get_sh_d()
        # TODO Naama: check where this is used, and if it is needed
        init_gaussian_param_num = 3 + 4 + 3 * sh_d + 2 + 1
        return init_gaussian_param_num

    def get_sh_d(self):
        sh_d = (self.sh_degree + 1) ** 2
        return sh_d


class InitializerEdgs(NonlearnedInitializer[InitializerEdgsCfg]):
    def __init__(self, cfg: InitializerEdgsCfg) -> None:
        super().__init__(cfg)

    def forward(
            self,
            context: BatchedViews,
            visualization_dump: Optional[dict] = None,
            cached_data_path: Optional[Path] = None,
            **kwargs
    ) -> InitializerOutput:

        device = context["extrinsics"].device

        # unpack context (batch_dim = 1)
        viewpoints_img = context["image"].squeeze(0)  # [N, 3, H, W]
        h, w = viewpoints_img.shape[2], viewpoints_img.shape[3]

        # poses
        viewpoints_c2w = context["extrinsics"].squeeze(0).clone()  # [N, 4, 4]
        camera_centers = viewpoints_c2w[..., :3, 3]
        viewpoints_w2c = torch.inverse(viewpoints_c2w)  # [N, 4, 4]

        # convert to column-major
        viewpoints_w2c = viewpoints_w2c.permute(0, 2, 1)
        
        # intrinsics
        viewpoints_intrinsics = context["intrinsics"].squeeze(0).clone()  # [N, 3, 3]
        # un-normalize intrinsics by multiplying by image size
        viewpoints_intrinsics[:, 0, :] *= w
        viewpoints_intrinsics[:, 1, :] *= h

        def getProjectionMatrix(znear, zfar, fovX, fovY):
            tanHalfFovY = math.tan((fovY / 2))
            tanHalfFovX = math.tan((fovX / 2))

            top = tanHalfFovY * znear
            bottom = -top
            right = tanHalfFovX * znear
            left = -right

            P = torch.zeros(4, 4)

            z_sign = 1.0

            P[0, 0] = 2.0 * znear / (right - left)
            P[1, 1] = 2.0 * znear / (top - bottom)
            P[0, 2] = (right + left) / (right - left)
            P[1, 2] = (top + bottom) / (top - bottom)
            P[3, 2] = z_sign
            P[2, 2] = z_sign * zfar / (zfar - znear)
            P[2, 3] = -(zfar * znear) / (zfar - znear)
            return P

        def focal2fov(focal, pixels):
            return 2 * math.atan(pixels / (2 * focal))

        viewpoints_proj = []
        for idx, intrinsic in enumerate(viewpoints_intrinsics):
            fx = intrinsic[0, 0]
            fy = intrinsic[1, 1]
            znear = 0.01
            zfar = 100.0
            fovY = focal2fov(fy, h)
            fovX = focal2fov(fx, w)
            proj = getProjectionMatrix(
                znear=znear, zfar=zfar, fovX=fovX, fovY=fovY
            ).transpose(0, 1).cuda()
            viewpoints_proj.append(proj)
        viewpoints_proj = torch.stack(viewpoints_proj, dim=0)  # [N, 4, 4]

        # compute full projection matrices
        viewpoints_full_proj = (viewpoints_w2c.bmm(viewpoints_proj))  # [N, 4, 4]

        # check if points_dict is stored on disk already (cached)
        found_cached = False
        if cached_data_path is not None:
            print("Checking for cached points_dict at:", str(cached_data_path))
            cache_path = cached_data_path / "points_dict.pt"
            if cache_path.exists():
                points_dict = torch.load(cache_path)
                print("Loaded cached points_dict from:", str(cache_path))
                found_cached = True
            else:
                print("No cached points_dict found at:", str(cache_path))

        if not found_cached:
            # recompute points_dict
            _, _, points_dict = init_gaussians_with_corr(
                viewpoints_img=viewpoints_img,  # [N, 3, H, W]
                viewpoints_w2c=viewpoints_w2c,  # [N, 4, 4]
                viewpoints_proj=viewpoints_full_proj,  # [N, 4, 4]
                camera_centers=camera_centers,  # [N, 3]
                init_opacity=self.cfg.init_opacity,
                roma_model_type=self.cfg.roma_model_type,
                verbose=False
            )
            if cached_data_path is not None:
                print("Saving points_dict to cache at:", str(cache_path))
                cached_data_path.mkdir(parents=True, exist_ok=True)
                torch.save(points_dict, cache_path)

        points_dict["scales"] *= self.cfg.scaling_factor
        
        # printing some stats
        for k, v in points_dict.items():
            print(f"points_dict[{k}]: shape={v.shape}, dtype={v.dtype}, min={v.min().item()}, max={v.max().item()}")
        
        # downsample if needed
        if self.cfg.sample_init_gaussians > 0:
            # randomly sample a subset of gaussians
            total_points = points_dict["xyz"].shape[0]
            sample_num = min(self.cfg.sample_init_gaussians, total_points)
            sampled_indices = torch.randperm(total_points)[:sample_num]
            points_dict = {k: v[sampled_indices] for k, v in points_dict.items()}
            print("Nr points after sampling:", points_dict["xyz"].shape[0])

        
        # pre-activation values on device
        gaussians_dict = points_to_gaussians(points_dict, sh_degree=self.cfg.sh_degree, device=device)

        means = gaussians_dict["xyz"]
        sh0 = gaussians_dict["sh0"]
        shN = gaussians_dict["shN"]
        harmonics = torch.cat([sh0, shN], dim=1)  # [N, sh_d, 3]
        harmonics = harmonics.permute(0, 2, 1)  # [N, 3, sh_d]
        rotations_unnorm = gaussians_dict["rotations_unnorm"]

        # post-activation values
        opacities = torch.sigmoid(gaussians_dict["opacities_raw"])
        scales = torch.exp(gaussians_dict["scales_raw"])
        rotations = F.normalize(gaussians_dict["rotations_unnorm"], dim=-1)
        covariances = build_covariance(scale=scales, rotation_xyzw=rotations)
        
        print("Nr gaussians initialized:", means.shape[0])

        gaussians = Gaussians(
            means=means.unsqueeze(0),
            covariances=covariances.unsqueeze(0),
            harmonics=harmonics.unsqueeze(0),  # [1, N, 3, sh_d]
            opacities=opacities.unsqueeze(0),
            scales=scales.unsqueeze(0),
            rotations=rotations.unsqueeze(0),
            rotations_unnorm=rotations_unnorm.unsqueeze(0),
        )

        return InitializerOutput(
            gaussians=gaussians,
            features=None,
            depths=None
        )