Spaces:
Sleeping
Sleeping
File size: 16,099 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 | from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Optional, Tuple
import os
import tempfile
import numpy as np
import torch
import torch.nn.functional as F
from plyfile import PlyData
from optgs.dataset.colmap.utils import Parser
from optgs.dataset.data_types import BatchedViews
from optgs.experimental.initializers_utils import knn, points_to_gaussians
from optgs.misc.general_utils import SkipBatchException
from optgs.model.types import Gaussians
from optgs.scene_trainer.common.gaussian_adapter import build_covariance
from optgs.scene_trainer.initializer.initializer import NonlearnedInitializer, InitializerOutput, NonlearnedInitializerCfg
@dataclass
class InitializerColmapCfg(NonlearnedInitializerCfg):
name: Literal["colmap"]
path: Path
normalize_world_space: bool
scaling_factor: float
init_opacity: float
sh_degree: int
dl3dv_settings: bool
filter_zero_rgb: bool
randomize_opacity: bool
randomize_opacity_distribution: Literal["uniform", "gaussian"]
randomize_opacity_std: float # Standard deviation for gaussian distribution
randomize_opacity_min: float # Minimum value for uniform distribution
points3d_subdir: Optional[str] # if set, overrides dl3dv_settings/default subdir logic
points3d_ply_filename: Optional[str] # if set, loads points from this PLY file (relative to scene dir) instead of COLMAP binary
override_dataset_poses: bool # if true, overrides the dataset poses with the COLMAP poses (after applying T_world transform)
def get_gaussian_param_num(self):
# calculate the number of parameters per Gaussian
sh_d = self.get_sh_d()
init_gaussian_param_num = 3 + 4 + 3 * sh_d + 2 + 1
return init_gaussian_param_num
def get_sh_d(self):
sh_d = (self.sh_degree + 1) ** 2
return sh_d
class InitializerColmap(NonlearnedInitializer[InitializerColmapCfg]):
def __init__(self, cfg: InitializerColmapCfg) -> None:
super().__init__(cfg)
def _npz_path(self, datadir: Path) -> Path:
suffix = "_norm" if self.cfg.normalize_world_space else ""
if self.cfg.points3d_ply_filename is not None:
ply_stem = Path(self.cfg.points3d_ply_filename).stem
return datadir / f"colmap_points_cache_ply_{ply_stem}{suffix}.npz"
return datadir / f"colmap_points_cache{suffix}.npz"
def _load_colmap(self, datadir: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Load COLMAP points/colors/poses.
On first access, parses the raw COLMAP binary files (or a PLY file when
``points3d_ply_filename`` is set) and saves a compact .npz next to the
scene folder. On subsequent calls only the tiny .npz is loaded.
"""
npz_path = self._npz_path(datadir)
if npz_path.exists():
try:
data = np.load(npz_path)
return data["points"], data["points_rgb"], data["camtoworlds"]
except PermissionError:
print(f"Warning: No read permission for cache {npz_path}. Attempting to delete and regenerate.")
try:
os.unlink(npz_path)
except Exception as del_e:
print(f"Warning: Could not delete {npz_path} ({del_e}). Will re-parse but cannot cache.")
except Exception as e:
print(f"Warning: Failed to load cache {npz_path} ({e}). Re-parsing COLMAP data.")
# Always parse COLMAP cameras/images for the poses.
parser = Parser(
data_dir=str(datadir),
factor=1,
normalize=self.cfg.normalize_world_space,
load_images=False,
dl3dv_settings=False,
points3d_subdir=self.cfg.points3d_subdir,
verbose=False,
)
camtoworlds = parser.camtoworlds # (M, 4, 4) float64
if self.cfg.points3d_ply_filename is not None:
# Load 3-D points from a PLY file located directly in the scene dir.
ply_path = datadir / self.cfg.points3d_ply_filename
if not ply_path.exists():
raise IOError(f"PLY file not found: {ply_path}")
plydata = PlyData.read(str(ply_path))
vertex = plydata["vertex"]
points = np.stack([
np.asarray(vertex["x"]),
np.asarray(vertex["y"]),
np.asarray(vertex["z"]),
], axis=1).astype(np.float32)
points_rgb = np.stack([
np.asarray(vertex["red"]),
np.asarray(vertex["green"]),
np.asarray(vertex["blue"]),
], axis=1).astype(np.uint8)
else:
points = parser.points # (N, 3) float32
points_rgb = parser.points_rgb # (N, 3) uint8
# TODO Patricia: Fix permission denied
# Write atomically with a temp file that already ends in .npz.
try:
tmp_path = ''
tmp_fd, tmp_path = tempfile.mkstemp(dir=datadir, suffix=".npz")
os.close(tmp_fd)
np.savez_compressed(tmp_path, points=points, points_rgb=points_rgb, camtoworlds=camtoworlds)
os.chmod(tmp_path, 0o664) # group-readable so other users can use this cache
os.replace(tmp_path, npz_path) # atomic on POSIX
except Exception:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
print(f"Warning: Failed to save COLMAP cache to {npz_path}. This may cause slow loading in the future.")
return points, points_rgb, camtoworlds
def forward(
self,
context: BatchedViews,
visualization_dump: Optional[dict] = None,
device: Optional[torch.device] = None,
**kwargs
) -> InitializerOutput:
verbose = False
# context not used
# assert COLMAP dir exists
if not self.cfg.path.exists():
raise ValueError(f"COLMAP dir {self.cfg.path} does not exist.")
if "scene" in kwargs:
scene_name = kwargs["scene"]
assert len(scene_name) == 1, f"Only single scene initialization supported. {scene_name}"
scene_name = scene_name[0]
if self.cfg.dl3dv_settings:
scene_name = scene_name.replace("dl3dv_", "")
if verbose:
print(f"Initializing scene '{scene_name}' from COLMAP at {self.cfg.path}.")
datadir = self.cfg.path / scene_name
if not datadir.exists():
raise ValueError(f"COLMAP scene dir {datadir} does not exist.")
else:
datadir = self.cfg.path
# run COLMAP parser (cached after first load)
points_xyz, points_rgb, camtoworlds = self._load_colmap(datadir)
if verbose:
print(f"Loaded {points_xyz.shape[0]} points from COLMAP.")
xyz = torch.from_numpy(points_xyz).float().to(device)
rgbs = torch.from_numpy(points_rgb / 255.0).float().to(device)
if self.cfg.filter_zero_rgb:
# Filter out points with 0,0,0 RGB values (these are often outliers in COLMAP reconstructions)
valid_mask = (rgbs.sum(dim=-1) > 0)
xyz = xyz[valid_mask]
rgbs = rgbs[valid_mask]
if self.cfg.dl3dv_settings:
assert "target" in kwargs, "Target key is required in kwargs for COLMAP initializer with dl3dv format."
target = kwargs["target"]
# In some configration we might move the batch to device later, so we want to keep the device consistent
batch_device = target['extrinsics'].device
context_c2w_dataset = context['extrinsics'] # (b, V, 4, 4)
c2w_colmap = torch.from_numpy(camtoworlds).to(device=batch_device,
dtype=context_c2w_dataset.dtype) # (N, 4, 4)
# T_world = c2w_dataset[0] @ c2w_colmap[0].inverse()
# eps = 1e-3
# T_world[T_world.abs() < eps] = 0
# T_world[(T_world - 1.0).abs() < eps] = 1.0
# T_world[(T_world + 1.0).abs() < eps] = -1.0
T_world = torch.tensor([[0., 1., 0., 0.],
[1., 0., 0., 0.],
[0., 0., -1., 0.],
[0., 0., 0., 1.]], device=batch_device,
dtype=context_c2w_dataset.dtype) # hard coded for dl3dv colmap reconstructions
c2w_dataset_predicted = T_world @ c2w_colmap
# Assume only one scene in the batch
context_x_flipped = context['x_flipped'][0]
target_x_flipped = target['x_flipped'][0]
assert context_x_flipped == target_x_flipped, "Context and target x_flipped values must match."
x_flipped = context_x_flipped
flip_transform = torch.eye(4, device=batch_device, dtype=context_c2w_dataset.dtype)
flip_transform[0, 0] = -1.0
if x_flipped:
c2w_dataset_predicted = flip_transform @ c2w_dataset_predicted @ flip_transform
# Overriding the dataset poses with the COLMAP to ensure consistency
if self.cfg.override_dataset_poses:
context_indices = context['index'][0]
new_context_c2w = c2w_dataset_predicted[context_indices]
new_context_c2w = new_context_c2w[None, ...] # (1, V, 4, 4)
context['extrinsics'] = new_context_c2w
target_indices = target['index'][0]
new_target_c2w = c2w_dataset_predicted[target_indices]
new_target_c2w = new_target_c2w[None, ...]
target['extrinsics'] = new_target_c2w
xyz = xyz.to(device)
xyz = T_world.to(device) @ torch.cat([xyz, torch.ones_like(xyz[:, :1])], dim=-1).T
if x_flipped:
xyz[0] *= -1.0
xyz = xyz[:3, :].T
# ββ Step 1: subsampling augmentation βββββββββββββββββββββββββββββββββββββ
min_sub = self.cfg.train_min_gaussians_subsample if self.training else self.cfg.eval_min_gaussians_subsample
max_sub = self.cfg.train_max_gaussians_subsample if self.training else self.cfg.eval_max_gaussians_subsample
if min_sub is not None or max_sub is not None:
target_count = self._sample_num_gaussians(xyz.shape[0], min_sub, max_sub)
if xyz.shape[0] > target_count:
indices = torch.randperm(xyz.shape[0], device=xyz.device)[:target_count]
xyz = xyz[indices]
rgbs = rgbs[indices]
# ββ Step 2: subsample to fixed count before knn (so distances are correct)
# If current number of points exceeds the fixed count, we subsample to the fixed count (for DDP consistency).
fixed_num = self.cfg.train_fixed_gaussians_num if self.training else self.cfg.eval_fixed_gaussians_num
if fixed_num is not None and xyz.shape[0] > fixed_num:
indices = torch.randperm(xyz.shape[0], device=xyz.device)[:fixed_num]
xyz = xyz[indices]
rgbs = rgbs[indices]
if xyz.shape[0] == 0:
black_gaussians_num = (points_rgb == 0).all(axis=-1).sum()
raise SkipBatchException(f"No valid points found in COLMAP data for scene {datadir}. Skipping batch. "
f"Originally {points_xyz.shape[0]} points. Black gaussian num {black_gaussians_num}.")
# ββ Step 3: knn-based scale initialisation βββββββββββββββββββββββββββββββ
dist2_avg = (knn(xyz, 4)[:, 1:] ** 2).mean(dim=-1) # [N,]
dist_avg = torch.sqrt(dist2_avg)
scales = dist_avg.unsqueeze(-1).repeat(1, 3) # [N, 3]
# Initialize opacities with optional randomization
if self.cfg.randomize_opacity:
if self.cfg.randomize_opacity_distribution == "uniform":
# Randomize opacities uniformly between min and max
opacities = (torch.rand(xyz.shape[0], device=xyz.device) * (self.cfg.init_opacity - self.cfg.randomize_opacity_min)) + self.cfg.randomize_opacity_min
elif self.cfg.randomize_opacity_distribution == "gaussian":
# Randomize opacities with a Gaussian distribution
mean = self.cfg.init_opacity
stddev = self.cfg.randomize_opacity_std
opacities = torch.normal(mean, stddev, size=(xyz.shape[0],), device=xyz.device)
opacities = opacities.clamp(0, 1) # Clamp to ensure valid values
else:
raise ValueError(f"Unknown randomize_opacity_distribution: {self.cfg.randomize_opacity_distribution}")
else:
opacities = torch.full((xyz.shape[0],), self.cfg.init_opacity)
nr_valid = xyz.shape[0]
# ββ Step 4: pad to fixed count for DDP consistency βββββββββββββββββββββββ
if fixed_num is not None and xyz.shape[0] < fixed_num:
pad = fixed_num - xyz.shape[0]
xyz = F.pad(xyz, (0, 0, 0, pad), value=0.0)
rgbs = F.pad(rgbs, (0, 0, 0, pad), value=0.0)
scales = F.pad(scales, (0, 0, 0, pad), value=1e-10)
opacities = F.pad(opacities, (0, pad), value=1e-10)
# TODO Naama: might be a problem if we don't freeze zero-grad gaussians
points_dict = {
"xyz": xyz,
"rgb": rgbs,
"scales": scales,
"opacities": opacities,
}
points_dict["scales"] *= self.cfg.scaling_factor
# pre-activation values on device
gaussians_dict = points_to_gaussians(points_dict, sh_degree=self.cfg.sh_degree, device=device)
means = gaussians_dict["xyz"]
sh0 = gaussians_dict["sh0"]
shN = gaussians_dict["shN"]
if shN is not None:
harmonics = torch.cat([sh0, shN], dim=1) # [N, sh_d, 3]
else:
harmonics = sh0
harmonics = harmonics.permute(0, 2, 1) # [N, 3, sh_d]
rotations_unnorm = gaussians_dict["rotations_unnorm"]
# post-activation values
opacities = torch.sigmoid(gaussians_dict["opacities_raw"])
scales = torch.exp(gaussians_dict["scales_raw"])
rotations = F.normalize(gaussians_dict["rotations_unnorm"], dim=-1)
covariances = build_covariance(scale=scales, rotation_xyzw=rotations)
gaussians = Gaussians(
means=means.unsqueeze(0),
covariances=covariances.unsqueeze(0),
harmonics=harmonics.unsqueeze(0), # [1, N, C, sh_d]
opacities=opacities.unsqueeze(0),
scales=scales.unsqueeze(0),
rotations=rotations.unsqueeze(0),
rotations_unnorm=rotations_unnorm.unsqueeze(0),
nr_valid=nr_valid
)
return InitializerOutput(
gaussians=gaussians,
features=None,
depths=None
)
@staticmethod
def _sample_num_gaussians(
available: int,
min_val: int | float | None,
max_val: int | float | None,
) -> int:
if min_val is None and max_val is None:
return available
assert min_val is not None and max_val is not None, \
"Both min and max must be set together for Gaussian subsampling."
assert type(min_val) == type(max_val), \
"min and max must be the same type (both int or both float)."
if isinstance(min_val, int):
count = torch.randint(min_val, max_val + 1, (1,)).item()
else:
assert 0.0 < min_val <= 1.0 and 0.0 < max_val <= 1.0, \
"Float subsampling ratios must be in (0, 1]."
ratio = torch.empty(1).uniform_(min_val, max_val).item()
count = int(available * ratio)
return min(count, available)
|