ritianyu commited on
Commit
e4b4a0d
·
1 Parent(s): 3b93851
InfiniDepth/gs/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- """Lightweight Gaussian Splatting inference utilities."""
2
-
3
- from .types import Gaussians
4
- from .predictor import GSPixelAlignPredictor
5
- from .ply import export_ply
6
-
7
- __all__ = [
8
- "Gaussians",
9
- "GSPixelAlignPredictor",
10
- "export_ply",
11
- ]
 
 
 
 
 
 
 
 
 
 
 
 
InfiniDepth/gs/adapter.py DELETED
@@ -1,90 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from torch import nn
6
-
7
- from .projection import get_world_rays
8
- from .types import Gaussians
9
-
10
-
11
- def rgb_to_sh(rgb: torch.Tensor) -> torch.Tensor:
12
- c0 = 0.28209479177387814
13
- return (rgb - 0.5) / c0
14
-
15
-
16
- @dataclass
17
- class GaussianAdapterCfg:
18
- gaussian_scale_min: float = 1e-10
19
- gaussian_scale_max: float = 5.0
20
- sh_degree: int = 2
21
-
22
-
23
- class GaussianAdapter(nn.Module):
24
- def __init__(self, cfg: GaussianAdapterCfg) -> None:
25
- super().__init__()
26
- self.cfg = cfg
27
- self.register_buffer("sh_mask", torch.ones((self.d_sh,), dtype=torch.float32), persistent=False)
28
- for degree in range(1, self.cfg.sh_degree + 1):
29
- self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * (0.25**degree)
30
-
31
- @property
32
- def d_sh(self) -> int:
33
- return (self.cfg.sh_degree + 1) ** 2
34
-
35
- @property
36
- def d_in(self) -> int:
37
- return 7 + 3 * self.d_sh
38
-
39
- def forward(
40
- self,
41
- image: torch.Tensor,
42
- extrinsics: torch.Tensor,
43
- intrinsics: torch.Tensor,
44
- coordinates_xy: torch.Tensor,
45
- depths: torch.Tensor,
46
- opacities: torch.Tensor,
47
- raw_gaussians: torch.Tensor,
48
- ) -> Gaussians:
49
- """Build world-space gaussians from per-point raw parameters.
50
-
51
- image: [B, 3, H, W]
52
- extrinsics: [B, 4, 4] camera-to-world
53
- intrinsics: [B, 3, 3]
54
- coordinates_xy: [B, N, 2] pixel-space (x, y)
55
- depths: [B, N]
56
- opacities: [B, N]
57
- raw_gaussians: [B, N, 7 + 3*d_sh]
58
- """
59
- b, _, h, w = image.shape
60
- scales_raw, rotations_raw, sh_raw = torch.split(raw_gaussians, [3, 4, 3 * self.d_sh], dim=-1)
61
- scales = torch.clamp(
62
- F.softplus(scales_raw - 4.0),
63
- min=self.cfg.gaussian_scale_min,
64
- max=self.cfg.gaussian_scale_max,
65
- )
66
- rotations = rotations_raw / (torch.norm(rotations_raw, dim=-1, keepdim=True) + 1e-8)
67
-
68
- harmonics = sh_raw.view(b, -1, 3, self.d_sh) * self.sh_mask.view(1, 1, 1, -1)
69
-
70
- # Initialize DC term from image color sampled at gaussian centers.
71
- x = coordinates_xy[..., 0]
72
- y = coordinates_xy[..., 1]
73
- grid_x = (x / max(float(w), 1.0)) * 2.0 - 1.0
74
- grid_y = (y / max(float(h), 1.0)) * 2.0 - 1.0
75
- grid = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(2) # [B, N, 1, 2]
76
- sampled_rgb = F.grid_sample(image, grid, mode="bilinear", align_corners=False)
77
- sampled_rgb = sampled_rgb.squeeze(-1).permute(0, 2, 1) # [B, N, 3]
78
- harmonics[..., 0] = harmonics[..., 0] + rgb_to_sh(sampled_rgb)
79
-
80
- origins, directions = get_world_rays(coordinates_xy, extrinsics, intrinsics)
81
- means = origins + directions * depths.unsqueeze(-1)
82
-
83
- return Gaussians(
84
- means=means,
85
- harmonics=harmonics,
86
- opacities=opacities,
87
- scales=scales,
88
- rotations=rotations,
89
- covariances=None,
90
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
InfiniDepth/gs/ply.py DELETED
@@ -1,232 +0,0 @@
1
- from pathlib import Path
2
-
3
- import numpy as np
4
- import torch
5
- from jaxtyping import Float
6
- from plyfile import PlyData, PlyElement
7
- from torch import Tensor
8
-
9
- def _construct_attributes(d_sh: int) -> list[str]:
10
- attrs = ["x", "y", "z", "nx", "ny", "nz", "f_dc_0", "f_dc_1", "f_dc_2"]
11
- n_rest = 3 * max(d_sh - 1, 0)
12
- attrs.extend([f"f_rest_{i}" for i in range(n_rest)])
13
- attrs.extend(["opacity", "scale_0", "scale_1", "scale_2", "rot_0", "rot_1", "rot_2", "rot_3"])
14
- return attrs
15
-
16
- def export_ply(
17
- means: Float[Tensor, "gaussian 3"],
18
- harmonics: Float[Tensor, "gaussian 3 d_sh"],
19
- opacities: Float[Tensor, " gaussian"],
20
- path: str | Path,
21
- scales: Float[Tensor, "gaussian 3"] | None = None,
22
- rotations: Float[Tensor, "gaussian 4"] | None = None,
23
- covariances: Float[Tensor, "gaussian 3 3"] | None = None, # Use covariances directly
24
- shift_and_scale: bool = True,
25
- save_sh_dc_only: bool = True, # Changed default to False to preserve quality
26
- center_method: str = "mean", # "mean", "median", or "bbox_center"
27
- apply_coordinate_transform: bool = True, # Apply x90° rotation for viewer compatibility
28
- focal_length_px: float | tuple[float, float] | None = None,
29
- image_shape: tuple[int, int] | None = None, # (height, width)
30
- extrinsic_matrix: np.ndarray | torch.Tensor | None = None,
31
- color_space_index: int | None = None,
32
- ):
33
- path = Path(path)
34
-
35
- # Check input consistency
36
- if covariances is None and (scales is None or rotations is None):
37
- raise ValueError("Either provide covariances or both scales and rotations")
38
-
39
- # Fast covariance to scale/rotation conversion using batch operations
40
- if covariances is not None:
41
- # Batch eigenvalue decomposition - much faster than individual decompositions
42
- eigenvalues, eigenvectors = torch.linalg.eigh(covariances)
43
- scales = torch.sqrt(torch.clamp(eigenvalues, min=1e-8))
44
-
45
- # Fast batch conversion from rotation matrices to quaternions
46
- # Using direct mathematical conversion instead of scipy loops
47
- def rotation_matrix_to_quaternion_batch(R):
48
- """Fast batch conversion from rotation matrices to quaternions"""
49
- trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
50
-
51
- # Pre-allocate quaternion tensor
52
- quat = torch.zeros(R.shape[0], 4, dtype=R.dtype, device=R.device)
53
-
54
- # Case 1: trace > 0
55
- mask1 = trace > 0
56
- if mask1.any():
57
- s = torch.sqrt(trace[mask1] + 1.0) * 2 # s = 4 * qw
58
- quat[mask1, 0] = 0.25 * s # qw
59
- quat[mask1, 1] = (R[mask1, 2, 1] - R[mask1, 1, 2]) / s # qx
60
- quat[mask1, 2] = (R[mask1, 0, 2] - R[mask1, 2, 0]) / s # qy
61
- quat[mask1, 3] = (R[mask1, 1, 0] - R[mask1, 0, 1]) / s # qz
62
-
63
- # Case 2: R[0,0] > R[1,1] and R[0,0] > R[2,2]
64
- mask2 = ~mask1 & (R[..., 0, 0] > R[..., 1, 1]) & (R[..., 0, 0] > R[..., 2, 2])
65
- if mask2.any():
66
- s = torch.sqrt(1.0 + R[mask2, 0, 0] - R[mask2, 1, 1] - R[mask2, 2, 2]) * 2
67
- quat[mask2, 0] = (R[mask2, 2, 1] - R[mask2, 1, 2]) / s # qw
68
- quat[mask2, 1] = 0.25 * s # qx
69
- quat[mask2, 2] = (R[mask2, 0, 1] + R[mask2, 1, 0]) / s # qy
70
- quat[mask2, 3] = (R[mask2, 0, 2] + R[mask2, 2, 0]) / s # qz
71
-
72
- # Case 3: R[1,1] > R[2,2]
73
- mask3 = ~mask1 & ~mask2 & (R[..., 1, 1] > R[..., 2, 2])
74
- if mask3.any():
75
- s = torch.sqrt(1.0 + R[mask3, 1, 1] - R[mask3, 0, 0] - R[mask3, 2, 2]) * 2
76
- quat[mask3, 0] = (R[mask3, 0, 2] - R[mask3, 2, 0]) / s # qw
77
- quat[mask3, 1] = (R[mask3, 0, 1] + R[mask3, 1, 0]) / s # qx
78
- quat[mask3, 2] = 0.25 * s # qy
79
- quat[mask3, 3] = (R[mask3, 1, 2] + R[mask3, 2, 1]) / s # qz
80
-
81
- # Case 4: else
82
- mask4 = ~mask1 & ~mask2 & ~mask3
83
- if mask4.any():
84
- s = torch.sqrt(1.0 + R[mask4, 2, 2] - R[mask4, 0, 0] - R[mask4, 1, 1]) * 2
85
- quat[mask4, 0] = (R[mask4, 1, 0] - R[mask4, 0, 1]) / s # qw
86
- quat[mask4, 1] = (R[mask4, 0, 2] + R[mask4, 2, 0]) / s # qx
87
- quat[mask4, 2] = (R[mask4, 1, 2] + R[mask4, 2, 1]) / s # qy
88
- quat[mask4, 3] = 0.25 * s # qz
89
-
90
- return quat
91
-
92
- # Ensure proper rotation matrices
93
- det = torch.det(eigenvectors)
94
- eigenvectors = torch.where(det.unsqueeze(-1).unsqueeze(-1) < 0,
95
- -eigenvectors, eigenvectors)
96
-
97
- # Fast batch conversion
98
- rotations = rotation_matrix_to_quaternion_batch(eigenvectors)
99
-
100
- # Apply centering - vectorized operations
101
- if shift_and_scale:
102
- if center_method == "mean":
103
- center = means.mean(dim=0)
104
- elif center_method == "median":
105
- center = means.median(dim=0).values
106
- elif center_method == "bbox_center":
107
- center = (means.min(dim=0).values + means.max(dim=0).values) / 2
108
- else:
109
- raise ValueError(f"Unknown center_method: {center_method}")
110
- means = means - center
111
-
112
- # Fast coordinate transformation using batch operations
113
- if apply_coordinate_transform:
114
- # X-axis 90° rotation matrix
115
- rot_x = torch.tensor([
116
- [1, 0, 0],
117
- [0, 0, -1],
118
- [0, 1, 0]
119
- ], dtype=means.dtype, device=means.device)
120
-
121
- # Apply to positions - batch matrix multiplication
122
- means = means @ rot_x.T
123
-
124
- # Apply to rotations - batch quaternion operations
125
- transform_quat = torch.tensor([0.7071068, 0.7071068, 0.0, 0.0],
126
- dtype=rotations.dtype, device=rotations.device) # 90° around X
127
-
128
- # Batch quaternion multiplication
129
- w1, x1, y1, z1 = transform_quat[0], transform_quat[1], transform_quat[2], transform_quat[3]
130
- w2, x2, y2, z2 = rotations[:, 0], rotations[:, 1], rotations[:, 2], rotations[:, 3]
131
-
132
- rotations = torch.stack([
133
- w1*w2 - x1*x2 - y1*y2 - z1*z2, # w
134
- w1*x2 + x1*w2 + y1*z2 - z1*y2, # x
135
- w1*y2 - x1*z2 + y1*w2 + z1*x2, # y
136
- w1*z2 + x1*y2 - y1*x2 + z1*w2 # z
137
- ], dim=1)
138
-
139
- # Convert to numpy for PLY writing - single conversion
140
- means_np = means.detach().cpu().numpy()
141
- scales_np = scales.detach().cpu().numpy()
142
- rotations_np = rotations.detach().cpu().numpy()
143
- opacities_np = opacities.detach().cpu().numpy()
144
- harmonics_np = harmonics.detach().cpu().numpy()
145
-
146
- # Process harmonics
147
- f_dc = harmonics_np[..., 0]
148
- f_rest = harmonics_np[..., 1:].reshape(harmonics_np.shape[0], -1)
149
-
150
- d_sh = harmonics_np.shape[-1]
151
- dtype_full = [
152
- (attribute, "f4")
153
- for attribute in _construct_attributes(1 if save_sh_dc_only else d_sh)
154
- ]
155
- elements = np.empty(means_np.shape[0], dtype=dtype_full)
156
-
157
- # Build attributes list
158
- attributes = [
159
- means_np,
160
- np.zeros_like(means_np), # normals
161
- f_dc,
162
- ]
163
-
164
- if not save_sh_dc_only:
165
- attributes.append(f_rest)
166
-
167
- # Apply inverse sigmoid to opacity for storage (viewer will apply sigmoid when loading)
168
- # logit(opacity) = log(opacity / (1 - opacity))
169
- opacities_clamped = np.clip(opacities_np, 1e-6, 1 - 1e-6) # Clamp to avoid log(0) or log(inf)
170
- opacities_logit = np.log(opacities_clamped / (1 - opacities_clamped))
171
-
172
- attributes.extend([
173
- opacities_logit.reshape(-1, 1),
174
- np.log(scales_np),
175
- rotations_np
176
- ])
177
-
178
- attributes = np.concatenate(attributes, axis=1)
179
- elements[:] = list(map(tuple, attributes))
180
- path.parent.mkdir(exist_ok=True, parents=True)
181
- ply_elements = [PlyElement.describe(elements, "vertex")]
182
-
183
- if focal_length_px is not None and image_shape is not None:
184
- image_height, image_width = image_shape
185
- if isinstance(focal_length_px, tuple):
186
- fx, fy = float(focal_length_px[0]), float(focal_length_px[1])
187
- else:
188
- fx = fy = float(focal_length_px)
189
-
190
- dtype_image_size = [("image_size", "u4")]
191
- image_size_array = np.empty(2, dtype=dtype_image_size)
192
- image_size_array[:] = np.array([image_width, image_height], dtype=np.uint32)
193
- ply_elements.append(PlyElement.describe(image_size_array, "image_size"))
194
-
195
- dtype_intrinsic = [("intrinsic", "f4")]
196
- intrinsic_array = np.empty(9, dtype=dtype_intrinsic)
197
- intrinsic = np.array(
198
- [
199
- fx,
200
- 0.0,
201
- image_width * 0.5,
202
- 0.0,
203
- fy,
204
- image_height * 0.5,
205
- 0.0,
206
- 0.0,
207
- 1.0,
208
- ],
209
- dtype=np.float32,
210
- )
211
- intrinsic_array[:] = intrinsic.flatten()
212
- ply_elements.append(PlyElement.describe(intrinsic_array, "intrinsic"))
213
-
214
- dtype_extrinsic = [("extrinsic", "f4")]
215
- extrinsic_array = np.empty(16, dtype=dtype_extrinsic)
216
- if extrinsic_matrix is None:
217
- extrinsic_np = np.eye(4, dtype=np.float32)
218
- elif torch.is_tensor(extrinsic_matrix):
219
- extrinsic_np = extrinsic_matrix.detach().cpu().numpy().astype(np.float32)
220
- else:
221
- extrinsic_np = np.asarray(extrinsic_matrix, dtype=np.float32)
222
- if extrinsic_np.shape != (4, 4):
223
- raise ValueError(f"extrinsic_matrix must have shape (4,4), got {extrinsic_np.shape}")
224
- extrinsic_array[:] = extrinsic_np.flatten()
225
- ply_elements.append(PlyElement.describe(extrinsic_array, "extrinsic"))
226
-
227
- dtype_color_space = [("color_space", "u1")]
228
- color_space_array = np.empty(1, dtype=dtype_color_space)
229
- color_space_array[:] = np.array([1 if color_space_index is None else color_space_index], dtype=np.uint8)
230
- ply_elements.append(PlyElement.describe(color_space_array, "color_space"))
231
-
232
- PlyData(ply_elements).write(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
InfiniDepth/gs/predictor.py DELETED
@@ -1,139 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- import torch
4
- from torch import nn
5
-
6
- from .adapter import GaussianAdapter, GaussianAdapterCfg
7
- from .projection import sample_image_grid
8
- from .types import Gaussians
9
-
10
-
11
- @dataclass
12
- class GSPredictorCfg:
13
- rgb_feature_dim: int = 64
14
- depth_feature_dim: int = 32
15
- dino_reduced_dim: int = 128
16
- gaussian_regressor_channels: int = 64
17
- num_surfaces: int = 1
18
- gaussian_scale_min: float = 1e-10
19
- gaussian_scale_max: float = 5.0
20
- sh_degree: int = 2
21
-
22
-
23
- class GSPixelAlignPredictor(nn.Module):
24
- def __init__(self, dino_feature_dim: int = 1024, cfg: GSPredictorCfg | None = None) -> None:
25
- super().__init__()
26
- self.cfg = cfg or GSPredictorCfg()
27
- cfg = self.cfg
28
-
29
- self.rgb_encoder = nn.Sequential(
30
- nn.Conv2d(3, 32, 3, 1, 1),
31
- nn.GELU(),
32
- nn.Conv2d(32, cfg.rgb_feature_dim, 3, 1, 1),
33
- nn.GELU(),
34
- )
35
- self.depth_encoder = nn.Sequential(
36
- nn.Conv2d(1, 16, 3, 1, 1),
37
- nn.GELU(),
38
- nn.Conv2d(16, cfg.depth_feature_dim, 3, 1, 1),
39
- nn.GELU(),
40
- )
41
- self.dino_projector = nn.Sequential(
42
- nn.Conv2d(dino_feature_dim, 256, 1),
43
- nn.GELU(),
44
- nn.Conv2d(256, cfg.dino_reduced_dim, 1),
45
- )
46
-
47
- reg_in = cfg.rgb_feature_dim + cfg.depth_feature_dim + cfg.dino_reduced_dim
48
- self.gaussian_regressor = nn.Sequential(
49
- nn.Conv2d(reg_in, cfg.gaussian_regressor_channels, 3, 1, 1),
50
- nn.GELU(),
51
- nn.Conv2d(cfg.gaussian_regressor_channels, cfg.gaussian_regressor_channels, 3, 1, 1),
52
- )
53
-
54
- self.gaussian_adapter = GaussianAdapter(
55
- GaussianAdapterCfg(
56
- gaussian_scale_min=cfg.gaussian_scale_min,
57
- gaussian_scale_max=cfg.gaussian_scale_max,
58
- sh_degree=cfg.sh_degree,
59
- )
60
- )
61
-
62
- num_gaussian_parameters = self.gaussian_adapter.d_in + 2 + 1
63
- head_in = cfg.gaussian_regressor_channels + cfg.rgb_feature_dim + cfg.dino_reduced_dim
64
- self.gaussian_head = nn.Sequential(
65
- nn.Conv2d(head_in, num_gaussian_parameters, 3, 1, 1, padding_mode="replicate"),
66
- nn.GELU(),
67
- nn.Conv2d(num_gaussian_parameters, num_gaussian_parameters, 3, 1, 1, padding_mode="replicate"),
68
- )
69
-
70
- @torch.no_grad()
71
- def load_from_infinisplat_checkpoint(self, checkpoint_path: str) -> None:
72
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
73
- state_dict = checkpoint.get("state_dict", checkpoint)
74
-
75
- own_sd = self.state_dict()
76
- load_sd = {}
77
- for k, _ in own_sd.items():
78
- prefixed = f"encoder.{k}"
79
- if prefixed in state_dict and state_dict[prefixed].shape == own_sd[k].shape:
80
- load_sd[k] = state_dict[prefixed]
81
- self.load_state_dict(load_sd, strict=False)
82
-
83
- def _tokens_to_feature_map(self, dino_tokens: torch.Tensor, h: int, w: int) -> torch.Tensor:
84
- b, n_all, c = dino_tokens.shape
85
- patch_h = h // 16
86
- patch_w = w // 16
87
- n_patch = patch_h * patch_w
88
- if n_all < n_patch:
89
- raise ValueError(f"Invalid token count: got {n_all}, expected at least {n_patch}")
90
- n_reg = n_all - n_patch
91
- patch_tokens = dino_tokens[:, n_reg:, :] # [B, patch_h*patch_w, C]
92
- patch_tokens = patch_tokens.reshape(b, patch_h, patch_w, c).permute(0, 3, 1, 2)
93
- return torch.nn.functional.interpolate(
94
- patch_tokens, size=(h, w), mode="bilinear", align_corners=False
95
- )
96
-
97
- def forward(
98
- self,
99
- image: torch.Tensor,
100
- depthmap: torch.Tensor,
101
- dino_tokens: torch.Tensor,
102
- intrinsics: torch.Tensor,
103
- extrinsics: torch.Tensor,
104
- ) -> Gaussians:
105
- b, _, h, w = image.shape
106
- dino_map = self._tokens_to_feature_map(dino_tokens, h, w)
107
-
108
- rgb_feat = self.rgb_encoder(image)
109
- depth_feat = self.depth_encoder(depthmap)
110
- dino_feat = self.dino_projector(dino_map)
111
-
112
- reg_input = torch.cat([rgb_feat, depth_feat, dino_feat], dim=1)
113
- reg_feat = self.gaussian_regressor(reg_input)
114
- head_input = torch.cat([reg_feat, rgb_feat, dino_feat], dim=1)
115
- raw = self.gaussian_head(head_input) # [B, Cg, H, W]
116
-
117
- raw = raw.permute(0, 2, 3, 1).reshape(b, h * w, -1) # [B, HW, Cg]
118
- opacities = torch.sigmoid(raw[..., :1]).squeeze(-1) # [B, HW]
119
- gaussian_core = raw[..., 1:] # [B, HW, Cg-1]
120
-
121
- # One surface per pixel in this lightweight integration.
122
- offset_xy = torch.sigmoid(gaussian_core[..., :2]) # [B, HW, 2], in [0,1]
123
- raw_gaussians = gaussian_core[..., 2:] # [B, HW, 7+3*d_sh]
124
-
125
- base = sample_image_grid(h, w, image.device).unsqueeze(0).expand(b, -1, -1)
126
- coords = base + (offset_xy - 0.5)
127
- coords[..., 0].clamp_(0.0, float(w - 1))
128
- coords[..., 1].clamp_(0.0, float(h - 1))
129
-
130
- depths = depthmap[:, 0].reshape(b, -1)
131
- return self.gaussian_adapter(
132
- image=image,
133
- extrinsics=extrinsics,
134
- intrinsics=intrinsics,
135
- coordinates_xy=coords,
136
- depths=depths,
137
- opacities=opacities,
138
- raw_gaussians=raw_gaussians,
139
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
InfiniDepth/gs/projection.py DELETED
@@ -1,53 +0,0 @@
1
- import torch
2
-
3
-
4
- def homogenize_points(points: torch.Tensor) -> torch.Tensor:
5
- return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
6
-
7
-
8
- def homogenize_vectors(vectors: torch.Tensor) -> torch.Tensor:
9
- return torch.cat([vectors, torch.zeros_like(vectors[..., :1])], dim=-1)
10
-
11
-
12
- def transform_cam2world(homogeneous: torch.Tensor, extrinsics: torch.Tensor) -> torch.Tensor:
13
- return torch.matmul(extrinsics, homogeneous.unsqueeze(-1)).squeeze(-1)
14
-
15
-
16
- def unproject(coordinates_xy: torch.Tensor, z: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
17
- """Unproject pixel-space xy to camera space using z depth.
18
-
19
- coordinates_xy: [B, N, 2] in pixel coordinates (x, y)
20
- z: [B, N]
21
- intrinsics: [B, 3, 3] in pixel units
22
- """
23
- coordinates_h = homogenize_points(coordinates_xy) # [B, N, 3]
24
- intr_inv = torch.linalg.inv(intrinsics) # [B, 3, 3]
25
- rays = torch.matmul(intr_inv.unsqueeze(1), coordinates_h.unsqueeze(-1)).squeeze(-1)
26
- return rays * z.unsqueeze(-1)
27
-
28
-
29
- def get_world_rays(
30
- coordinates_xy: torch.Tensor,
31
- extrinsics: torch.Tensor,
32
- intrinsics: torch.Tensor,
33
- ) -> tuple[torch.Tensor, torch.Tensor]:
34
- """Return world-space ray origins and directions.
35
-
36
- coordinates_xy: [B, N, 2] in pixel coordinates (x, y)
37
- extrinsics: [B, 4, 4] camera-to-world
38
- intrinsics: [B, 3, 3] pixel intrinsics
39
- """
40
- ones = torch.ones_like(coordinates_xy[..., 0])
41
- directions_cam = unproject(coordinates_xy, ones, intrinsics)
42
- directions_cam = directions_cam / torch.clamp(directions_cam[..., 2:], min=1e-6)
43
- directions_world = transform_cam2world(homogenize_vectors(directions_cam), extrinsics)[..., :3]
44
- origins_world = extrinsics[:, None, :3, 3].expand_as(directions_world)
45
- return origins_world, directions_world
46
-
47
-
48
- def sample_image_grid(h: int, w: int, device: torch.device) -> torch.Tensor:
49
- """Return pixel center coordinates with shape [H*W, 2], order (x, y)."""
50
- ys = torch.arange(h, device=device, dtype=torch.float32) + 0.5
51
- xs = torch.arange(w, device=device, dtype=torch.float32) + 0.5
52
- grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
53
- return torch.stack([grid_x, grid_y], dim=-1).reshape(-1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
InfiniDepth/gs/types.py DELETED
@@ -1,14 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Optional
3
-
4
- import torch
5
-
6
-
7
- @dataclass
8
- class Gaussians:
9
- means: torch.Tensor # [B, N, 3]
10
- harmonics: torch.Tensor # [B, N, 3, d_sh]
11
- opacities: torch.Tensor # [B, N]
12
- scales: torch.Tensor # [B, N, 3]
13
- rotations: torch.Tensor # [B, N, 4]
14
- covariances: Optional[torch.Tensor] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
InfiniDepth/utils/depth_video_utils.py DELETED
@@ -1,250 +0,0 @@
1
- import os
2
- from typing import Optional
3
-
4
- import cv2
5
- import numpy as np
6
- import torch
7
-
8
- from .inference_utils import default_dir_by_input_file, default_video_file_by_input
9
- from .io_utils import filter_depth_noise_numpy
10
- from .io_utils import save_sampled_point_clouds
11
- from .moge_utils import estimate_metric_depth_with_moge2
12
- from .sampling_utils import SAMPLING_METHODS
13
- from .vis_utils import colorize_depth_maps
14
-
15
-
16
- def prepare_rgb_frame(
17
- frame_bgr: np.ndarray,
18
- input_size: tuple[int, int],
19
- device: torch.device,
20
- ) -> tuple[torch.Tensor, torch.Tensor, tuple[int, int]]:
21
- frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
22
- org_h, org_w = frame_rgb.shape[:2]
23
- org_img = torch.from_numpy(frame_rgb).permute(2, 0, 1).unsqueeze(0).float() / 255.0
24
-
25
- resized = cv2.resize(frame_rgb, input_size[::-1], interpolation=cv2.INTER_AREA)
26
- image = torch.from_numpy(resized).permute(2, 0, 1).unsqueeze(0).float() / 255.0
27
- image = image.to(device)
28
- return org_img, image, (org_h, org_w)
29
-
30
-
31
- def depth_frame_to_metric_depth(depth_frame: np.ndarray, depth_video_scale: float) -> np.ndarray:
32
- if depth_frame.ndim == 3:
33
- # Assume grayscale/depth-like content is stored in channels.
34
- depth_raw = depth_frame[:, :, 0]
35
- else:
36
- depth_raw = depth_frame
37
- return depth_raw.astype(np.float32) / max(depth_video_scale, 1e-8)
38
-
39
-
40
- def sample_sparse_prompt(
41
- depth: np.ndarray,
42
- depth_mask: np.ndarray,
43
- num_samples: int,
44
- ) -> np.ndarray:
45
- valid_depth = depth * depth_mask
46
- if (valid_depth > 0.1).sum() <= num_samples:
47
- return valid_depth
48
-
49
- flat = valid_depth.reshape(-1)
50
- nonzero_index = np.array(list(np.nonzero(flat > 0.1))).squeeze()
51
- index = np.random.permutation(nonzero_index)[:num_samples]
52
- sample_mask = np.ones_like(flat)
53
- sample_mask[index] = 0.0
54
- flat[sample_mask.astype(bool)] = 0.0
55
- return flat.reshape(valid_depth.shape)
56
-
57
-
58
- def prepare_prompt_from_depth_frame(
59
- depth_frame: np.ndarray,
60
- input_size: tuple[int, int],
61
- depth_video_scale: float,
62
- num_samples: int,
63
- min_prompt: float,
64
- max_prompt: float,
65
- enable_noise_filter: bool,
66
- filter_std_threshold: float,
67
- filter_median_threshold: float,
68
- filter_gradient_threshold: float,
69
- filter_min_neighbors: int,
70
- device: torch.device,
71
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
72
- depth = depth_frame_to_metric_depth(depth_frame, depth_video_scale)
73
- depth = cv2.resize(depth, input_size[::-1], interpolation=cv2.INTER_NEAREST)
74
-
75
- if enable_noise_filter:
76
- initial_mask = ((depth > min_prompt) & (depth < max_prompt)).astype(np.float32)
77
- depth, depth_mask = filter_depth_noise_numpy(
78
- depth=depth,
79
- depth_mask=initial_mask,
80
- std_threshold=filter_std_threshold,
81
- median_threshold=filter_median_threshold,
82
- gradient_threshold=filter_gradient_threshold,
83
- min_neighbors=filter_min_neighbors,
84
- )
85
- else:
86
- depth_mask = ((depth > min_prompt) & (depth < max_prompt)).astype(np.float32)
87
-
88
- prompt_depth = sample_sparse_prompt(depth, depth_mask, num_samples=num_samples)
89
-
90
- gt_depth_t = torch.from_numpy(depth).unsqueeze(0).unsqueeze(0).float().to(device)
91
- prompt_depth_t = torch.from_numpy(prompt_depth).unsqueeze(0).unsqueeze(0).float().to(device)
92
- depth_mask_t = torch.from_numpy(depth_mask).unsqueeze(0).unsqueeze(0).float().to(device)
93
- return gt_depth_t, prompt_depth_t, depth_mask_t
94
-
95
-
96
- def ensure_depth_map(pred_depth: torch.Tensor, h_sample: int, w_sample: int) -> torch.Tensor:
97
- if pred_depth.ndim == 4 and pred_depth.shape[-2:] == (h_sample, w_sample):
98
- return pred_depth
99
-
100
- if pred_depth.ndim == 3:
101
- b, d1, d2 = pred_depth.shape
102
- if d1 == h_sample * w_sample and d2 == 1:
103
- return pred_depth.permute(0, 2, 1).reshape(b, 1, h_sample, w_sample)
104
- if d1 == 1 and d2 == h_sample * w_sample:
105
- return pred_depth.reshape(b, 1, h_sample, w_sample)
106
- if d1 == h_sample and d2 == w_sample:
107
- return pred_depth.unsqueeze(1)
108
-
109
- raise ValueError(
110
- f"Unsupported pred_depth shape: {tuple(pred_depth.shape)} for target ({h_sample}, {w_sample})"
111
- )
112
-
113
-
114
- def build_query_coords(
115
- h_sample: int,
116
- w_sample: int,
117
- device: torch.device,
118
- ) -> torch.Tensor:
119
- return SAMPLING_METHODS["2d_uniform"]((h_sample, w_sample)).unsqueeze(0).to(device)
120
-
121
-
122
- def resolve_video_output_paths(
123
- input_video_path: str,
124
- depth_output_video_path: Optional[str],
125
- pcd_output_dir: Optional[str],
126
- save_depth_video: bool,
127
- save_pcd: bool,
128
- ) -> tuple[str, str]:
129
- resolved_depth_video_path = depth_output_video_path or default_video_file_by_input(
130
- input_video_path,
131
- "pred_depth_video",
132
- "pred_depth.mp4",
133
- )
134
- resolved_pcd_output_dir = pcd_output_dir or default_dir_by_input_file(input_video_path, "pred_pcd_frames")
135
-
136
- if save_depth_video:
137
- os.makedirs(os.path.dirname(resolved_depth_video_path) or ".", exist_ok=True)
138
- if save_pcd:
139
- os.makedirs(resolved_pcd_output_dir, exist_ok=True)
140
- return resolved_depth_video_path, resolved_pcd_output_dir
141
-
142
-
143
- def prepare_video_prompt(
144
- depth_frame: Optional[np.ndarray],
145
- image: torch.Tensor,
146
- input_size: tuple[int, int],
147
- depth_video_scale: float,
148
- num_samples: int,
149
- min_prompt: float,
150
- max_prompt: float,
151
- enable_noise_filter: bool,
152
- filter_std_threshold: float,
153
- filter_median_threshold: float,
154
- filter_gradient_threshold: float,
155
- filter_min_neighbors: int,
156
- moge2_pretrained: str,
157
- moge2_use_fp16: bool,
158
- moge2_resolution_level: int,
159
- moge2_num_tokens: Optional[int],
160
- moge2_threshold: float,
161
- device: torch.device,
162
- ) -> tuple[torch.Tensor, torch.Tensor]:
163
- if depth_frame is not None:
164
- _, prompt_depth, depth_mask = prepare_prompt_from_depth_frame(
165
- depth_frame=depth_frame,
166
- input_size=input_size,
167
- depth_video_scale=depth_video_scale,
168
- num_samples=num_samples,
169
- min_prompt=min_prompt,
170
- max_prompt=max_prompt,
171
- enable_noise_filter=enable_noise_filter,
172
- filter_std_threshold=filter_std_threshold,
173
- filter_median_threshold=filter_median_threshold,
174
- filter_gradient_threshold=filter_gradient_threshold,
175
- filter_min_neighbors=filter_min_neighbors,
176
- device=device,
177
- )
178
- return prompt_depth, depth_mask
179
-
180
- pred_depth_prompt, depth_mask = estimate_metric_depth_with_moge2(
181
- image=image,
182
- pretrained_model_name_or_path=moge2_pretrained,
183
- use_fp16=moge2_use_fp16,
184
- resolution_level=moge2_resolution_level,
185
- num_tokens=moge2_num_tokens,
186
- threshold=moge2_threshold,
187
- )
188
- return pred_depth_prompt.to(device), depth_mask.to(device)
189
-
190
-
191
- def write_depth_video_frame(
192
- pred_depthmap: torch.Tensor,
193
- depth_writer: Optional[cv2.VideoWriter],
194
- writer_size: Optional[tuple[int, int]],
195
- depth_output_video_path: str,
196
- final_fps: float,
197
- ) -> tuple[cv2.VideoWriter, tuple[int, int]]:
198
- depth_np = pred_depthmap[0, 0].detach().cpu().numpy()
199
- valid = np.isfinite(depth_np) & (depth_np > 0)
200
- if np.any(valid):
201
- depth_min, depth_max = np.percentile(depth_np[valid], [1.0, 99.0]).tolist()
202
- if depth_max <= depth_min:
203
- depth_max = depth_min + 1e-6
204
- else:
205
- depth_min, depth_max = 0.0, 1.0
206
-
207
- vis_depth = colorize_depth_maps(depth_np, min_depth=depth_min, max_depth=depth_max, cmap="Spectral")
208
- vis_bgr = cv2.cvtColor(vis_depth, cv2.COLOR_RGB2BGR)
209
-
210
- if depth_writer is None:
211
- writer_size = (vis_bgr.shape[1], vis_bgr.shape[0])
212
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
213
- depth_writer = cv2.VideoWriter(depth_output_video_path, fourcc, float(final_fps), writer_size)
214
- if not depth_writer.isOpened():
215
- raise RuntimeError(f"Failed to open depth video writer: {depth_output_video_path}")
216
-
217
- if writer_size is None:
218
- raise RuntimeError("writer_size should not be None after depth_writer initialization.")
219
-
220
- if (vis_bgr.shape[1], vis_bgr.shape[0]) != writer_size:
221
- vis_bgr = cv2.resize(vis_bgr, writer_size, interpolation=cv2.INTER_AREA)
222
- depth_writer.write(vis_bgr)
223
- return depth_writer, writer_size
224
-
225
-
226
- def save_video_frame_point_cloud(
227
- query_2d_uniform_coord: torch.Tensor,
228
- pred_2d_uniform_depth: torch.Tensor,
229
- image: torch.Tensor,
230
- fx: float,
231
- fy: float,
232
- cx: float,
233
- cy: float,
234
- pcd_output_dir: str,
235
- frame_id: int,
236
- enable_filter_flying_points: bool,
237
- ) -> str:
238
- pcd_save_path = os.path.join(pcd_output_dir, f"frame_{frame_id:06d}.ply")
239
- save_sampled_point_clouds(
240
- sampled_coord=query_2d_uniform_coord.squeeze(0).detach().cpu(),
241
- sampled_depth=pred_2d_uniform_depth.squeeze(0).squeeze(-1).detach().cpu(),
242
- rgb_image=image.squeeze(0).detach().cpu(),
243
- fx=fx,
244
- fy=fy,
245
- cx=cx,
246
- cy=cy,
247
- output_path=pcd_save_path,
248
- filter_flying_points=enable_filter_flying_points,
249
- )
250
- return pcd_save_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
InfiniDepth/utils/gs_utils.py DELETED
@@ -1,289 +0,0 @@
1
- import math
2
- import os
3
- from typing import Optional
4
-
5
- import imageio.v2 as imageio
6
- import numpy as np
7
- import torch
8
- import torch.nn.functional as F
9
-
10
- from InfiniDepth.gs import Gaussians
11
- from InfiniDepth.gs.projection import homogenize_points, transform_cam2world, unproject
12
-
13
-
14
- def _safe_normalize(v: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
15
- return v / torch.clamp(torch.norm(v), min=eps)
16
-
17
-
18
- def _look_at_c2w(position: torch.Tensor, target: torch.Tensor, up_hint: torch.Tensor) -> torch.Tensor:
19
- forward = _safe_normalize(target - position)
20
- # Camera basis is stored as [right, up, forward]. Using cross(forward, up)
21
- # flips the x-axis and produces a horizontally mirrored render. Keep the
22
- # original up hint and derive a right-handed basis instead.
23
- right = torch.cross(up_hint, forward, dim=0)
24
- if torch.norm(right) < 1e-6:
25
- right = torch.cross(
26
- torch.tensor([1.0, 0.0, 0.0], device=position.device, dtype=position.dtype),
27
- forward,
28
- dim=0,
29
- )
30
- right = _safe_normalize(right)
31
- up = _safe_normalize(torch.cross(forward, right, dim=0))
32
-
33
- c2w = torch.eye(4, device=position.device, dtype=position.dtype)
34
- c2w[:3, 0] = right
35
- c2w[:3, 1] = up
36
- c2w[:3, 2] = forward
37
- c2w[:3, 3] = position
38
- return c2w
39
-
40
-
41
- def _build_orbit_poses(
42
- base_c2w: torch.Tensor,
43
- target: torch.Tensor,
44
- num_frames: int,
45
- radius: float,
46
- vertical: float,
47
- forward_amp: float,
48
- ) -> list[torch.Tensor]:
49
- base_pos = base_c2w[:3, 3]
50
- right = base_c2w[:3, 0]
51
- up = base_c2w[:3, 1]
52
- forward = base_c2w[:3, 2]
53
-
54
- poses: list[torch.Tensor] = []
55
- n = max(2, int(num_frames))
56
- for i in range(n):
57
- theta = 2.0 * math.pi * float(i) / float(n)
58
- offset = (
59
- right * (radius * math.sin(theta))
60
- + up * (vertical * math.sin(2.0 * theta))
61
- + forward * (forward_amp * 0.5 * (1.0 - math.cos(theta)))
62
- )
63
- pos = base_pos + offset
64
- poses.append(_look_at_c2w(pos, target, up))
65
- return poses
66
-
67
-
68
- def _build_swing_poses(
69
- base_c2w: torch.Tensor,
70
- num_frames: int,
71
- radius: float,
72
- forward_amp: float,
73
- ) -> list[torch.Tensor]:
74
- base_pos = base_c2w[:3, 3]
75
- right = base_c2w[:3, 0]
76
- forward = base_c2w[:3, 2]
77
-
78
- key_offsets = [
79
- torch.zeros(3, device=base_pos.device, dtype=base_pos.dtype),
80
- -right * radius,
81
- right * radius,
82
- forward * forward_amp,
83
- torch.zeros(3, device=base_pos.device, dtype=base_pos.dtype),
84
- ]
85
-
86
- poses: list[torch.Tensor] = []
87
- seg_frames = max(1, int(num_frames) // (len(key_offsets) - 1))
88
- for seg in range(len(key_offsets) - 1):
89
- p0 = base_pos + key_offsets[seg]
90
- p1 = base_pos + key_offsets[seg + 1]
91
- for i in range(seg_frames):
92
- alpha = 1.0 if seg_frames == 1 else float(i) / float(seg_frames - 1)
93
- pos = (1.0 - alpha) * p0 + alpha * p1
94
- pose = base_c2w.clone()
95
- pose[:3, 3] = pos
96
- if seg > 0 and i == 0:
97
- continue
98
- poses.append(pose)
99
- return poses
100
-
101
-
102
- def _scale_intrinsics_for_render(
103
- intrinsics: torch.Tensor,
104
- src_h: int,
105
- src_w: int,
106
- dst_h: int,
107
- dst_w: int,
108
- ) -> torch.Tensor:
109
- scaled = intrinsics.clone()
110
- sx = float(dst_w) / float(src_w)
111
- sy = float(dst_h) / float(src_h)
112
- scaled[0, 0] *= sx
113
- scaled[1, 1] *= sy
114
- scaled[0, 2] *= sx
115
- scaled[1, 2] *= sy
116
- return scaled
117
-
118
-
119
- def _render_gaussian_frame(
120
- rasterization_fn,
121
- means: torch.Tensor,
122
- harmonics: torch.Tensor,
123
- opacities: torch.Tensor,
124
- scales: torch.Tensor,
125
- rotations: torch.Tensor,
126
- c2w: torch.Tensor,
127
- intrinsics: torch.Tensor,
128
- render_h: int,
129
- render_w: int,
130
- bg_color: tuple[float, float, float],
131
- ) -> np.ndarray:
132
- xyzs = means.unsqueeze(0).float() # [1, N, 3]
133
- opacitys = opacities.unsqueeze(0).float() # [1, N]
134
- rotations_b = rotations.unsqueeze(0).float() # [1, N, 4]
135
- scales_b = scales.unsqueeze(0).float() # [1, N, 3]
136
-
137
- # [N, 3, d_sh] -> [1, N, d_sh, 3]
138
- features = harmonics.unsqueeze(0).permute(0, 1, 3, 2).contiguous().float()
139
- d_sh = features.shape[-2]
140
- sh_degree = int(round(math.sqrt(float(d_sh)) - 1.0))
141
-
142
- w2c = torch.linalg.inv(c2w).unsqueeze(0).unsqueeze(0).float() # [1, 1, 4, 4]
143
- Ks = intrinsics.unsqueeze(0).unsqueeze(0).float() # [1, 1, 3, 3]
144
- backgrounds = torch.tensor(bg_color, dtype=torch.float32, device=xyzs.device).view(1, 1, 3)
145
-
146
- rendering, _, _ = rasterization_fn(
147
- xyzs,
148
- rotations_b,
149
- scales_b,
150
- opacitys,
151
- features,
152
- w2c,
153
- Ks,
154
- render_w,
155
- render_h,
156
- sh_degree=sh_degree,
157
- render_mode="RGB+D",
158
- packed=False,
159
- backgrounds=backgrounds,
160
- covars=None,
161
- eps2d=1e-8,
162
- )
163
-
164
- rgb = rendering[0, 0, :, :, :3].clamp(0.0, 1.0)
165
- return (rgb * 255.0).to(torch.uint8).cpu().numpy()
166
-
167
-
168
- def _render_novel_video(
169
- means: torch.Tensor,
170
- harmonics: torch.Tensor,
171
- opacities: torch.Tensor,
172
- scales: torch.Tensor,
173
- rotations: torch.Tensor,
174
- base_c2w: torch.Tensor,
175
- intrinsics: torch.Tensor,
176
- render_h: int,
177
- render_w: int,
178
- video_path: str,
179
- trajectory: str,
180
- num_frames: int,
181
- fps: int,
182
- radius: float,
183
- vertical: float,
184
- forward_amp: float,
185
- bg_color: tuple[float, float, float],
186
- ) -> None:
187
- try:
188
- from gsplat import rasterization as rasterization_fn
189
- except ImportError as exc:
190
- raise RuntimeError("Novel-view rendering requires gsplat. Please install gsplat first.") from exc
191
-
192
- target = means.mean(dim=0)
193
- if trajectory == "swing":
194
- poses = _build_swing_poses(base_c2w, num_frames, radius, forward_amp)
195
- else:
196
- poses = _build_orbit_poses(base_c2w, target, num_frames, radius, vertical, forward_amp)
197
-
198
- video_dir = os.path.dirname(video_path)
199
- if video_dir:
200
- os.makedirs(video_dir, exist_ok=True)
201
-
202
- try:
203
- with imageio.get_writer(
204
- video_path,
205
- fps=float(max(1, fps)),
206
- codec="libx264",
207
- macro_block_size=1,
208
- ) as writer:
209
- for pose in poses:
210
- frame_rgb = _render_gaussian_frame(
211
- rasterization_fn=rasterization_fn,
212
- means=means,
213
- harmonics=harmonics,
214
- opacities=opacities,
215
- scales=scales,
216
- rotations=rotations,
217
- c2w=pose,
218
- intrinsics=intrinsics,
219
- render_h=render_h,
220
- render_w=render_w,
221
- bg_color=bg_color,
222
- )
223
- writer.append_data(frame_rgb)
224
- except Exception as exc:
225
- raise RuntimeError(f"Failed to write video with imageio: {video_path}") from exc
226
-
227
-
228
- def _build_sparse_uniform_gaussians(
229
- dense_gaussians,
230
- query_3d_uniform_coord: torch.Tensor,
231
- pred_depth_3d: torch.Tensor,
232
- intrinsics: torch.Tensor,
233
- extrinsics: torch.Tensor,
234
- h: int,
235
- w: int,
236
- ) -> Gaussians:
237
- """Convert dense pixel gaussians to sparse 3d-uniform gaussians.
238
- """
239
- if dense_gaussians.means.shape[0] != 1:
240
- raise ValueError("Current strict-aligned sparse interpolation only supports batch size 1.")
241
-
242
- sparse_coords_normalized = query_3d_uniform_coord[0] # [N,2], [y,x]
243
- sparse_depths = pred_depth_3d[0] # [N,1]
244
-
245
- # Convert normalized coordinates to pixel coordinates
246
- p_y = ((sparse_coords_normalized[:, 0] + 1.0) * (h / 2.0)) - 0.5
247
- p_x = ((sparse_coords_normalized[:, 1] + 1.0) * (w / 2.0)) - 0.5
248
- xy_coords = torch.stack([p_x, p_y], dim=-1) # [N,2], [x,y]
249
-
250
- depth_values = sparse_depths.squeeze(-1)
251
- camera_points = unproject(xy_coords.unsqueeze(0), depth_values.unsqueeze(0), intrinsics)[0]
252
- camera_points_hom = homogenize_points(camera_points)
253
- world_points = transform_cam2world(camera_points_hom.unsqueeze(0), extrinsics)[0]
254
- sparse_pts_world = world_points[..., :3] # [N,3]
255
-
256
- grid = sparse_coords_normalized[:, [1, 0]].unsqueeze(0).unsqueeze(0) # [1,1,N,2]
257
-
258
- def sample_attribute(attr):
259
- if attr.dim() == 2:
260
- attr_spatial = attr.view(1, 1, h, w)
261
- sampled = F.grid_sample(attr_spatial, grid, mode="bilinear", align_corners=False)
262
- return sampled.squeeze(0).squeeze(0)
263
- if attr.dim() == 3:
264
- d = attr.shape[-1]
265
- attr_spatial = attr.view(1, h, w, d).permute(0, 3, 1, 2)
266
- sampled = F.grid_sample(attr_spatial, grid, mode="bilinear", align_corners=False)
267
- return sampled.squeeze(2).permute(0, 2, 1)
268
- if attr.dim() == 4:
269
- d1, d2 = attr.shape[-2:]
270
- attr_flat = attr.view(1, h, w, d1 * d2).permute(0, 3, 1, 2)
271
- sampled = F.grid_sample(attr_flat, grid, mode="bilinear", align_corners=False)
272
- return sampled.squeeze(2).permute(0, 2, 1).view(1, -1, d1, d2)
273
- raise ValueError(f"Unsupported attribute dimension: {attr.dim()}")
274
-
275
- sparse_harmonics = sample_attribute(dense_gaussians.harmonics)
276
- sparse_opacities = sample_attribute(dense_gaussians.opacities)
277
- sparse_scales = sample_attribute(dense_gaussians.scales)
278
- sparse_rotations = sample_attribute(dense_gaussians.rotations)
279
- sparse_rotations = sparse_rotations / (torch.norm(sparse_rotations, dim=-1, keepdim=True) + 1e-8)
280
-
281
- return Gaussians(
282
- means=sparse_pts_world.unsqueeze(0),
283
- covariances=None,
284
- harmonics=sparse_harmonics,
285
- opacities=sparse_opacities,
286
- scales=sparse_scales,
287
- rotations=sparse_rotations,
288
- )
289
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
InfiniDepth/utils/inference_utils.py CHANGED
@@ -6,8 +6,6 @@ import cv2
6
  import torch
7
  import torch.nn.functional as F
8
 
9
- from InfiniDepth.gs import Gaussians
10
-
11
  from .io_utils import load_depth
12
  from .moge_utils import estimate_metric_depth_with_moge2
13
  from .vis_utils import build_sky_model, run_skyseg
@@ -331,67 +329,4 @@ def build_camera_matrices(
331
  device=device,
332
  ).unsqueeze(0).expand(batch, -1, -1)
333
  extrinsics = torch.eye(4, dtype=torch.float32, device=device).unsqueeze(0).expand(batch, -1, -1)
334
- return fx, fy, cx, cy, intrinsics, extrinsics
335
-
336
-
337
- def filter_gaussians_by_depth_ratio(
338
- pixel_gaussians: Gaussians,
339
- extrinsics: torch.Tensor,
340
- keep_far_ratio: float,
341
- ) -> tuple[Gaussians, int, int, float, float]:
342
- camera_position = extrinsics[0, :3, 3]
343
- gaussian_means = pixel_gaussians.means[0]
344
- distances = torch.norm(gaussian_means - camera_position.unsqueeze(0), dim=-1)
345
- max_depth = distances.max()
346
- depth_threshold = max_depth * keep_far_ratio
347
- near_mask = distances <= depth_threshold
348
- num_filtered = int((~near_mask).sum().item())
349
- num_kept = int(near_mask.sum().item())
350
- filtered_gaussians = Gaussians(
351
- means=pixel_gaussians.means[:, near_mask, :],
352
- covariances=None,
353
- harmonics=pixel_gaussians.harmonics[:, near_mask, :, :],
354
- opacities=pixel_gaussians.opacities[:, near_mask],
355
- scales=pixel_gaussians.scales[:, near_mask, :],
356
- rotations=pixel_gaussians.rotations[:, near_mask, :],
357
- )
358
- return filtered_gaussians, num_filtered, num_kept, float(depth_threshold.item()), float(max_depth.item())
359
-
360
-
361
- def filter_gaussians_by_min_opacity(pixel_gaussians: Gaussians, min_opacity: float) -> Gaussians:
362
- if min_opacity <= 0.0:
363
- return pixel_gaussians
364
- keep = pixel_gaussians.opacities[0] >= min_opacity
365
- return Gaussians(
366
- means=pixel_gaussians.means[:, keep, :],
367
- covariances=None,
368
- harmonics=pixel_gaussians.harmonics[:, keep, :, :],
369
- opacities=pixel_gaussians.opacities[:, keep],
370
- scales=pixel_gaussians.scales[:, keep, :],
371
- rotations=pixel_gaussians.rotations[:, keep, :],
372
- )
373
-
374
-
375
- def unpack_gaussians_for_export(
376
- pixel_gaussians: Gaussians,
377
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
378
- return (
379
- pixel_gaussians.means[0],
380
- pixel_gaussians.harmonics[0],
381
- pixel_gaussians.opacities[0],
382
- pixel_gaussians.scales[0],
383
- pixel_gaussians.rotations[0],
384
- )
385
-
386
-
387
- def resolve_ply_output_path(
388
- input_image_path: str,
389
- model_type: str,
390
- output_ply_dir: Optional[str] = None,
391
- output_ply_name: Optional[str] = None,
392
- ) -> tuple[str, str]:
393
- ply_dir = output_ply_dir or default_dir_by_input_file(input_image_path, "pred_gs")
394
- os.makedirs(ply_dir, exist_ok=True)
395
- stem = os.path.splitext(os.path.basename(input_image_path))[0]
396
- ply_name = output_ply_name or f"{model_type}_{stem}_gaussians.ply"
397
- return ply_dir, os.path.join(ply_dir, ply_name)
 
6
  import torch
7
  import torch.nn.functional as F
8
 
 
 
9
  from .io_utils import load_depth
10
  from .moge_utils import estimate_metric_depth_with_moge2
11
  from .vis_utils import build_sky_model, run_skyseg
 
329
  device=device,
330
  ).unsqueeze(0).expand(batch, -1, -1)
331
  extrinsics = torch.eye(4, dtype=torch.float32, device=device).unsqueeze(0).expand(batch, -1, -1)
332
+ return fx, fy, cx, cy, intrinsics, extrinsics