File size: 8,210 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Optional

import numpy as np
import torch
import torch.nn.functional as F

from optgs.experimental.initializers_utils import knn, points_to_gaussians
from optgs.model.types import Gaussians
from optgs.scene_trainer.common.gaussian_adapter import build_covariance
from optgs.scene_trainer.initializer.initializer import NonlearnedInitializer, InitializerOutput, NonlearnedInitializerCfg


@dataclass
class InitializerPointcloudCfg(NonlearnedInitializerCfg):
    name: Literal["pointcloud"]
    path: Path  # Directory containing <scene_id>.ply files
    scaling_factor: float
    init_opacity: float
    sh_degree: int
    filter_zero_rgb: bool
    # 4x4 world transform applied to point cloud positions.
    # Needed when the PLY is in a different coordinate system than the camera poses.
    # For ScanNet++/NeRFstudio: the PLY is in COLMAP space while cameras are in
    # NeRFstudio space. The transform is (x,y,z) -> (y,x,-z), i.e.:
    #   [[0,1,0,0],[1,0,0,0],[0,0,-1,0],[0,0,0,1]]
    # Set to null to skip.
    world_transform: Optional[list]

    def get_gaussian_param_num(self):
        sh_d = (self.sh_degree + 1) ** 2
        return 3 + 4 + 3 * sh_d + 2 + 1

    def get_sh_d(self):
        return (self.sh_degree + 1) ** 2


