Spaces:
Running on Zero
Running on Zero
update
Browse files- InfiniDepth/gs/__init__.py +0 -11
- InfiniDepth/gs/adapter.py +0 -90
- InfiniDepth/gs/ply.py +0 -232
- InfiniDepth/gs/predictor.py +0 -139
- InfiniDepth/gs/projection.py +0 -53
- InfiniDepth/gs/types.py +0 -14
- InfiniDepth/utils/depth_video_utils.py +0 -250
- InfiniDepth/utils/gs_utils.py +0 -289
- InfiniDepth/utils/inference_utils.py +1 -66
InfiniDepth/gs/__init__.py
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
"""Lightweight Gaussian Splatting inference utilities."""
|
| 2 |
-
|
| 3 |
-
from .types import Gaussians
|
| 4 |
-
from .predictor import GSPixelAlignPredictor
|
| 5 |
-
from .ply import export_ply
|
| 6 |
-
|
| 7 |
-
__all__ = [
|
| 8 |
-
"Gaussians",
|
| 9 |
-
"GSPixelAlignPredictor",
|
| 10 |
-
"export_ply",
|
| 11 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
InfiniDepth/gs/adapter.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
from torch import nn
|
| 6 |
-
|
| 7 |
-
from .projection import get_world_rays
|
| 8 |
-
from .types import Gaussians
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def rgb_to_sh(rgb: torch.Tensor) -> torch.Tensor:
|
| 12 |
-
c0 = 0.28209479177387814
|
| 13 |
-
return (rgb - 0.5) / c0
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
@dataclass
|
| 17 |
-
class GaussianAdapterCfg:
|
| 18 |
-
gaussian_scale_min: float = 1e-10
|
| 19 |
-
gaussian_scale_max: float = 5.0
|
| 20 |
-
sh_degree: int = 2
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class GaussianAdapter(nn.Module):
|
| 24 |
-
def __init__(self, cfg: GaussianAdapterCfg) -> None:
|
| 25 |
-
super().__init__()
|
| 26 |
-
self.cfg = cfg
|
| 27 |
-
self.register_buffer("sh_mask", torch.ones((self.d_sh,), dtype=torch.float32), persistent=False)
|
| 28 |
-
for degree in range(1, self.cfg.sh_degree + 1):
|
| 29 |
-
self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * (0.25**degree)
|
| 30 |
-
|
| 31 |
-
@property
|
| 32 |
-
def d_sh(self) -> int:
|
| 33 |
-
return (self.cfg.sh_degree + 1) ** 2
|
| 34 |
-
|
| 35 |
-
@property
|
| 36 |
-
def d_in(self) -> int:
|
| 37 |
-
return 7 + 3 * self.d_sh
|
| 38 |
-
|
| 39 |
-
def forward(
|
| 40 |
-
self,
|
| 41 |
-
image: torch.Tensor,
|
| 42 |
-
extrinsics: torch.Tensor,
|
| 43 |
-
intrinsics: torch.Tensor,
|
| 44 |
-
coordinates_xy: torch.Tensor,
|
| 45 |
-
depths: torch.Tensor,
|
| 46 |
-
opacities: torch.Tensor,
|
| 47 |
-
raw_gaussians: torch.Tensor,
|
| 48 |
-
) -> Gaussians:
|
| 49 |
-
"""Build world-space gaussians from per-point raw parameters.
|
| 50 |
-
|
| 51 |
-
image: [B, 3, H, W]
|
| 52 |
-
extrinsics: [B, 4, 4] camera-to-world
|
| 53 |
-
intrinsics: [B, 3, 3]
|
| 54 |
-
coordinates_xy: [B, N, 2] pixel-space (x, y)
|
| 55 |
-
depths: [B, N]
|
| 56 |
-
opacities: [B, N]
|
| 57 |
-
raw_gaussians: [B, N, 7 + 3*d_sh]
|
| 58 |
-
"""
|
| 59 |
-
b, _, h, w = image.shape
|
| 60 |
-
scales_raw, rotations_raw, sh_raw = torch.split(raw_gaussians, [3, 4, 3 * self.d_sh], dim=-1)
|
| 61 |
-
scales = torch.clamp(
|
| 62 |
-
F.softplus(scales_raw - 4.0),
|
| 63 |
-
min=self.cfg.gaussian_scale_min,
|
| 64 |
-
max=self.cfg.gaussian_scale_max,
|
| 65 |
-
)
|
| 66 |
-
rotations = rotations_raw / (torch.norm(rotations_raw, dim=-1, keepdim=True) + 1e-8)
|
| 67 |
-
|
| 68 |
-
harmonics = sh_raw.view(b, -1, 3, self.d_sh) * self.sh_mask.view(1, 1, 1, -1)
|
| 69 |
-
|
| 70 |
-
# Initialize DC term from image color sampled at gaussian centers.
|
| 71 |
-
x = coordinates_xy[..., 0]
|
| 72 |
-
y = coordinates_xy[..., 1]
|
| 73 |
-
grid_x = (x / max(float(w), 1.0)) * 2.0 - 1.0
|
| 74 |
-
grid_y = (y / max(float(h), 1.0)) * 2.0 - 1.0
|
| 75 |
-
grid = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(2) # [B, N, 1, 2]
|
| 76 |
-
sampled_rgb = F.grid_sample(image, grid, mode="bilinear", align_corners=False)
|
| 77 |
-
sampled_rgb = sampled_rgb.squeeze(-1).permute(0, 2, 1) # [B, N, 3]
|
| 78 |
-
harmonics[..., 0] = harmonics[..., 0] + rgb_to_sh(sampled_rgb)
|
| 79 |
-
|
| 80 |
-
origins, directions = get_world_rays(coordinates_xy, extrinsics, intrinsics)
|
| 81 |
-
means = origins + directions * depths.unsqueeze(-1)
|
| 82 |
-
|
| 83 |
-
return Gaussians(
|
| 84 |
-
means=means,
|
| 85 |
-
harmonics=harmonics,
|
| 86 |
-
opacities=opacities,
|
| 87 |
-
scales=scales,
|
| 88 |
-
rotations=rotations,
|
| 89 |
-
covariances=None,
|
| 90 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
InfiniDepth/gs/ply.py
DELETED
|
@@ -1,232 +0,0 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
from jaxtyping import Float
|
| 6 |
-
from plyfile import PlyData, PlyElement
|
| 7 |
-
from torch import Tensor
|
| 8 |
-
|
| 9 |
-
def _construct_attributes(d_sh: int) -> list[str]:
|
| 10 |
-
attrs = ["x", "y", "z", "nx", "ny", "nz", "f_dc_0", "f_dc_1", "f_dc_2"]
|
| 11 |
-
n_rest = 3 * max(d_sh - 1, 0)
|
| 12 |
-
attrs.extend([f"f_rest_{i}" for i in range(n_rest)])
|
| 13 |
-
attrs.extend(["opacity", "scale_0", "scale_1", "scale_2", "rot_0", "rot_1", "rot_2", "rot_3"])
|
| 14 |
-
return attrs
|
| 15 |
-
|
| 16 |
-
def export_ply(
|
| 17 |
-
means: Float[Tensor, "gaussian 3"],
|
| 18 |
-
harmonics: Float[Tensor, "gaussian 3 d_sh"],
|
| 19 |
-
opacities: Float[Tensor, " gaussian"],
|
| 20 |
-
path: str | Path,
|
| 21 |
-
scales: Float[Tensor, "gaussian 3"] | None = None,
|
| 22 |
-
rotations: Float[Tensor, "gaussian 4"] | None = None,
|
| 23 |
-
covariances: Float[Tensor, "gaussian 3 3"] | None = None, # Use covariances directly
|
| 24 |
-
shift_and_scale: bool = True,
|
| 25 |
-
save_sh_dc_only: bool = True, # Changed default to False to preserve quality
|
| 26 |
-
center_method: str = "mean", # "mean", "median", or "bbox_center"
|
| 27 |
-
apply_coordinate_transform: bool = True, # Apply x90° rotation for viewer compatibility
|
| 28 |
-
focal_length_px: float | tuple[float, float] | None = None,
|
| 29 |
-
image_shape: tuple[int, int] | None = None, # (height, width)
|
| 30 |
-
extrinsic_matrix: np.ndarray | torch.Tensor | None = None,
|
| 31 |
-
color_space_index: int | None = None,
|
| 32 |
-
):
|
| 33 |
-
path = Path(path)
|
| 34 |
-
|
| 35 |
-
# Check input consistency
|
| 36 |
-
if covariances is None and (scales is None or rotations is None):
|
| 37 |
-
raise ValueError("Either provide covariances or both scales and rotations")
|
| 38 |
-
|
| 39 |
-
# Fast covariance to scale/rotation conversion using batch operations
|
| 40 |
-
if covariances is not None:
|
| 41 |
-
# Batch eigenvalue decomposition - much faster than individual decompositions
|
| 42 |
-
eigenvalues, eigenvectors = torch.linalg.eigh(covariances)
|
| 43 |
-
scales = torch.sqrt(torch.clamp(eigenvalues, min=1e-8))
|
| 44 |
-
|
| 45 |
-
# Fast batch conversion from rotation matrices to quaternions
|
| 46 |
-
# Using direct mathematical conversion instead of scipy loops
|
| 47 |
-
def rotation_matrix_to_quaternion_batch(R):
|
| 48 |
-
"""Fast batch conversion from rotation matrices to quaternions"""
|
| 49 |
-
trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
|
| 50 |
-
|
| 51 |
-
# Pre-allocate quaternion tensor
|
| 52 |
-
quat = torch.zeros(R.shape[0], 4, dtype=R.dtype, device=R.device)
|
| 53 |
-
|
| 54 |
-
# Case 1: trace > 0
|
| 55 |
-
mask1 = trace > 0
|
| 56 |
-
if mask1.any():
|
| 57 |
-
s = torch.sqrt(trace[mask1] + 1.0) * 2 # s = 4 * qw
|
| 58 |
-
quat[mask1, 0] = 0.25 * s # qw
|
| 59 |
-
quat[mask1, 1] = (R[mask1, 2, 1] - R[mask1, 1, 2]) / s # qx
|
| 60 |
-
quat[mask1, 2] = (R[mask1, 0, 2] - R[mask1, 2, 0]) / s # qy
|
| 61 |
-
quat[mask1, 3] = (R[mask1, 1, 0] - R[mask1, 0, 1]) / s # qz
|
| 62 |
-
|
| 63 |
-
# Case 2: R[0,0] > R[1,1] and R[0,0] > R[2,2]
|
| 64 |
-
mask2 = ~mask1 & (R[..., 0, 0] > R[..., 1, 1]) & (R[..., 0, 0] > R[..., 2, 2])
|
| 65 |
-
if mask2.any():
|
| 66 |
-
s = torch.sqrt(1.0 + R[mask2, 0, 0] - R[mask2, 1, 1] - R[mask2, 2, 2]) * 2
|
| 67 |
-
quat[mask2, 0] = (R[mask2, 2, 1] - R[mask2, 1, 2]) / s # qw
|
| 68 |
-
quat[mask2, 1] = 0.25 * s # qx
|
| 69 |
-
quat[mask2, 2] = (R[mask2, 0, 1] + R[mask2, 1, 0]) / s # qy
|
| 70 |
-
quat[mask2, 3] = (R[mask2, 0, 2] + R[mask2, 2, 0]) / s # qz
|
| 71 |
-
|
| 72 |
-
# Case 3: R[1,1] > R[2,2]
|
| 73 |
-
mask3 = ~mask1 & ~mask2 & (R[..., 1, 1] > R[..., 2, 2])
|
| 74 |
-
if mask3.any():
|
| 75 |
-
s = torch.sqrt(1.0 + R[mask3, 1, 1] - R[mask3, 0, 0] - R[mask3, 2, 2]) * 2
|
| 76 |
-
quat[mask3, 0] = (R[mask3, 0, 2] - R[mask3, 2, 0]) / s # qw
|
| 77 |
-
quat[mask3, 1] = (R[mask3, 0, 1] + R[mask3, 1, 0]) / s # qx
|
| 78 |
-
quat[mask3, 2] = 0.25 * s # qy
|
| 79 |
-
quat[mask3, 3] = (R[mask3, 1, 2] + R[mask3, 2, 1]) / s # qz
|
| 80 |
-
|
| 81 |
-
# Case 4: else
|
| 82 |
-
mask4 = ~mask1 & ~mask2 & ~mask3
|
| 83 |
-
if mask4.any():
|
| 84 |
-
s = torch.sqrt(1.0 + R[mask4, 2, 2] - R[mask4, 0, 0] - R[mask4, 1, 1]) * 2
|
| 85 |
-
quat[mask4, 0] = (R[mask4, 1, 0] - R[mask4, 0, 1]) / s # qw
|
| 86 |
-
quat[mask4, 1] = (R[mask4, 0, 2] + R[mask4, 2, 0]) / s # qx
|
| 87 |
-
quat[mask4, 2] = (R[mask4, 1, 2] + R[mask4, 2, 1]) / s # qy
|
| 88 |
-
quat[mask4, 3] = 0.25 * s # qz
|
| 89 |
-
|
| 90 |
-
return quat
|
| 91 |
-
|
| 92 |
-
# Ensure proper rotation matrices
|
| 93 |
-
det = torch.det(eigenvectors)
|
| 94 |
-
eigenvectors = torch.where(det.unsqueeze(-1).unsqueeze(-1) < 0,
|
| 95 |
-
-eigenvectors, eigenvectors)
|
| 96 |
-
|
| 97 |
-
# Fast batch conversion
|
| 98 |
-
rotations = rotation_matrix_to_quaternion_batch(eigenvectors)
|
| 99 |
-
|
| 100 |
-
# Apply centering - vectorized operations
|
| 101 |
-
if shift_and_scale:
|
| 102 |
-
if center_method == "mean":
|
| 103 |
-
center = means.mean(dim=0)
|
| 104 |
-
elif center_method == "median":
|
| 105 |
-
center = means.median(dim=0).values
|
| 106 |
-
elif center_method == "bbox_center":
|
| 107 |
-
center = (means.min(dim=0).values + means.max(dim=0).values) / 2
|
| 108 |
-
else:
|
| 109 |
-
raise ValueError(f"Unknown center_method: {center_method}")
|
| 110 |
-
means = means - center
|
| 111 |
-
|
| 112 |
-
# Fast coordinate transformation using batch operations
|
| 113 |
-
if apply_coordinate_transform:
|
| 114 |
-
# X-axis 90° rotation matrix
|
| 115 |
-
rot_x = torch.tensor([
|
| 116 |
-
[1, 0, 0],
|
| 117 |
-
[0, 0, -1],
|
| 118 |
-
[0, 1, 0]
|
| 119 |
-
], dtype=means.dtype, device=means.device)
|
| 120 |
-
|
| 121 |
-
# Apply to positions - batch matrix multiplication
|
| 122 |
-
means = means @ rot_x.T
|
| 123 |
-
|
| 124 |
-
# Apply to rotations - batch quaternion operations
|
| 125 |
-
transform_quat = torch.tensor([0.7071068, 0.7071068, 0.0, 0.0],
|
| 126 |
-
dtype=rotations.dtype, device=rotations.device) # 90° around X
|
| 127 |
-
|
| 128 |
-
# Batch quaternion multiplication
|
| 129 |
-
w1, x1, y1, z1 = transform_quat[0], transform_quat[1], transform_quat[2], transform_quat[3]
|
| 130 |
-
w2, x2, y2, z2 = rotations[:, 0], rotations[:, 1], rotations[:, 2], rotations[:, 3]
|
| 131 |
-
|
| 132 |
-
rotations = torch.stack([
|
| 133 |
-
w1*w2 - x1*x2 - y1*y2 - z1*z2, # w
|
| 134 |
-
w1*x2 + x1*w2 + y1*z2 - z1*y2, # x
|
| 135 |
-
w1*y2 - x1*z2 + y1*w2 + z1*x2, # y
|
| 136 |
-
w1*z2 + x1*y2 - y1*x2 + z1*w2 # z
|
| 137 |
-
], dim=1)
|
| 138 |
-
|
| 139 |
-
# Convert to numpy for PLY writing - single conversion
|
| 140 |
-
means_np = means.detach().cpu().numpy()
|
| 141 |
-
scales_np = scales.detach().cpu().numpy()
|
| 142 |
-
rotations_np = rotations.detach().cpu().numpy()
|
| 143 |
-
opacities_np = opacities.detach().cpu().numpy()
|
| 144 |
-
harmonics_np = harmonics.detach().cpu().numpy()
|
| 145 |
-
|
| 146 |
-
# Process harmonics
|
| 147 |
-
f_dc = harmonics_np[..., 0]
|
| 148 |
-
f_rest = harmonics_np[..., 1:].reshape(harmonics_np.shape[0], -1)
|
| 149 |
-
|
| 150 |
-
d_sh = harmonics_np.shape[-1]
|
| 151 |
-
dtype_full = [
|
| 152 |
-
(attribute, "f4")
|
| 153 |
-
for attribute in _construct_attributes(1 if save_sh_dc_only else d_sh)
|
| 154 |
-
]
|
| 155 |
-
elements = np.empty(means_np.shape[0], dtype=dtype_full)
|
| 156 |
-
|
| 157 |
-
# Build attributes list
|
| 158 |
-
attributes = [
|
| 159 |
-
means_np,
|
| 160 |
-
np.zeros_like(means_np), # normals
|
| 161 |
-
f_dc,
|
| 162 |
-
]
|
| 163 |
-
|
| 164 |
-
if not save_sh_dc_only:
|
| 165 |
-
attributes.append(f_rest)
|
| 166 |
-
|
| 167 |
-
# Apply inverse sigmoid to opacity for storage (viewer will apply sigmoid when loading)
|
| 168 |
-
# logit(opacity) = log(opacity / (1 - opacity))
|
| 169 |
-
opacities_clamped = np.clip(opacities_np, 1e-6, 1 - 1e-6) # Clamp to avoid log(0) or log(inf)
|
| 170 |
-
opacities_logit = np.log(opacities_clamped / (1 - opacities_clamped))
|
| 171 |
-
|
| 172 |
-
attributes.extend([
|
| 173 |
-
opacities_logit.reshape(-1, 1),
|
| 174 |
-
np.log(scales_np),
|
| 175 |
-
rotations_np
|
| 176 |
-
])
|
| 177 |
-
|
| 178 |
-
attributes = np.concatenate(attributes, axis=1)
|
| 179 |
-
elements[:] = list(map(tuple, attributes))
|
| 180 |
-
path.parent.mkdir(exist_ok=True, parents=True)
|
| 181 |
-
ply_elements = [PlyElement.describe(elements, "vertex")]
|
| 182 |
-
|
| 183 |
-
if focal_length_px is not None and image_shape is not None:
|
| 184 |
-
image_height, image_width = image_shape
|
| 185 |
-
if isinstance(focal_length_px, tuple):
|
| 186 |
-
fx, fy = float(focal_length_px[0]), float(focal_length_px[1])
|
| 187 |
-
else:
|
| 188 |
-
fx = fy = float(focal_length_px)
|
| 189 |
-
|
| 190 |
-
dtype_image_size = [("image_size", "u4")]
|
| 191 |
-
image_size_array = np.empty(2, dtype=dtype_image_size)
|
| 192 |
-
image_size_array[:] = np.array([image_width, image_height], dtype=np.uint32)
|
| 193 |
-
ply_elements.append(PlyElement.describe(image_size_array, "image_size"))
|
| 194 |
-
|
| 195 |
-
dtype_intrinsic = [("intrinsic", "f4")]
|
| 196 |
-
intrinsic_array = np.empty(9, dtype=dtype_intrinsic)
|
| 197 |
-
intrinsic = np.array(
|
| 198 |
-
[
|
| 199 |
-
fx,
|
| 200 |
-
0.0,
|
| 201 |
-
image_width * 0.5,
|
| 202 |
-
0.0,
|
| 203 |
-
fy,
|
| 204 |
-
image_height * 0.5,
|
| 205 |
-
0.0,
|
| 206 |
-
0.0,
|
| 207 |
-
1.0,
|
| 208 |
-
],
|
| 209 |
-
dtype=np.float32,
|
| 210 |
-
)
|
| 211 |
-
intrinsic_array[:] = intrinsic.flatten()
|
| 212 |
-
ply_elements.append(PlyElement.describe(intrinsic_array, "intrinsic"))
|
| 213 |
-
|
| 214 |
-
dtype_extrinsic = [("extrinsic", "f4")]
|
| 215 |
-
extrinsic_array = np.empty(16, dtype=dtype_extrinsic)
|
| 216 |
-
if extrinsic_matrix is None:
|
| 217 |
-
extrinsic_np = np.eye(4, dtype=np.float32)
|
| 218 |
-
elif torch.is_tensor(extrinsic_matrix):
|
| 219 |
-
extrinsic_np = extrinsic_matrix.detach().cpu().numpy().astype(np.float32)
|
| 220 |
-
else:
|
| 221 |
-
extrinsic_np = np.asarray(extrinsic_matrix, dtype=np.float32)
|
| 222 |
-
if extrinsic_np.shape != (4, 4):
|
| 223 |
-
raise ValueError(f"extrinsic_matrix must have shape (4,4), got {extrinsic_np.shape}")
|
| 224 |
-
extrinsic_array[:] = extrinsic_np.flatten()
|
| 225 |
-
ply_elements.append(PlyElement.describe(extrinsic_array, "extrinsic"))
|
| 226 |
-
|
| 227 |
-
dtype_color_space = [("color_space", "u1")]
|
| 228 |
-
color_space_array = np.empty(1, dtype=dtype_color_space)
|
| 229 |
-
color_space_array[:] = np.array([1 if color_space_index is None else color_space_index], dtype=np.uint8)
|
| 230 |
-
ply_elements.append(PlyElement.describe(color_space_array, "color_space"))
|
| 231 |
-
|
| 232 |
-
PlyData(ply_elements).write(path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
InfiniDepth/gs/predictor.py
DELETED
|
@@ -1,139 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
|
| 6 |
-
from .adapter import GaussianAdapter, GaussianAdapterCfg
|
| 7 |
-
from .projection import sample_image_grid
|
| 8 |
-
from .types import Gaussians
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
@dataclass
|
| 12 |
-
class GSPredictorCfg:
|
| 13 |
-
rgb_feature_dim: int = 64
|
| 14 |
-
depth_feature_dim: int = 32
|
| 15 |
-
dino_reduced_dim: int = 128
|
| 16 |
-
gaussian_regressor_channels: int = 64
|
| 17 |
-
num_surfaces: int = 1
|
| 18 |
-
gaussian_scale_min: float = 1e-10
|
| 19 |
-
gaussian_scale_max: float = 5.0
|
| 20 |
-
sh_degree: int = 2
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class GSPixelAlignPredictor(nn.Module):
|
| 24 |
-
def __init__(self, dino_feature_dim: int = 1024, cfg: GSPredictorCfg | None = None) -> None:
|
| 25 |
-
super().__init__()
|
| 26 |
-
self.cfg = cfg or GSPredictorCfg()
|
| 27 |
-
cfg = self.cfg
|
| 28 |
-
|
| 29 |
-
self.rgb_encoder = nn.Sequential(
|
| 30 |
-
nn.Conv2d(3, 32, 3, 1, 1),
|
| 31 |
-
nn.GELU(),
|
| 32 |
-
nn.Conv2d(32, cfg.rgb_feature_dim, 3, 1, 1),
|
| 33 |
-
nn.GELU(),
|
| 34 |
-
)
|
| 35 |
-
self.depth_encoder = nn.Sequential(
|
| 36 |
-
nn.Conv2d(1, 16, 3, 1, 1),
|
| 37 |
-
nn.GELU(),
|
| 38 |
-
nn.Conv2d(16, cfg.depth_feature_dim, 3, 1, 1),
|
| 39 |
-
nn.GELU(),
|
| 40 |
-
)
|
| 41 |
-
self.dino_projector = nn.Sequential(
|
| 42 |
-
nn.Conv2d(dino_feature_dim, 256, 1),
|
| 43 |
-
nn.GELU(),
|
| 44 |
-
nn.Conv2d(256, cfg.dino_reduced_dim, 1),
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
reg_in = cfg.rgb_feature_dim + cfg.depth_feature_dim + cfg.dino_reduced_dim
|
| 48 |
-
self.gaussian_regressor = nn.Sequential(
|
| 49 |
-
nn.Conv2d(reg_in, cfg.gaussian_regressor_channels, 3, 1, 1),
|
| 50 |
-
nn.GELU(),
|
| 51 |
-
nn.Conv2d(cfg.gaussian_regressor_channels, cfg.gaussian_regressor_channels, 3, 1, 1),
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
self.gaussian_adapter = GaussianAdapter(
|
| 55 |
-
GaussianAdapterCfg(
|
| 56 |
-
gaussian_scale_min=cfg.gaussian_scale_min,
|
| 57 |
-
gaussian_scale_max=cfg.gaussian_scale_max,
|
| 58 |
-
sh_degree=cfg.sh_degree,
|
| 59 |
-
)
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
num_gaussian_parameters = self.gaussian_adapter.d_in + 2 + 1
|
| 63 |
-
head_in = cfg.gaussian_regressor_channels + cfg.rgb_feature_dim + cfg.dino_reduced_dim
|
| 64 |
-
self.gaussian_head = nn.Sequential(
|
| 65 |
-
nn.Conv2d(head_in, num_gaussian_parameters, 3, 1, 1, padding_mode="replicate"),
|
| 66 |
-
nn.GELU(),
|
| 67 |
-
nn.Conv2d(num_gaussian_parameters, num_gaussian_parameters, 3, 1, 1, padding_mode="replicate"),
|
| 68 |
-
)
|
| 69 |
-
|
| 70 |
-
@torch.no_grad()
|
| 71 |
-
def load_from_infinisplat_checkpoint(self, checkpoint_path: str) -> None:
|
| 72 |
-
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 73 |
-
state_dict = checkpoint.get("state_dict", checkpoint)
|
| 74 |
-
|
| 75 |
-
own_sd = self.state_dict()
|
| 76 |
-
load_sd = {}
|
| 77 |
-
for k, _ in own_sd.items():
|
| 78 |
-
prefixed = f"encoder.{k}"
|
| 79 |
-
if prefixed in state_dict and state_dict[prefixed].shape == own_sd[k].shape:
|
| 80 |
-
load_sd[k] = state_dict[prefixed]
|
| 81 |
-
self.load_state_dict(load_sd, strict=False)
|
| 82 |
-
|
| 83 |
-
def _tokens_to_feature_map(self, dino_tokens: torch.Tensor, h: int, w: int) -> torch.Tensor:
|
| 84 |
-
b, n_all, c = dino_tokens.shape
|
| 85 |
-
patch_h = h // 16
|
| 86 |
-
patch_w = w // 16
|
| 87 |
-
n_patch = patch_h * patch_w
|
| 88 |
-
if n_all < n_patch:
|
| 89 |
-
raise ValueError(f"Invalid token count: got {n_all}, expected at least {n_patch}")
|
| 90 |
-
n_reg = n_all - n_patch
|
| 91 |
-
patch_tokens = dino_tokens[:, n_reg:, :] # [B, patch_h*patch_w, C]
|
| 92 |
-
patch_tokens = patch_tokens.reshape(b, patch_h, patch_w, c).permute(0, 3, 1, 2)
|
| 93 |
-
return torch.nn.functional.interpolate(
|
| 94 |
-
patch_tokens, size=(h, w), mode="bilinear", align_corners=False
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
-
def forward(
|
| 98 |
-
self,
|
| 99 |
-
image: torch.Tensor,
|
| 100 |
-
depthmap: torch.Tensor,
|
| 101 |
-
dino_tokens: torch.Tensor,
|
| 102 |
-
intrinsics: torch.Tensor,
|
| 103 |
-
extrinsics: torch.Tensor,
|
| 104 |
-
) -> Gaussians:
|
| 105 |
-
b, _, h, w = image.shape
|
| 106 |
-
dino_map = self._tokens_to_feature_map(dino_tokens, h, w)
|
| 107 |
-
|
| 108 |
-
rgb_feat = self.rgb_encoder(image)
|
| 109 |
-
depth_feat = self.depth_encoder(depthmap)
|
| 110 |
-
dino_feat = self.dino_projector(dino_map)
|
| 111 |
-
|
| 112 |
-
reg_input = torch.cat([rgb_feat, depth_feat, dino_feat], dim=1)
|
| 113 |
-
reg_feat = self.gaussian_regressor(reg_input)
|
| 114 |
-
head_input = torch.cat([reg_feat, rgb_feat, dino_feat], dim=1)
|
| 115 |
-
raw = self.gaussian_head(head_input) # [B, Cg, H, W]
|
| 116 |
-
|
| 117 |
-
raw = raw.permute(0, 2, 3, 1).reshape(b, h * w, -1) # [B, HW, Cg]
|
| 118 |
-
opacities = torch.sigmoid(raw[..., :1]).squeeze(-1) # [B, HW]
|
| 119 |
-
gaussian_core = raw[..., 1:] # [B, HW, Cg-1]
|
| 120 |
-
|
| 121 |
-
# One surface per pixel in this lightweight integration.
|
| 122 |
-
offset_xy = torch.sigmoid(gaussian_core[..., :2]) # [B, HW, 2], in [0,1]
|
| 123 |
-
raw_gaussians = gaussian_core[..., 2:] # [B, HW, 7+3*d_sh]
|
| 124 |
-
|
| 125 |
-
base = sample_image_grid(h, w, image.device).unsqueeze(0).expand(b, -1, -1)
|
| 126 |
-
coords = base + (offset_xy - 0.5)
|
| 127 |
-
coords[..., 0].clamp_(0.0, float(w - 1))
|
| 128 |
-
coords[..., 1].clamp_(0.0, float(h - 1))
|
| 129 |
-
|
| 130 |
-
depths = depthmap[:, 0].reshape(b, -1)
|
| 131 |
-
return self.gaussian_adapter(
|
| 132 |
-
image=image,
|
| 133 |
-
extrinsics=extrinsics,
|
| 134 |
-
intrinsics=intrinsics,
|
| 135 |
-
coordinates_xy=coords,
|
| 136 |
-
depths=depths,
|
| 137 |
-
opacities=opacities,
|
| 138 |
-
raw_gaussians=raw_gaussians,
|
| 139 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
InfiniDepth/gs/projection.py
DELETED
|
@@ -1,53 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
def homogenize_points(points: torch.Tensor) -> torch.Tensor:
|
| 5 |
-
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def homogenize_vectors(vectors: torch.Tensor) -> torch.Tensor:
|
| 9 |
-
return torch.cat([vectors, torch.zeros_like(vectors[..., :1])], dim=-1)
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def transform_cam2world(homogeneous: torch.Tensor, extrinsics: torch.Tensor) -> torch.Tensor:
|
| 13 |
-
return torch.matmul(extrinsics, homogeneous.unsqueeze(-1)).squeeze(-1)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def unproject(coordinates_xy: torch.Tensor, z: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
|
| 17 |
-
"""Unproject pixel-space xy to camera space using z depth.
|
| 18 |
-
|
| 19 |
-
coordinates_xy: [B, N, 2] in pixel coordinates (x, y)
|
| 20 |
-
z: [B, N]
|
| 21 |
-
intrinsics: [B, 3, 3] in pixel units
|
| 22 |
-
"""
|
| 23 |
-
coordinates_h = homogenize_points(coordinates_xy) # [B, N, 3]
|
| 24 |
-
intr_inv = torch.linalg.inv(intrinsics) # [B, 3, 3]
|
| 25 |
-
rays = torch.matmul(intr_inv.unsqueeze(1), coordinates_h.unsqueeze(-1)).squeeze(-1)
|
| 26 |
-
return rays * z.unsqueeze(-1)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def get_world_rays(
|
| 30 |
-
coordinates_xy: torch.Tensor,
|
| 31 |
-
extrinsics: torch.Tensor,
|
| 32 |
-
intrinsics: torch.Tensor,
|
| 33 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 34 |
-
"""Return world-space ray origins and directions.
|
| 35 |
-
|
| 36 |
-
coordinates_xy: [B, N, 2] in pixel coordinates (x, y)
|
| 37 |
-
extrinsics: [B, 4, 4] camera-to-world
|
| 38 |
-
intrinsics: [B, 3, 3] pixel intrinsics
|
| 39 |
-
"""
|
| 40 |
-
ones = torch.ones_like(coordinates_xy[..., 0])
|
| 41 |
-
directions_cam = unproject(coordinates_xy, ones, intrinsics)
|
| 42 |
-
directions_cam = directions_cam / torch.clamp(directions_cam[..., 2:], min=1e-6)
|
| 43 |
-
directions_world = transform_cam2world(homogenize_vectors(directions_cam), extrinsics)[..., :3]
|
| 44 |
-
origins_world = extrinsics[:, None, :3, 3].expand_as(directions_world)
|
| 45 |
-
return origins_world, directions_world
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def sample_image_grid(h: int, w: int, device: torch.device) -> torch.Tensor:
|
| 49 |
-
"""Return pixel center coordinates with shape [H*W, 2], order (x, y)."""
|
| 50 |
-
ys = torch.arange(h, device=device, dtype=torch.float32) + 0.5
|
| 51 |
-
xs = torch.arange(w, device=device, dtype=torch.float32) + 0.5
|
| 52 |
-
grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
|
| 53 |
-
return torch.stack([grid_x, grid_y], dim=-1).reshape(-1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
InfiniDepth/gs/types.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
from typing import Optional
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
@dataclass
|
| 8 |
-
class Gaussians:
|
| 9 |
-
means: torch.Tensor # [B, N, 3]
|
| 10 |
-
harmonics: torch.Tensor # [B, N, 3, d_sh]
|
| 11 |
-
opacities: torch.Tensor # [B, N]
|
| 12 |
-
scales: torch.Tensor # [B, N, 3]
|
| 13 |
-
rotations: torch.Tensor # [B, N, 4]
|
| 14 |
-
covariances: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
InfiniDepth/utils/depth_video_utils.py
DELETED
|
@@ -1,250 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from typing import Optional
|
| 3 |
-
|
| 4 |
-
import cv2
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch
|
| 7 |
-
|
| 8 |
-
from .inference_utils import default_dir_by_input_file, default_video_file_by_input
|
| 9 |
-
from .io_utils import filter_depth_noise_numpy
|
| 10 |
-
from .io_utils import save_sampled_point_clouds
|
| 11 |
-
from .moge_utils import estimate_metric_depth_with_moge2
|
| 12 |
-
from .sampling_utils import SAMPLING_METHODS
|
| 13 |
-
from .vis_utils import colorize_depth_maps
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def prepare_rgb_frame(
|
| 17 |
-
frame_bgr: np.ndarray,
|
| 18 |
-
input_size: tuple[int, int],
|
| 19 |
-
device: torch.device,
|
| 20 |
-
) -> tuple[torch.Tensor, torch.Tensor, tuple[int, int]]:
|
| 21 |
-
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
|
| 22 |
-
org_h, org_w = frame_rgb.shape[:2]
|
| 23 |
-
org_img = torch.from_numpy(frame_rgb).permute(2, 0, 1).unsqueeze(0).float() / 255.0
|
| 24 |
-
|
| 25 |
-
resized = cv2.resize(frame_rgb, input_size[::-1], interpolation=cv2.INTER_AREA)
|
| 26 |
-
image = torch.from_numpy(resized).permute(2, 0, 1).unsqueeze(0).float() / 255.0
|
| 27 |
-
image = image.to(device)
|
| 28 |
-
return org_img, image, (org_h, org_w)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def depth_frame_to_metric_depth(depth_frame: np.ndarray, depth_video_scale: float) -> np.ndarray:
|
| 32 |
-
if depth_frame.ndim == 3:
|
| 33 |
-
# Assume grayscale/depth-like content is stored in channels.
|
| 34 |
-
depth_raw = depth_frame[:, :, 0]
|
| 35 |
-
else:
|
| 36 |
-
depth_raw = depth_frame
|
| 37 |
-
return depth_raw.astype(np.float32) / max(depth_video_scale, 1e-8)
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def sample_sparse_prompt(
|
| 41 |
-
depth: np.ndarray,
|
| 42 |
-
depth_mask: np.ndarray,
|
| 43 |
-
num_samples: int,
|
| 44 |
-
) -> np.ndarray:
|
| 45 |
-
valid_depth = depth * depth_mask
|
| 46 |
-
if (valid_depth > 0.1).sum() <= num_samples:
|
| 47 |
-
return valid_depth
|
| 48 |
-
|
| 49 |
-
flat = valid_depth.reshape(-1)
|
| 50 |
-
nonzero_index = np.array(list(np.nonzero(flat > 0.1))).squeeze()
|
| 51 |
-
index = np.random.permutation(nonzero_index)[:num_samples]
|
| 52 |
-
sample_mask = np.ones_like(flat)
|
| 53 |
-
sample_mask[index] = 0.0
|
| 54 |
-
flat[sample_mask.astype(bool)] = 0.0
|
| 55 |
-
return flat.reshape(valid_depth.shape)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def prepare_prompt_from_depth_frame(
|
| 59 |
-
depth_frame: np.ndarray,
|
| 60 |
-
input_size: tuple[int, int],
|
| 61 |
-
depth_video_scale: float,
|
| 62 |
-
num_samples: int,
|
| 63 |
-
min_prompt: float,
|
| 64 |
-
max_prompt: float,
|
| 65 |
-
enable_noise_filter: bool,
|
| 66 |
-
filter_std_threshold: float,
|
| 67 |
-
filter_median_threshold: float,
|
| 68 |
-
filter_gradient_threshold: float,
|
| 69 |
-
filter_min_neighbors: int,
|
| 70 |
-
device: torch.device,
|
| 71 |
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 72 |
-
depth = depth_frame_to_metric_depth(depth_frame, depth_video_scale)
|
| 73 |
-
depth = cv2.resize(depth, input_size[::-1], interpolation=cv2.INTER_NEAREST)
|
| 74 |
-
|
| 75 |
-
if enable_noise_filter:
|
| 76 |
-
initial_mask = ((depth > min_prompt) & (depth < max_prompt)).astype(np.float32)
|
| 77 |
-
depth, depth_mask = filter_depth_noise_numpy(
|
| 78 |
-
depth=depth,
|
| 79 |
-
depth_mask=initial_mask,
|
| 80 |
-
std_threshold=filter_std_threshold,
|
| 81 |
-
median_threshold=filter_median_threshold,
|
| 82 |
-
gradient_threshold=filter_gradient_threshold,
|
| 83 |
-
min_neighbors=filter_min_neighbors,
|
| 84 |
-
)
|
| 85 |
-
else:
|
| 86 |
-
depth_mask = ((depth > min_prompt) & (depth < max_prompt)).astype(np.float32)
|
| 87 |
-
|
| 88 |
-
prompt_depth = sample_sparse_prompt(depth, depth_mask, num_samples=num_samples)
|
| 89 |
-
|
| 90 |
-
gt_depth_t = torch.from_numpy(depth).unsqueeze(0).unsqueeze(0).float().to(device)
|
| 91 |
-
prompt_depth_t = torch.from_numpy(prompt_depth).unsqueeze(0).unsqueeze(0).float().to(device)
|
| 92 |
-
depth_mask_t = torch.from_numpy(depth_mask).unsqueeze(0).unsqueeze(0).float().to(device)
|
| 93 |
-
return gt_depth_t, prompt_depth_t, depth_mask_t
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def ensure_depth_map(pred_depth: torch.Tensor, h_sample: int, w_sample: int) -> torch.Tensor:
|
| 97 |
-
if pred_depth.ndim == 4 and pred_depth.shape[-2:] == (h_sample, w_sample):
|
| 98 |
-
return pred_depth
|
| 99 |
-
|
| 100 |
-
if pred_depth.ndim == 3:
|
| 101 |
-
b, d1, d2 = pred_depth.shape
|
| 102 |
-
if d1 == h_sample * w_sample and d2 == 1:
|
| 103 |
-
return pred_depth.permute(0, 2, 1).reshape(b, 1, h_sample, w_sample)
|
| 104 |
-
if d1 == 1 and d2 == h_sample * w_sample:
|
| 105 |
-
return pred_depth.reshape(b, 1, h_sample, w_sample)
|
| 106 |
-
if d1 == h_sample and d2 == w_sample:
|
| 107 |
-
return pred_depth.unsqueeze(1)
|
| 108 |
-
|
| 109 |
-
raise ValueError(
|
| 110 |
-
f"Unsupported pred_depth shape: {tuple(pred_depth.shape)} for target ({h_sample}, {w_sample})"
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def build_query_coords(
|
| 115 |
-
h_sample: int,
|
| 116 |
-
w_sample: int,
|
| 117 |
-
device: torch.device,
|
| 118 |
-
) -> torch.Tensor:
|
| 119 |
-
return SAMPLING_METHODS["2d_uniform"]((h_sample, w_sample)).unsqueeze(0).to(device)
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
def resolve_video_output_paths(
|
| 123 |
-
input_video_path: str,
|
| 124 |
-
depth_output_video_path: Optional[str],
|
| 125 |
-
pcd_output_dir: Optional[str],
|
| 126 |
-
save_depth_video: bool,
|
| 127 |
-
save_pcd: bool,
|
| 128 |
-
) -> tuple[str, str]:
|
| 129 |
-
resolved_depth_video_path = depth_output_video_path or default_video_file_by_input(
|
| 130 |
-
input_video_path,
|
| 131 |
-
"pred_depth_video",
|
| 132 |
-
"pred_depth.mp4",
|
| 133 |
-
)
|
| 134 |
-
resolved_pcd_output_dir = pcd_output_dir or default_dir_by_input_file(input_video_path, "pred_pcd_frames")
|
| 135 |
-
|
| 136 |
-
if save_depth_video:
|
| 137 |
-
os.makedirs(os.path.dirname(resolved_depth_video_path) or ".", exist_ok=True)
|
| 138 |
-
if save_pcd:
|
| 139 |
-
os.makedirs(resolved_pcd_output_dir, exist_ok=True)
|
| 140 |
-
return resolved_depth_video_path, resolved_pcd_output_dir
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
def prepare_video_prompt(
|
| 144 |
-
depth_frame: Optional[np.ndarray],
|
| 145 |
-
image: torch.Tensor,
|
| 146 |
-
input_size: tuple[int, int],
|
| 147 |
-
depth_video_scale: float,
|
| 148 |
-
num_samples: int,
|
| 149 |
-
min_prompt: float,
|
| 150 |
-
max_prompt: float,
|
| 151 |
-
enable_noise_filter: bool,
|
| 152 |
-
filter_std_threshold: float,
|
| 153 |
-
filter_median_threshold: float,
|
| 154 |
-
filter_gradient_threshold: float,
|
| 155 |
-
filter_min_neighbors: int,
|
| 156 |
-
moge2_pretrained: str,
|
| 157 |
-
moge2_use_fp16: bool,
|
| 158 |
-
moge2_resolution_level: int,
|
| 159 |
-
moge2_num_tokens: Optional[int],
|
| 160 |
-
moge2_threshold: float,
|
| 161 |
-
device: torch.device,
|
| 162 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 163 |
-
if depth_frame is not None:
|
| 164 |
-
_, prompt_depth, depth_mask = prepare_prompt_from_depth_frame(
|
| 165 |
-
depth_frame=depth_frame,
|
| 166 |
-
input_size=input_size,
|
| 167 |
-
depth_video_scale=depth_video_scale,
|
| 168 |
-
num_samples=num_samples,
|
| 169 |
-
min_prompt=min_prompt,
|
| 170 |
-
max_prompt=max_prompt,
|
| 171 |
-
enable_noise_filter=enable_noise_filter,
|
| 172 |
-
filter_std_threshold=filter_std_threshold,
|
| 173 |
-
filter_median_threshold=filter_median_threshold,
|
| 174 |
-
filter_gradient_threshold=filter_gradient_threshold,
|
| 175 |
-
filter_min_neighbors=filter_min_neighbors,
|
| 176 |
-
device=device,
|
| 177 |
-
)
|
| 178 |
-
return prompt_depth, depth_mask
|
| 179 |
-
|
| 180 |
-
pred_depth_prompt, depth_mask = estimate_metric_depth_with_moge2(
|
| 181 |
-
image=image,
|
| 182 |
-
pretrained_model_name_or_path=moge2_pretrained,
|
| 183 |
-
use_fp16=moge2_use_fp16,
|
| 184 |
-
resolution_level=moge2_resolution_level,
|
| 185 |
-
num_tokens=moge2_num_tokens,
|
| 186 |
-
threshold=moge2_threshold,
|
| 187 |
-
)
|
| 188 |
-
return pred_depth_prompt.to(device), depth_mask.to(device)
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
def write_depth_video_frame(
|
| 192 |
-
pred_depthmap: torch.Tensor,
|
| 193 |
-
depth_writer: Optional[cv2.VideoWriter],
|
| 194 |
-
writer_size: Optional[tuple[int, int]],
|
| 195 |
-
depth_output_video_path: str,
|
| 196 |
-
final_fps: float,
|
| 197 |
-
) -> tuple[cv2.VideoWriter, tuple[int, int]]:
|
| 198 |
-
depth_np = pred_depthmap[0, 0].detach().cpu().numpy()
|
| 199 |
-
valid = np.isfinite(depth_np) & (depth_np > 0)
|
| 200 |
-
if np.any(valid):
|
| 201 |
-
depth_min, depth_max = np.percentile(depth_np[valid], [1.0, 99.0]).tolist()
|
| 202 |
-
if depth_max <= depth_min:
|
| 203 |
-
depth_max = depth_min + 1e-6
|
| 204 |
-
else:
|
| 205 |
-
depth_min, depth_max = 0.0, 1.0
|
| 206 |
-
|
| 207 |
-
vis_depth = colorize_depth_maps(depth_np, min_depth=depth_min, max_depth=depth_max, cmap="Spectral")
|
| 208 |
-
vis_bgr = cv2.cvtColor(vis_depth, cv2.COLOR_RGB2BGR)
|
| 209 |
-
|
| 210 |
-
if depth_writer is None:
|
| 211 |
-
writer_size = (vis_bgr.shape[1], vis_bgr.shape[0])
|
| 212 |
-
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 213 |
-
depth_writer = cv2.VideoWriter(depth_output_video_path, fourcc, float(final_fps), writer_size)
|
| 214 |
-
if not depth_writer.isOpened():
|
| 215 |
-
raise RuntimeError(f"Failed to open depth video writer: {depth_output_video_path}")
|
| 216 |
-
|
| 217 |
-
if writer_size is None:
|
| 218 |
-
raise RuntimeError("writer_size should not be None after depth_writer initialization.")
|
| 219 |
-
|
| 220 |
-
if (vis_bgr.shape[1], vis_bgr.shape[0]) != writer_size:
|
| 221 |
-
vis_bgr = cv2.resize(vis_bgr, writer_size, interpolation=cv2.INTER_AREA)
|
| 222 |
-
depth_writer.write(vis_bgr)
|
| 223 |
-
return depth_writer, writer_size
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
def save_video_frame_point_cloud(
|
| 227 |
-
query_2d_uniform_coord: torch.Tensor,
|
| 228 |
-
pred_2d_uniform_depth: torch.Tensor,
|
| 229 |
-
image: torch.Tensor,
|
| 230 |
-
fx: float,
|
| 231 |
-
fy: float,
|
| 232 |
-
cx: float,
|
| 233 |
-
cy: float,
|
| 234 |
-
pcd_output_dir: str,
|
| 235 |
-
frame_id: int,
|
| 236 |
-
enable_filter_flying_points: bool,
|
| 237 |
-
) -> str:
|
| 238 |
-
pcd_save_path = os.path.join(pcd_output_dir, f"frame_{frame_id:06d}.ply")
|
| 239 |
-
save_sampled_point_clouds(
|
| 240 |
-
sampled_coord=query_2d_uniform_coord.squeeze(0).detach().cpu(),
|
| 241 |
-
sampled_depth=pred_2d_uniform_depth.squeeze(0).squeeze(-1).detach().cpu(),
|
| 242 |
-
rgb_image=image.squeeze(0).detach().cpu(),
|
| 243 |
-
fx=fx,
|
| 244 |
-
fy=fy,
|
| 245 |
-
cx=cx,
|
| 246 |
-
cy=cy,
|
| 247 |
-
output_path=pcd_save_path,
|
| 248 |
-
filter_flying_points=enable_filter_flying_points,
|
| 249 |
-
)
|
| 250 |
-
return pcd_save_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
InfiniDepth/utils/gs_utils.py
DELETED
|
@@ -1,289 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
import os
|
| 3 |
-
from typing import Optional
|
| 4 |
-
|
| 5 |
-
import imageio.v2 as imageio
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
from InfiniDepth.gs import Gaussians
|
| 11 |
-
from InfiniDepth.gs.projection import homogenize_points, transform_cam2world, unproject
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def _safe_normalize(v: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
| 15 |
-
return v / torch.clamp(torch.norm(v), min=eps)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def _look_at_c2w(position: torch.Tensor, target: torch.Tensor, up_hint: torch.Tensor) -> torch.Tensor:
|
| 19 |
-
forward = _safe_normalize(target - position)
|
| 20 |
-
# Camera basis is stored as [right, up, forward]. Using cross(forward, up)
|
| 21 |
-
# flips the x-axis and produces a horizontally mirrored render. Keep the
|
| 22 |
-
# original up hint and derive a right-handed basis instead.
|
| 23 |
-
right = torch.cross(up_hint, forward, dim=0)
|
| 24 |
-
if torch.norm(right) < 1e-6:
|
| 25 |
-
right = torch.cross(
|
| 26 |
-
torch.tensor([1.0, 0.0, 0.0], device=position.device, dtype=position.dtype),
|
| 27 |
-
forward,
|
| 28 |
-
dim=0,
|
| 29 |
-
)
|
| 30 |
-
right = _safe_normalize(right)
|
| 31 |
-
up = _safe_normalize(torch.cross(forward, right, dim=0))
|
| 32 |
-
|
| 33 |
-
c2w = torch.eye(4, device=position.device, dtype=position.dtype)
|
| 34 |
-
c2w[:3, 0] = right
|
| 35 |
-
c2w[:3, 1] = up
|
| 36 |
-
c2w[:3, 2] = forward
|
| 37 |
-
c2w[:3, 3] = position
|
| 38 |
-
return c2w
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def _build_orbit_poses(
|
| 42 |
-
base_c2w: torch.Tensor,
|
| 43 |
-
target: torch.Tensor,
|
| 44 |
-
num_frames: int,
|
| 45 |
-
radius: float,
|
| 46 |
-
vertical: float,
|
| 47 |
-
forward_amp: float,
|
| 48 |
-
) -> list[torch.Tensor]:
|
| 49 |
-
base_pos = base_c2w[:3, 3]
|
| 50 |
-
right = base_c2w[:3, 0]
|
| 51 |
-
up = base_c2w[:3, 1]
|
| 52 |
-
forward = base_c2w[:3, 2]
|
| 53 |
-
|
| 54 |
-
poses: list[torch.Tensor] = []
|
| 55 |
-
n = max(2, int(num_frames))
|
| 56 |
-
for i in range(n):
|
| 57 |
-
theta = 2.0 * math.pi * float(i) / float(n)
|
| 58 |
-
offset = (
|
| 59 |
-
right * (radius * math.sin(theta))
|
| 60 |
-
+ up * (vertical * math.sin(2.0 * theta))
|
| 61 |
-
+ forward * (forward_amp * 0.5 * (1.0 - math.cos(theta)))
|
| 62 |
-
)
|
| 63 |
-
pos = base_pos + offset
|
| 64 |
-
poses.append(_look_at_c2w(pos, target, up))
|
| 65 |
-
return poses
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def _build_swing_poses(
|
| 69 |
-
base_c2w: torch.Tensor,
|
| 70 |
-
num_frames: int,
|
| 71 |
-
radius: float,
|
| 72 |
-
forward_amp: float,
|
| 73 |
-
) -> list[torch.Tensor]:
|
| 74 |
-
base_pos = base_c2w[:3, 3]
|
| 75 |
-
right = base_c2w[:3, 0]
|
| 76 |
-
forward = base_c2w[:3, 2]
|
| 77 |
-
|
| 78 |
-
key_offsets = [
|
| 79 |
-
torch.zeros(3, device=base_pos.device, dtype=base_pos.dtype),
|
| 80 |
-
-right * radius,
|
| 81 |
-
right * radius,
|
| 82 |
-
forward * forward_amp,
|
| 83 |
-
torch.zeros(3, device=base_pos.device, dtype=base_pos.dtype),
|
| 84 |
-
]
|
| 85 |
-
|
| 86 |
-
poses: list[torch.Tensor] = []
|
| 87 |
-
seg_frames = max(1, int(num_frames) // (len(key_offsets) - 1))
|
| 88 |
-
for seg in range(len(key_offsets) - 1):
|
| 89 |
-
p0 = base_pos + key_offsets[seg]
|
| 90 |
-
p1 = base_pos + key_offsets[seg + 1]
|
| 91 |
-
for i in range(seg_frames):
|
| 92 |
-
alpha = 1.0 if seg_frames == 1 else float(i) / float(seg_frames - 1)
|
| 93 |
-
pos = (1.0 - alpha) * p0 + alpha * p1
|
| 94 |
-
pose = base_c2w.clone()
|
| 95 |
-
pose[:3, 3] = pos
|
| 96 |
-
if seg > 0 and i == 0:
|
| 97 |
-
continue
|
| 98 |
-
poses.append(pose)
|
| 99 |
-
return poses
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def _scale_intrinsics_for_render(
|
| 103 |
-
intrinsics: torch.Tensor,
|
| 104 |
-
src_h: int,
|
| 105 |
-
src_w: int,
|
| 106 |
-
dst_h: int,
|
| 107 |
-
dst_w: int,
|
| 108 |
-
) -> torch.Tensor:
|
| 109 |
-
scaled = intrinsics.clone()
|
| 110 |
-
sx = float(dst_w) / float(src_w)
|
| 111 |
-
sy = float(dst_h) / float(src_h)
|
| 112 |
-
scaled[0, 0] *= sx
|
| 113 |
-
scaled[1, 1] *= sy
|
| 114 |
-
scaled[0, 2] *= sx
|
| 115 |
-
scaled[1, 2] *= sy
|
| 116 |
-
return scaled
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def _render_gaussian_frame(
|
| 120 |
-
rasterization_fn,
|
| 121 |
-
means: torch.Tensor,
|
| 122 |
-
harmonics: torch.Tensor,
|
| 123 |
-
opacities: torch.Tensor,
|
| 124 |
-
scales: torch.Tensor,
|
| 125 |
-
rotations: torch.Tensor,
|
| 126 |
-
c2w: torch.Tensor,
|
| 127 |
-
intrinsics: torch.Tensor,
|
| 128 |
-
render_h: int,
|
| 129 |
-
render_w: int,
|
| 130 |
-
bg_color: tuple[float, float, float],
|
| 131 |
-
) -> np.ndarray:
|
| 132 |
-
xyzs = means.unsqueeze(0).float() # [1, N, 3]
|
| 133 |
-
opacitys = opacities.unsqueeze(0).float() # [1, N]
|
| 134 |
-
rotations_b = rotations.unsqueeze(0).float() # [1, N, 4]
|
| 135 |
-
scales_b = scales.unsqueeze(0).float() # [1, N, 3]
|
| 136 |
-
|
| 137 |
-
# [N, 3, d_sh] -> [1, N, d_sh, 3]
|
| 138 |
-
features = harmonics.unsqueeze(0).permute(0, 1, 3, 2).contiguous().float()
|
| 139 |
-
d_sh = features.shape[-2]
|
| 140 |
-
sh_degree = int(round(math.sqrt(float(d_sh)) - 1.0))
|
| 141 |
-
|
| 142 |
-
w2c = torch.linalg.inv(c2w).unsqueeze(0).unsqueeze(0).float() # [1, 1, 4, 4]
|
| 143 |
-
Ks = intrinsics.unsqueeze(0).unsqueeze(0).float() # [1, 1, 3, 3]
|
| 144 |
-
backgrounds = torch.tensor(bg_color, dtype=torch.float32, device=xyzs.device).view(1, 1, 3)
|
| 145 |
-
|
| 146 |
-
rendering, _, _ = rasterization_fn(
|
| 147 |
-
xyzs,
|
| 148 |
-
rotations_b,
|
| 149 |
-
scales_b,
|
| 150 |
-
opacitys,
|
| 151 |
-
features,
|
| 152 |
-
w2c,
|
| 153 |
-
Ks,
|
| 154 |
-
render_w,
|
| 155 |
-
render_h,
|
| 156 |
-
sh_degree=sh_degree,
|
| 157 |
-
render_mode="RGB+D",
|
| 158 |
-
packed=False,
|
| 159 |
-
backgrounds=backgrounds,
|
| 160 |
-
covars=None,
|
| 161 |
-
eps2d=1e-8,
|
| 162 |
-
)
|
| 163 |
-
|
| 164 |
-
rgb = rendering[0, 0, :, :, :3].clamp(0.0, 1.0)
|
| 165 |
-
return (rgb * 255.0).to(torch.uint8).cpu().numpy()
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
def _render_novel_video(
|
| 169 |
-
means: torch.Tensor,
|
| 170 |
-
harmonics: torch.Tensor,
|
| 171 |
-
opacities: torch.Tensor,
|
| 172 |
-
scales: torch.Tensor,
|
| 173 |
-
rotations: torch.Tensor,
|
| 174 |
-
base_c2w: torch.Tensor,
|
| 175 |
-
intrinsics: torch.Tensor,
|
| 176 |
-
render_h: int,
|
| 177 |
-
render_w: int,
|
| 178 |
-
video_path: str,
|
| 179 |
-
trajectory: str,
|
| 180 |
-
num_frames: int,
|
| 181 |
-
fps: int,
|
| 182 |
-
radius: float,
|
| 183 |
-
vertical: float,
|
| 184 |
-
forward_amp: float,
|
| 185 |
-
bg_color: tuple[float, float, float],
|
| 186 |
-
) -> None:
|
| 187 |
-
try:
|
| 188 |
-
from gsplat import rasterization as rasterization_fn
|
| 189 |
-
except ImportError as exc:
|
| 190 |
-
raise RuntimeError("Novel-view rendering requires gsplat. Please install gsplat first.") from exc
|
| 191 |
-
|
| 192 |
-
target = means.mean(dim=0)
|
| 193 |
-
if trajectory == "swing":
|
| 194 |
-
poses = _build_swing_poses(base_c2w, num_frames, radius, forward_amp)
|
| 195 |
-
else:
|
| 196 |
-
poses = _build_orbit_poses(base_c2w, target, num_frames, radius, vertical, forward_amp)
|
| 197 |
-
|
| 198 |
-
video_dir = os.path.dirname(video_path)
|
| 199 |
-
if video_dir:
|
| 200 |
-
os.makedirs(video_dir, exist_ok=True)
|
| 201 |
-
|
| 202 |
-
try:
|
| 203 |
-
with imageio.get_writer(
|
| 204 |
-
video_path,
|
| 205 |
-
fps=float(max(1, fps)),
|
| 206 |
-
codec="libx264",
|
| 207 |
-
macro_block_size=1,
|
| 208 |
-
) as writer:
|
| 209 |
-
for pose in poses:
|
| 210 |
-
frame_rgb = _render_gaussian_frame(
|
| 211 |
-
rasterization_fn=rasterization_fn,
|
| 212 |
-
means=means,
|
| 213 |
-
harmonics=harmonics,
|
| 214 |
-
opacities=opacities,
|
| 215 |
-
scales=scales,
|
| 216 |
-
rotations=rotations,
|
| 217 |
-
c2w=pose,
|
| 218 |
-
intrinsics=intrinsics,
|
| 219 |
-
render_h=render_h,
|
| 220 |
-
render_w=render_w,
|
| 221 |
-
bg_color=bg_color,
|
| 222 |
-
)
|
| 223 |
-
writer.append_data(frame_rgb)
|
| 224 |
-
except Exception as exc:
|
| 225 |
-
raise RuntimeError(f"Failed to write video with imageio: {video_path}") from exc
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
def _build_sparse_uniform_gaussians(
|
| 229 |
-
dense_gaussians,
|
| 230 |
-
query_3d_uniform_coord: torch.Tensor,
|
| 231 |
-
pred_depth_3d: torch.Tensor,
|
| 232 |
-
intrinsics: torch.Tensor,
|
| 233 |
-
extrinsics: torch.Tensor,
|
| 234 |
-
h: int,
|
| 235 |
-
w: int,
|
| 236 |
-
) -> Gaussians:
|
| 237 |
-
"""Convert dense pixel gaussians to sparse 3d-uniform gaussians.
|
| 238 |
-
"""
|
| 239 |
-
if dense_gaussians.means.shape[0] != 1:
|
| 240 |
-
raise ValueError("Current strict-aligned sparse interpolation only supports batch size 1.")
|
| 241 |
-
|
| 242 |
-
sparse_coords_normalized = query_3d_uniform_coord[0] # [N,2], [y,x]
|
| 243 |
-
sparse_depths = pred_depth_3d[0] # [N,1]
|
| 244 |
-
|
| 245 |
-
# Convert normalized coordinates to pixel coordinates
|
| 246 |
-
p_y = ((sparse_coords_normalized[:, 0] + 1.0) * (h / 2.0)) - 0.5
|
| 247 |
-
p_x = ((sparse_coords_normalized[:, 1] + 1.0) * (w / 2.0)) - 0.5
|
| 248 |
-
xy_coords = torch.stack([p_x, p_y], dim=-1) # [N,2], [x,y]
|
| 249 |
-
|
| 250 |
-
depth_values = sparse_depths.squeeze(-1)
|
| 251 |
-
camera_points = unproject(xy_coords.unsqueeze(0), depth_values.unsqueeze(0), intrinsics)[0]
|
| 252 |
-
camera_points_hom = homogenize_points(camera_points)
|
| 253 |
-
world_points = transform_cam2world(camera_points_hom.unsqueeze(0), extrinsics)[0]
|
| 254 |
-
sparse_pts_world = world_points[..., :3] # [N,3]
|
| 255 |
-
|
| 256 |
-
grid = sparse_coords_normalized[:, [1, 0]].unsqueeze(0).unsqueeze(0) # [1,1,N,2]
|
| 257 |
-
|
| 258 |
-
def sample_attribute(attr):
|
| 259 |
-
if attr.dim() == 2:
|
| 260 |
-
attr_spatial = attr.view(1, 1, h, w)
|
| 261 |
-
sampled = F.grid_sample(attr_spatial, grid, mode="bilinear", align_corners=False)
|
| 262 |
-
return sampled.squeeze(0).squeeze(0)
|
| 263 |
-
if attr.dim() == 3:
|
| 264 |
-
d = attr.shape[-1]
|
| 265 |
-
attr_spatial = attr.view(1, h, w, d).permute(0, 3, 1, 2)
|
| 266 |
-
sampled = F.grid_sample(attr_spatial, grid, mode="bilinear", align_corners=False)
|
| 267 |
-
return sampled.squeeze(2).permute(0, 2, 1)
|
| 268 |
-
if attr.dim() == 4:
|
| 269 |
-
d1, d2 = attr.shape[-2:]
|
| 270 |
-
attr_flat = attr.view(1, h, w, d1 * d2).permute(0, 3, 1, 2)
|
| 271 |
-
sampled = F.grid_sample(attr_flat, grid, mode="bilinear", align_corners=False)
|
| 272 |
-
return sampled.squeeze(2).permute(0, 2, 1).view(1, -1, d1, d2)
|
| 273 |
-
raise ValueError(f"Unsupported attribute dimension: {attr.dim()}")
|
| 274 |
-
|
| 275 |
-
sparse_harmonics = sample_attribute(dense_gaussians.harmonics)
|
| 276 |
-
sparse_opacities = sample_attribute(dense_gaussians.opacities)
|
| 277 |
-
sparse_scales = sample_attribute(dense_gaussians.scales)
|
| 278 |
-
sparse_rotations = sample_attribute(dense_gaussians.rotations)
|
| 279 |
-
sparse_rotations = sparse_rotations / (torch.norm(sparse_rotations, dim=-1, keepdim=True) + 1e-8)
|
| 280 |
-
|
| 281 |
-
return Gaussians(
|
| 282 |
-
means=sparse_pts_world.unsqueeze(0),
|
| 283 |
-
covariances=None,
|
| 284 |
-
harmonics=sparse_harmonics,
|
| 285 |
-
opacities=sparse_opacities,
|
| 286 |
-
scales=sparse_scales,
|
| 287 |
-
rotations=sparse_rotations,
|
| 288 |
-
)
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
InfiniDepth/utils/inference_utils.py
CHANGED
|
@@ -6,8 +6,6 @@ import cv2
|
|
| 6 |
import torch
|
| 7 |
import torch.nn.functional as F
|
| 8 |
|
| 9 |
-
from InfiniDepth.gs import Gaussians
|
| 10 |
-
|
| 11 |
from .io_utils import load_depth
|
| 12 |
from .moge_utils import estimate_metric_depth_with_moge2
|
| 13 |
from .vis_utils import build_sky_model, run_skyseg
|
|
@@ -331,67 +329,4 @@ def build_camera_matrices(
|
|
| 331 |
device=device,
|
| 332 |
).unsqueeze(0).expand(batch, -1, -1)
|
| 333 |
extrinsics = torch.eye(4, dtype=torch.float32, device=device).unsqueeze(0).expand(batch, -1, -1)
|
| 334 |
-
return fx, fy, cx, cy, intrinsics, extrinsics
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
def filter_gaussians_by_depth_ratio(
|
| 338 |
-
pixel_gaussians: Gaussians,
|
| 339 |
-
extrinsics: torch.Tensor,
|
| 340 |
-
keep_far_ratio: float,
|
| 341 |
-
) -> tuple[Gaussians, int, int, float, float]:
|
| 342 |
-
camera_position = extrinsics[0, :3, 3]
|
| 343 |
-
gaussian_means = pixel_gaussians.means[0]
|
| 344 |
-
distances = torch.norm(gaussian_means - camera_position.unsqueeze(0), dim=-1)
|
| 345 |
-
max_depth = distances.max()
|
| 346 |
-
depth_threshold = max_depth * keep_far_ratio
|
| 347 |
-
near_mask = distances <= depth_threshold
|
| 348 |
-
num_filtered = int((~near_mask).sum().item())
|
| 349 |
-
num_kept = int(near_mask.sum().item())
|
| 350 |
-
filtered_gaussians = Gaussians(
|
| 351 |
-
means=pixel_gaussians.means[:, near_mask, :],
|
| 352 |
-
covariances=None,
|
| 353 |
-
harmonics=pixel_gaussians.harmonics[:, near_mask, :, :],
|
| 354 |
-
opacities=pixel_gaussians.opacities[:, near_mask],
|
| 355 |
-
scales=pixel_gaussians.scales[:, near_mask, :],
|
| 356 |
-
rotations=pixel_gaussians.rotations[:, near_mask, :],
|
| 357 |
-
)
|
| 358 |
-
return filtered_gaussians, num_filtered, num_kept, float(depth_threshold.item()), float(max_depth.item())
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
def filter_gaussians_by_min_opacity(pixel_gaussians: Gaussians, min_opacity: float) -> Gaussians:
|
| 362 |
-
if min_opacity <= 0.0:
|
| 363 |
-
return pixel_gaussians
|
| 364 |
-
keep = pixel_gaussians.opacities[0] >= min_opacity
|
| 365 |
-
return Gaussians(
|
| 366 |
-
means=pixel_gaussians.means[:, keep, :],
|
| 367 |
-
covariances=None,
|
| 368 |
-
harmonics=pixel_gaussians.harmonics[:, keep, :, :],
|
| 369 |
-
opacities=pixel_gaussians.opacities[:, keep],
|
| 370 |
-
scales=pixel_gaussians.scales[:, keep, :],
|
| 371 |
-
rotations=pixel_gaussians.rotations[:, keep, :],
|
| 372 |
-
)
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
def unpack_gaussians_for_export(
|
| 376 |
-
pixel_gaussians: Gaussians,
|
| 377 |
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 378 |
-
return (
|
| 379 |
-
pixel_gaussians.means[0],
|
| 380 |
-
pixel_gaussians.harmonics[0],
|
| 381 |
-
pixel_gaussians.opacities[0],
|
| 382 |
-
pixel_gaussians.scales[0],
|
| 383 |
-
pixel_gaussians.rotations[0],
|
| 384 |
-
)
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
def resolve_ply_output_path(
|
| 388 |
-
input_image_path: str,
|
| 389 |
-
model_type: str,
|
| 390 |
-
output_ply_dir: Optional[str] = None,
|
| 391 |
-
output_ply_name: Optional[str] = None,
|
| 392 |
-
) -> tuple[str, str]:
|
| 393 |
-
ply_dir = output_ply_dir or default_dir_by_input_file(input_image_path, "pred_gs")
|
| 394 |
-
os.makedirs(ply_dir, exist_ok=True)
|
| 395 |
-
stem = os.path.splitext(os.path.basename(input_image_path))[0]
|
| 396 |
-
ply_name = output_ply_name or f"{model_type}_{stem}_gaussians.ply"
|
| 397 |
-
return ply_dir, os.path.join(ply_dir, ply_name)
|
|
|
|
| 6 |
import torch
|
| 7 |
import torch.nn.functional as F
|
| 8 |
|
|
|
|
|
|
|
| 9 |
from .io_utils import load_depth
|
| 10 |
from .moge_utils import estimate_metric_depth_with_moge2
|
| 11 |
from .vis_utils import build_sky_model, run_skyseg
|
|
|
|
| 329 |
device=device,
|
| 330 |
).unsqueeze(0).expand(batch, -1, -1)
|
| 331 |
extrinsics = torch.eye(4, dtype=torch.float32, device=device).unsqueeze(0).expand(batch, -1, -1)
|
| 332 |
+
return fx, fy, cx, cy, intrinsics, extrinsics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|