|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from math import isqrt |
|
|
from typing import Literal, Optional |
|
|
|
|
|
import torch |
|
|
from einops import rearrange, repeat |
|
|
from tqdm import tqdm |
|
|
|
|
|
from depth_anything_3.specs import Gaussians |
|
|
from depth_anything_3.utils.camera_trj_helpers import ( |
|
|
interpolate_extrinsics, |
|
|
interpolate_intrinsics, |
|
|
render_dolly_zoom_path, |
|
|
render_stabilization_path, |
|
|
render_wander_path, |
|
|
render_wobble_inter_path, |
|
|
) |
|
|
from depth_anything_3.utils.geometry import affine_inverse, as_homogeneous, get_fov |
|
|
from depth_anything_3.utils.logger import logger |
|
|
|
|
|
try: |
|
|
from gsplat import rasterization |
|
|
except ImportError: |
|
|
logger.warn( |
|
|
"Dependency `gsplat` is required for rendering 3DGS. " |
|
|
"Install via: pip install git+https://github.com/nerfstudio-project/" |
|
|
"gsplat.git@0b4dddf04cb687367602c01196913cde6a743d70" |
|
|
) |
|
|
|
|
|
|
|
|
def render_3dgs( |
|
|
extrinsics: torch.Tensor, |
|
|
intrinsics: torch.Tensor, |
|
|
image_shape: tuple[int, int], |
|
|
gaussian: Gaussians, |
|
|
background_color: Optional[torch.Tensor] = None, |
|
|
use_sh: bool = True, |
|
|
num_view: int = 1, |
|
|
color_mode: Literal["RGB+D", "RGB+ED"] = "RGB+D", |
|
|
**kwargs, |
|
|
) -> tuple[ |
|
|
torch.Tensor, |
|
|
torch.Tensor, |
|
|
]: |
|
|
|
|
|
gaussian_means = gaussian.means |
|
|
gaussian_scales = gaussian.scales |
|
|
gaussian_quats = gaussian.rotations |
|
|
gaussian_opacities = gaussian.opacities |
|
|
gaussian_sh_coefficients = gaussian.harmonics |
|
|
b, _, _ = extrinsics.shape |
|
|
|
|
|
if background_color is None: |
|
|
background_color = repeat(torch.tensor([0.0, 0.0, 0.0]), "c -> b c", b=b).to( |
|
|
gaussian_sh_coefficients |
|
|
) |
|
|
|
|
|
if use_sh: |
|
|
_, _, _, n = gaussian_sh_coefficients.shape |
|
|
degree = isqrt(n) - 1 |
|
|
shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous() |
|
|
else: |
|
|
shs = ( |
|
|
gaussian_sh_coefficients.squeeze(-1).sigmoid().contiguous() |
|
|
) |
|
|
|
|
|
h, w = image_shape |
|
|
|
|
|
fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1) |
|
|
tan_fov_x = (0.5 * fov_x).tan() |
|
|
tan_fov_y = (0.5 * fov_y).tan() |
|
|
focal_length_x = w / (2 * tan_fov_x) |
|
|
focal_length_y = h / (2 * tan_fov_y) |
|
|
|
|
|
view_matrix = extrinsics.float() |
|
|
|
|
|
all_images = [] |
|
|
all_radii = [] |
|
|
all_depths = [] |
|
|
|
|
|
|
|
|
batch_scene = b // num_view |
|
|
|
|
|
def index_i_gs_attr(full_attr, idx): |
|
|
|
|
|
return full_attr[idx] |
|
|
|
|
|
for i in range(batch_scene): |
|
|
K = repeat( |
|
|
torch.tensor( |
|
|
[ |
|
|
[0, 0, w / 2.0], |
|
|
[0, 0, h / 2.0], |
|
|
[0, 0, 1], |
|
|
] |
|
|
), |
|
|
"i j -> v i j", |
|
|
v=num_view, |
|
|
).to(gaussian_means) |
|
|
K[:, 0, 0] = focal_length_x.reshape(batch_scene, num_view)[i] |
|
|
K[:, 1, 1] = focal_length_y.reshape(batch_scene, num_view)[i] |
|
|
|
|
|
i_means = index_i_gs_attr(gaussian_means, i) |
|
|
i_scales = index_i_gs_attr(gaussian_scales, i) |
|
|
i_quats = index_i_gs_attr(gaussian_quats, i) |
|
|
i_opacities = index_i_gs_attr(gaussian_opacities, i) |
|
|
i_colors = index_i_gs_attr(shs, i) |
|
|
i_viewmats = rearrange(view_matrix, "(b v) ... -> b v ...", v=num_view)[i] |
|
|
i_backgrounds = rearrange(background_color, "(b v) ... -> b v ...", v=num_view)[ |
|
|
i |
|
|
] |
|
|
|
|
|
render_colors, render_alphas, info = rasterization( |
|
|
means=i_means, |
|
|
quats=i_quats, |
|
|
scales=i_scales, |
|
|
opacities=i_opacities, |
|
|
colors=i_colors, |
|
|
viewmats=i_viewmats, |
|
|
Ks=K, |
|
|
backgrounds=i_backgrounds, |
|
|
render_mode=color_mode, |
|
|
width=w, |
|
|
height=h, |
|
|
packed=False, |
|
|
sh_degree=degree if use_sh else None, |
|
|
) |
|
|
depth = render_colors[..., -1].unbind(dim=0) |
|
|
|
|
|
image = rearrange(render_colors[..., :3], "v h w c -> v c h w").unbind(dim=0) |
|
|
radii = info["radii"].unbind(dim=0) |
|
|
try: |
|
|
info["means2d"].retain_grad() |
|
|
except Exception: |
|
|
pass |
|
|
all_images.extend(image) |
|
|
all_depths.extend(depth) |
|
|
all_radii.extend(radii) |
|
|
|
|
|
return torch.stack(all_images), torch.stack(all_depths) |
|
|
|
|
|
|
|
|
def run_renderer_in_chunk_w_trj_mode( |
|
|
gaussians: Gaussians, |
|
|
extrinsics: torch.Tensor, |
|
|
intrinsics: torch.Tensor, |
|
|
image_shape: tuple[int, int], |
|
|
chunk_size: Optional[int] = 8, |
|
|
trj_mode: Literal[ |
|
|
"original", |
|
|
"smooth", |
|
|
"interpolate", |
|
|
"interpolate_smooth", |
|
|
"wander", |
|
|
"dolly_zoom", |
|
|
"extend", |
|
|
"wobble_inter", |
|
|
] = "smooth", |
|
|
input_shape: Optional[tuple[int, int]] = None, |
|
|
enable_tqdm: Optional[bool] = False, |
|
|
**kwargs, |
|
|
) -> tuple[ |
|
|
torch.Tensor, |
|
|
torch.Tensor, |
|
|
]: |
|
|
cam2world = affine_inverse(as_homogeneous(extrinsics)) |
|
|
if input_shape is not None: |
|
|
in_h, in_w = input_shape |
|
|
else: |
|
|
in_h, in_w = image_shape |
|
|
intr_normed = intrinsics.clone().detach() |
|
|
intr_normed[..., 0, :] /= in_w |
|
|
intr_normed[..., 1, :] /= in_h |
|
|
if extrinsics.shape[1] <= 1: |
|
|
assert trj_mode in [ |
|
|
"wander", |
|
|
"dolly_zoom", |
|
|
], "Please set trj_mode to 'wander' or 'dolly_zoom' when n_views=1" |
|
|
|
|
|
def _smooth_trj_fn_batch(raw_c2ws, k_size=50): |
|
|
try: |
|
|
smooth_c2ws = torch.stack( |
|
|
[render_stabilization_path(c2w_i, k_size) for c2w_i in raw_c2ws], |
|
|
dim=0, |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"[DEBUG] Path smoothing failed with error: {e}.") |
|
|
smooth_c2ws = raw_c2ws |
|
|
return smooth_c2ws |
|
|
|
|
|
|
|
|
if trj_mode == "original": |
|
|
tgt_c2w = cam2world |
|
|
tgt_intr = intr_normed |
|
|
elif trj_mode == "smooth": |
|
|
tgt_c2w = _smooth_trj_fn_batch(cam2world) |
|
|
tgt_intr = intr_normed |
|
|
elif trj_mode in ["interpolate", "interpolate_smooth", "extend"]: |
|
|
inter_len = 8 |
|
|
total_len = (cam2world.shape[1] - 1) * inter_len |
|
|
if total_len > 24 * 18: |
|
|
inter_len = max(1, 24 * 10 // (cam2world.shape[1] - 1)) |
|
|
if total_len < 24 * 2: |
|
|
inter_len = max(1, 24 * 2 // (cam2world.shape[1] - 1)) |
|
|
|
|
|
if inter_len > 2: |
|
|
t = torch.linspace(0, 1, inter_len, dtype=torch.float32, device=cam2world.device) |
|
|
t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 |
|
|
tgt_c2w_b = [] |
|
|
tgt_intr_b = [] |
|
|
for b_idx in range(cam2world.shape[0]): |
|
|
tgt_c2w = [] |
|
|
tgt_intr = [] |
|
|
for cur_idx in range(cam2world.shape[1] - 1): |
|
|
tgt_c2w.append( |
|
|
interpolate_extrinsics( |
|
|
cam2world[b_idx, cur_idx], cam2world[b_idx, cur_idx + 1], t |
|
|
)[(0 if cur_idx == 0 else 1) :] |
|
|
) |
|
|
tgt_intr.append( |
|
|
interpolate_intrinsics( |
|
|
intr_normed[b_idx, cur_idx], intr_normed[b_idx, cur_idx + 1], t |
|
|
)[(0 if cur_idx == 0 else 1) :] |
|
|
) |
|
|
tgt_c2w_b.append(torch.cat(tgt_c2w)) |
|
|
tgt_intr_b.append(torch.cat(tgt_intr)) |
|
|
tgt_c2w = torch.stack(tgt_c2w_b) |
|
|
tgt_intr = torch.stack(tgt_intr_b) |
|
|
else: |
|
|
tgt_c2w = cam2world |
|
|
tgt_intr = intr_normed |
|
|
if trj_mode in ["interpolate_smooth", "extend"]: |
|
|
tgt_c2w = _smooth_trj_fn_batch(tgt_c2w) |
|
|
if trj_mode == "extend": |
|
|
|
|
|
assert cam2world.shape[0] == 1, "extend only supports for batch_size=1 currently." |
|
|
mid_idx = tgt_c2w.shape[1] // 2 |
|
|
c2w_wd, intr_wd = render_wander_path( |
|
|
tgt_c2w[0, mid_idx], |
|
|
tgt_intr[0, mid_idx], |
|
|
h=in_h, |
|
|
w=in_w, |
|
|
num_frames=max(36, min(60, mid_idx // 2)), |
|
|
max_disp=24.0, |
|
|
) |
|
|
c2w_dz, intr_dz = render_dolly_zoom_path( |
|
|
tgt_c2w[0, mid_idx], |
|
|
tgt_intr[0, mid_idx], |
|
|
h=in_h, |
|
|
w=in_w, |
|
|
num_frames=max(36, min(60, mid_idx // 2)), |
|
|
) |
|
|
tgt_c2w = torch.cat( |
|
|
[ |
|
|
tgt_c2w[:, :mid_idx], |
|
|
c2w_wd.unsqueeze(0), |
|
|
c2w_dz.unsqueeze(0), |
|
|
tgt_c2w[:, mid_idx:], |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
tgt_intr = torch.cat( |
|
|
[ |
|
|
tgt_intr[:, :mid_idx], |
|
|
intr_wd.unsqueeze(0), |
|
|
intr_dz.unsqueeze(0), |
|
|
tgt_intr[:, mid_idx:], |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
elif trj_mode in ["wander", "dolly_zoom"]: |
|
|
if trj_mode == "wander": |
|
|
render_fn = render_wander_path |
|
|
extra_kwargs = {"max_disp": 24.0} |
|
|
else: |
|
|
render_fn = render_dolly_zoom_path |
|
|
extra_kwargs = {"D_focus": 30.0, "max_disp": 2.0} |
|
|
tgt_c2w = [] |
|
|
tgt_intr = [] |
|
|
for b_idx in range(cam2world.shape[0]): |
|
|
c2w_i, intr_i = render_fn( |
|
|
cam2world[b_idx, 0], intr_normed[b_idx, 0], h=in_h, w=in_w, **extra_kwargs |
|
|
) |
|
|
tgt_c2w.append(c2w_i) |
|
|
tgt_intr.append(intr_i) |
|
|
tgt_c2w = torch.stack(tgt_c2w) |
|
|
tgt_intr = torch.stack(tgt_intr) |
|
|
elif trj_mode == "wobble_inter": |
|
|
tgt_c2w, tgt_intr = render_wobble_inter_path( |
|
|
cam2world=cam2world, |
|
|
intr_normed=intr_normed, |
|
|
inter_len=10, |
|
|
n_skip=3, |
|
|
) |
|
|
else: |
|
|
raise Exception(f"trj mode [{trj_mode}] is not implemented.") |
|
|
|
|
|
_, v = tgt_c2w.shape[:2] |
|
|
tgt_extr = affine_inverse(tgt_c2w) |
|
|
if chunk_size is None: |
|
|
chunk_size = v |
|
|
chunk_size = min(v, chunk_size) |
|
|
all_colors = [] |
|
|
all_depths = [] |
|
|
for chunk_idx in tqdm( |
|
|
range(math.ceil(v / chunk_size)), |
|
|
desc="Rendering novel views", |
|
|
disable=(not enable_tqdm), |
|
|
leave=False, |
|
|
): |
|
|
s = int(chunk_idx * chunk_size) |
|
|
e = int((chunk_idx + 1) * chunk_size) |
|
|
cur_n_view = tgt_extr[:, s:e].shape[1] |
|
|
color, depth = render_3dgs( |
|
|
extrinsics=rearrange(tgt_extr[:, s:e], "b v ... -> (b v) ..."), |
|
|
intrinsics=rearrange(tgt_intr[:, s:e], "b v ... -> (b v) ..."), |
|
|
image_shape=image_shape, |
|
|
gaussian=gaussians, |
|
|
num_view=cur_n_view, |
|
|
**kwargs, |
|
|
) |
|
|
all_colors.append(rearrange(color, "(b v) ... -> b v ...", v=cur_n_view)) |
|
|
all_depths.append(rearrange(depth, "(b v) ... -> b v ...", v=cur_n_view)) |
|
|
all_colors = torch.cat(all_colors, dim=1) |
|
|
all_depths = torch.cat(all_depths, dim=1) |
|
|
|
|
|
return all_colors, all_depths |
|
|
|