class InitializerPointcloud(NonlearnedInitializer[InitializerPointcloudCfg]):
    def __init__(self, cfg: InitializerPointcloudCfg) -> None:
        super().__init__(cfg)

    @staticmethod
    def _load_ply(ply_path: Path) -> tuple[np.ndarray, np.ndarray]:
        """Load Open3D binary PLY file.

        Returns:
            xyz: [N, 3] float32 array of 3D positions
            rgb: [N, 3] uint8 array of colors
        """
        with open(ply_path, "rb") as f:
            num_vertices = 0
            while True:
                line = f.readline().decode("ascii").strip()
                if line.startswith("element vertex"):
                    num_vertices = int(line.split()[-1])
                if line == "end_header":
                    break

            dtype = np.dtype([
                ("x", "<f8"), ("y", "<f8"), ("z", "<f8"),
                ("r", "u1"), ("g", "u1"), ("b", "u1"),
            ])
            data = np.frombuffer(f.read(num_vertices * dtype.itemsize), dtype=dtype)

        xyz = np.stack([data["x"], data["y"], data["z"]], axis=1).astype(np.float32)
        rgb = np.stack([data["r"], data["g"], data["b"]], axis=1)
        return xyz, rgb

    def forward(
            self,
            context,
            visualization_dump: Optional[dict] = None,
            device: Optional[torch.device] = None,
            **kwargs
    ) -> InitializerOutput:
        # Resolve PLY path
        if "scene" in kwargs:
            scene_name = kwargs["scene"]
            assert len(scene_name) == 1, f"Only single scene initialization supported. {scene_name}"
            scene_name = scene_name[0]
            ply_path = self.cfg.path / f"{scene_name}.ply"
        else:
            raise ValueError("Scene name is required for pointcloud initializer.")

        if not ply_path.exists():
            raise ValueError(f"PLY file {ply_path} does not exist.")

        print(f"Loading point cloud from {ply_path}")

        # Load PLY
        points_xyz, points_rgb = self._load_ply(ply_path)
        print(f"Loaded {points_xyz.shape[0]} points.")

        xyz = torch.from_numpy(points_xyz).float().to(device)
        rgbs = torch.from_numpy(points_rgb / 255.0).float().to(device)

        # Apply world transform to align point cloud with camera coordinate system
        if self.cfg.world_transform is not None:
            T = torch.tensor(self.cfg.world_transform, dtype=torch.float32, device=device)
            # Transform: new_xyz = (T @ [xyz, 1])[:3]
            xyz_h = torch.cat([xyz, torch.ones(xyz.shape[0], 1, device=device)], dim=-1)  # [N, 4]
            xyz = (T @ xyz_h.T)[:3].T  # [N, 3]

        # Filter zero-RGB points
        if self.cfg.filter_zero_rgb:
            valid_mask = rgbs.sum(dim=-1) > 0
            xyz = xyz[valid_mask]
            rgbs = rgbs[valid_mask]

        # ── Step 1: subsampling augmentation ─────────────────────────────────────
        min_sub = self.cfg.train_min_gaussians_subsample if self.training else self.cfg.eval_min_gaussians_subsample
        max_sub = self.cfg.train_max_gaussians_subsample if self.training else self.cfg.eval_max_gaussians_subsample

        if min_sub is not None or max_sub is not None:
            target_count = self._sample_num_gaussians(xyz.shape[0], min_sub, max_sub)
            if xyz.shape[0] > target_count:
                indices = torch.randperm(xyz.shape[0], device=xyz.device)[:target_count]
                xyz = xyz[indices]
                rgbs = rgbs[indices]

        # ── Step 2: subsample to fixed count (for DDP consistency) ────────────
        fixed_num = self.cfg.train_fixed_gaussians_num if self.training else self.cfg.eval_fixed_gaussians_num
        if fixed_num is not None and xyz.shape[0] > fixed_num:
            indices = torch.randperm(xyz.shape[0], device=xyz.device)[:fixed_num]
            xyz = xyz[indices]
            rgbs = rgbs[indices]

        # KNN β†’ scales
        dist2_avg = (knn(xyz, 4)[:, 1:] ** 2).mean(dim=-1)  # [N,]
        dist_avg = torch.sqrt(dist2_avg)
        scales = dist_avg.unsqueeze(-1).repeat(1, 3)  # [N, 3]
        opacities = torch.full((xyz.shape[0],), self.cfg.init_opacity)

        # Pad to fixed size for distributed training
        if self.training and fixed_num is not None:
            current_num = xyz.shape[0]
            if current_num < fixed_num:
                pad_size = fixed_num - current_num
                xyz = F.pad(xyz, (0, 0, 0, pad_size), mode='constant', value=0.0)
                rgbs = F.pad(rgbs, (0, 0, 0, pad_size), mode='constant', value=0.0)
                scales = F.pad(scales, (0, 0, 0, pad_size), mode='constant', value=1e-10)
                opacities = F.pad(opacities, (0, pad_size), mode='constant', value=1e-10)

        points_dict = {
            "xyz": xyz,
            "rgb": rgbs,
            "scales": scales * self.cfg.scaling_factor,
            "opacities": opacities,
        }

        # Convert to Gaussian representation
        gaussians_dict = points_to_gaussians(points_dict, sh_degree=self.cfg.sh_degree, device=device)

        means = gaussians_dict["xyz"]
        sh0 = gaussians_dict["sh0"]
        shN = gaussians_dict["shN"]
        harmonics = torch.cat([sh0, shN], dim=1)  # [N, sh_d, 3]
        harmonics = harmonics.permute(0, 2, 1)  # [N, 3, sh_d]
        rotations_unnorm = gaussians_dict["rotations_unnorm"]

        opacities = torch.sigmoid(gaussians_dict["opacities_raw"])
        scales = torch.exp(gaussians_dict["scales_raw"])
        rotations = F.normalize(gaussians_dict["rotations_unnorm"], dim=-1)
        covariances = build_covariance(scale=scales, rotation_xyzw=rotations)

        gaussians = Gaussians(
            means=means.unsqueeze(0),
            covariances=covariances.unsqueeze(0),
            harmonics=harmonics.unsqueeze(0),
            opacities=opacities.unsqueeze(0),
            scales=scales.unsqueeze(0),
            rotations=rotations.unsqueeze(0),
            rotations_unnorm=rotations_unnorm.unsqueeze(0),
        )

        return InitializerOutput(
            gaussians=gaussians,
            features=None,
            depths=None,
        )

    @staticmethod
    def _sample_num_gaussians(available: int, min_sub: int | float | None, max_sub: int | float | None) -> int:
        """Sample a target Gaussian count from the [min_sub, max_sub] range."""
        if min_sub is None:
            min_sub = max_sub
        if max_sub is None:
            max_sub = min_sub

        if isinstance(min_sub, int):
            target = torch.randint(min_sub, max_sub + 1, (1,)).item()
        else:  # float β†’ ratio of available
            ratio = torch.empty(1).uniform_(min_sub, max_sub).item()
            target = int(available * ratio)

        return min(target, available)