|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from gsplat.rendering import rasterization |
|
|
import kiui |
|
|
import torch.nn.functional as F |
|
|
import einops |
|
|
|
|
|
from src.models.utils.render import downscale_intrinsics |
|
|
from src.rendering.gs_deferred_patch import DeferredBPPatch |
|
|
|
|
|
class DeferredBP(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def render(xyz, feature, scale, rotation, opacity, test_w2c, test_intr, |
|
|
W, H, near_plane, far_plane, backgrounds, raster_kwargs): |
|
|
rgbd, alpha, _ = rasterization( |
|
|
means=xyz, |
|
|
quats=rotation, |
|
|
scales=scale, |
|
|
opacities=opacity, |
|
|
colors=feature, |
|
|
viewmats=test_w2c, |
|
|
Ks=test_intr, |
|
|
width=W, |
|
|
height=H, |
|
|
near_plane=near_plane, |
|
|
far_plane=far_plane, |
|
|
backgrounds=backgrounds, |
|
|
render_mode="RGB+ED", |
|
|
**raster_kwargs, |
|
|
) |
|
|
image, depth = rgbd[..., :3], rgbd[..., 3:] |
|
|
return image, alpha, depth |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, xyz, feature, scale, rotation, opacity, test_w2cs, test_intr, |
|
|
W, H, near_plane, far_plane, backgrounds, raster_kwargs): |
|
|
ctx.save_for_backward(xyz, feature, scale, rotation, opacity, test_w2cs, test_intr, backgrounds) |
|
|
ctx.W = W |
|
|
ctx.H = H |
|
|
ctx.near_plane = near_plane |
|
|
ctx.far_plane = far_plane |
|
|
ctx.raster_kwargs = raster_kwargs |
|
|
with torch.no_grad(): |
|
|
B, V = test_intr.shape[:2] |
|
|
images = torch.zeros(B, V, H, W, 3).to(xyz.device) |
|
|
alphas = torch.zeros(B, V, H, W, 1).to(xyz.device) |
|
|
depths = torch.zeros(B, V, H, W, 1).to(xyz.device) |
|
|
for ib in range(B): |
|
|
for iv in range(V): |
|
|
image, alpha, depth = DeferredBP.render( |
|
|
xyz[ib], feature[ib], scale[ib], rotation[ib], opacity[ib], |
|
|
test_w2cs[ib,iv:iv+1], test_intr[ib,iv:iv+1], |
|
|
W, H, near_plane, far_plane, backgrounds[ib,iv:iv+1], |
|
|
raster_kwargs |
|
|
) |
|
|
images[ib, iv:iv+1] = image |
|
|
alphas[ib, iv:iv+1] = alpha |
|
|
depths[ib, iv:iv+1] = depth |
|
|
images = images.requires_grad_() |
|
|
alphas = alphas.requires_grad_() |
|
|
depths = depths.requires_grad_() |
|
|
return images, alphas, depths |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, images_grad, alphas_grad, depths_grad): |
|
|
xyz, feature, scale, rotation, opacity, test_w2cs, test_intr, backgrounds = ctx.saved_tensors |
|
|
xyz = xyz.detach().requires_grad_() |
|
|
feature = feature.detach().requires_grad_() |
|
|
scale = scale.detach().requires_grad_() |
|
|
rotation = rotation.detach().requires_grad_() |
|
|
opacity = opacity.detach().requires_grad_() |
|
|
W = ctx.W |
|
|
H = ctx.H |
|
|
near_plane = ctx.near_plane |
|
|
far_plane = ctx.far_plane |
|
|
raster_kwargs = ctx.raster_kwargs |
|
|
with torch.enable_grad(): |
|
|
B, V = test_intr.shape[:2] |
|
|
for ib in range(B): |
|
|
for iv in range(V): |
|
|
image, alpha, depth = DeferredBP.render( |
|
|
xyz[ib], feature[ib], scale[ib], rotation[ib], opacity[ib], |
|
|
test_w2cs[ib,iv:iv+1], test_intr[ib,iv:iv+1], |
|
|
W, H, near_plane, far_plane, backgrounds[ib,iv:iv+1], |
|
|
raster_kwargs, |
|
|
) |
|
|
render_split = torch.cat([image, alpha, depth], dim=-1) |
|
|
grad_split = torch.cat([images_grad[ib, iv:iv+1], alphas_grad[ib, iv:iv+1], depths_grad[ib, iv:iv+1]], dim=-1) |
|
|
render_split.backward(grad_split) |
|
|
|
|
|
return xyz.grad, feature.grad, scale.grad, rotation.grad, opacity.grad, None, None, None, None, None, None, None, None |
|
|
|
|
|
class GaussianRendererDeferred: |
|
|
def __init__(self, opt): |
|
|
self.opt = opt |
|
|
if self.opt.deferred_bp: |
|
|
self.render_func = self.render_deferred |
|
|
else: |
|
|
self.render_func = self.render_standard |
|
|
self.oom_downscale_factors = [1, 2, 4, 8] |
|
|
self.use_3dgut = self.opt.get('use_3dgut', False) |
|
|
if self.use_3dgut: |
|
|
self.raster_kwargs = {'with_ut': True, 'with_eval3d': True, 'packed': False} |
|
|
else: |
|
|
|
|
|
self.raster_kwargs = {'with_ut': False, 'with_eval3d': False, 'packed': False} |
|
|
|
|
|
def render(self, gaussians, cam_view, bg_color=None, intrinsics=None, patch_size=None): |
|
|
B, V = cam_view.shape[:2] |
|
|
|
|
|
means3D = gaussians[..., 0:3].contiguous().float() |
|
|
opacity = gaussians[..., 3:4].contiguous().float().squeeze(-1) |
|
|
scales = gaussians[..., 4:7].contiguous().float() |
|
|
rotations = gaussians[..., 7:11].contiguous().float() |
|
|
rgbs = gaussians[..., 11:].contiguous().float() |
|
|
|
|
|
viewmat = cam_view.float().transpose(3, 2) |
|
|
Ks = torch.tensor([[[[view_intrinsic[0],0.,view_intrinsic[2]],[0.,view_intrinsic[1],view_intrinsic[3]],[0., 0., 1.]] for view_intrinsic in batch_intrinsic] for batch_intrinsic in intrinsics], dtype=means3D.dtype, device=means3D.device) |
|
|
backgrounds = torch.tensor([[bg_color for _ in range(V)] for _ in range(B)], dtype=means3D.dtype, device=means3D.device) if bg_color is not None else torch.ones(B, V, 3, dtype=means3D.dtype, device=means3D.device) |
|
|
|
|
|
H, W = self.opt.img_size |
|
|
near_plane, far_plane = self.opt.znear, self.opt.zfar |
|
|
|
|
|
for factor_idx, downscale_factor in enumerate(self.oom_downscale_factors): |
|
|
out_dict = self.render_func(means3D, opacity, scales, rotations, rgbs, viewmat, Ks, backgrounds, H, W, near_plane, far_plane, patch_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return out_dict |
|
|
|
|
|
def render_downscale(self, means3D, opacity, scales, rotations, rgbs, viewmat, Ks, backgrounds, H, W, near_plane, far_plane, patch_size, B, downscale_factor): |
|
|
print(f"Cuda Error for rendering on {means3D.device}! Switch to {downscale_factor}x low res") |
|
|
Ks_resized = downscale_intrinsics(Ks.clone(), factor=downscale_factor) |
|
|
H_resized, W_resized = H //downscale_factor, W //downscale_factor |
|
|
out_dict = self.render_func(means3D, opacity, scales, rotations, rgbs, viewmat, Ks_resized, backgrounds, H_resized, W_resized, near_plane, far_plane, patch_size) |
|
|
for k in ["images_pred", "alphas_pred", "depths_pred"]: |
|
|
out_dict[k] = einops.rearrange(out_dict[k], 'b v c h w -> (b v) c h w') |
|
|
out_dict[k] = F.interpolate(out_dict[k], size=(H, W), mode='nearest') |
|
|
out_dict[k] = einops.rearrange(out_dict[k], '(b v) c h w -> b v c h w', b=B) |
|
|
return out_dict |
|
|
|
|
|
def render_deferred(self, means3D, opacity, scales, rotations, rgbs, viewmat, Ks, backgrounds, H, W, near_plane, far_plane, patch_size=None): |
|
|
|
|
|
if patch_size is None: |
|
|
images, alphas, depths = DeferredBP.apply( |
|
|
means3D, rgbs, scales, rotations, opacity, |
|
|
viewmat, Ks, W, H, near_plane, far_plane, |
|
|
backgrounds, self.raster_kwargs, |
|
|
) |
|
|
return { |
|
|
"images_pred": images.permute(0, 1, 4, 2, 3), |
|
|
"alphas_pred": alphas.permute(0, 1, 4, 2, 3), |
|
|
"depths_pred": depths.permute(0, 1, 4, 2, 3), |
|
|
} |
|
|
else: |
|
|
|
|
|
images, alphas, depths = DeferredBPPatch.apply( |
|
|
means3D, rgbs, scales, rotations, opacity, |
|
|
viewmat, Ks, W, H, near_plane, far_plane, |
|
|
backgrounds, patch_size, self.raster_kwargs, |
|
|
) |
|
|
return { |
|
|
"images_pred": images, |
|
|
"alphas_pred": alphas, |
|
|
"depths_pred": depths, |
|
|
} |
|
|
|
|
|
def render_standard(self, means3D, opacity, scales, rotations, rgbs, viewmat, Ks, backgrounds, H, W, near_plane, far_plane, patch_size=None): |
|
|
|
|
|
|
|
|
B, V = Ks.shape[:2] |
|
|
|
|
|
|
|
|
images, alphas, depths = [], [], [] |
|
|
for b in range(B): |
|
|
rendered_image_all, rendered_alpha_all, _ = rasterization( |
|
|
means=means3D[b], |
|
|
quats=rotations[b], |
|
|
scales=scales[b], |
|
|
opacities=opacity[b], |
|
|
colors=rgbs[b], |
|
|
viewmats=viewmat[b], |
|
|
Ks=Ks[b], |
|
|
width=W, |
|
|
height=H, |
|
|
near_plane=near_plane, |
|
|
far_plane=far_plane, |
|
|
backgrounds=backgrounds[b], |
|
|
render_mode="RGB+ED", |
|
|
**self.raster_kwargs, |
|
|
) |
|
|
for rendered_image, rendered_alpha in zip(rendered_image_all, rendered_alpha_all): |
|
|
depths.append(rendered_image[...,3:].permute(2, 0, 1)) |
|
|
images.append(rendered_image[...,:3].permute(2, 0, 1)) |
|
|
alphas.append(rendered_alpha.permute(2, 0, 1)) |
|
|
|
|
|
images, alphas, depths = torch.stack(images), torch.stack(alphas), torch.stack(depths) |
|
|
images, alphas, depths = images.view(B, V, *images.shape[1:]), alphas.view(B, V, *alphas.shape[1:]), depths.view(B, V, *depths.shape[1:]) |
|
|
|
|
|
return { |
|
|
"images_pred": images, |
|
|
"alphas_pred": alphas, |
|
|
"depths_pred": depths, |
|
|
} |
|
|
|
|
|
|
|
|
def save_ply(self, gaussians, path, compatible=True): |
|
|
|
|
|
|
|
|
|
|
|
assert gaussians.shape[0] == 1, 'only support batch size 1' |
|
|
|
|
|
from plyfile import PlyData, PlyElement |
|
|
|
|
|
means3D = gaussians[0, :, 0:3].contiguous().float() |
|
|
opacity = gaussians[0, :, 3:4].contiguous().float() |
|
|
scales = gaussians[0, :, 4:7].contiguous().float() |
|
|
rotations = gaussians[0, :, 7:11].contiguous().float() |
|
|
shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() |
|
|
|
|
|
|
|
|
mask = opacity.squeeze(-1) >= 0.005 |
|
|
means3D = means3D[mask] |
|
|
opacity = opacity[mask] |
|
|
scales = scales[mask] |
|
|
rotations = rotations[mask] |
|
|
shs = shs[mask] |
|
|
|
|
|
|
|
|
if compatible: |
|
|
opacity = kiui.op.inverse_sigmoid(opacity) |
|
|
scales = torch.log(scales + 1e-8) |
|
|
shs = (shs - 0.5) / 0.28209479177387814 |
|
|
|
|
|
xyzs = means3D.detach().cpu().numpy() |
|
|
f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() |
|
|
opacities = opacity.detach().cpu().numpy() |
|
|
scales = scales.detach().cpu().numpy() |
|
|
rotations = rotations.detach().cpu().numpy() |
|
|
|
|
|
l = ['x', 'y', 'z'] |
|
|
|
|
|
for i in range(f_dc.shape[1]): |
|
|
l.append('f_dc_{}'.format(i)) |
|
|
l.append('opacity') |
|
|
for i in range(scales.shape[1]): |
|
|
l.append('scale_{}'.format(i)) |
|
|
for i in range(rotations.shape[1]): |
|
|
l.append('rot_{}'.format(i)) |
|
|
|
|
|
dtype_full = [(attribute, 'f4') for attribute in l] |
|
|
|
|
|
elements = np.empty(xyzs.shape[0], dtype=dtype_full) |
|
|
attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1) |
|
|
elements[:] = list(map(tuple, attributes)) |
|
|
el = PlyElement.describe(elements, 'vertex') |
|
|
|
|
|
PlyData([el]).write(path) |
|
|
|
|
|
def load_ply(self, path, compatible=True): |
|
|
|
|
|
from plyfile import PlyData, PlyElement |
|
|
|
|
|
plydata = PlyData.read(path) |
|
|
|
|
|
xyz = np.stack((np.asarray(plydata.elements[0]["x"]), |
|
|
np.asarray(plydata.elements[0]["y"]), |
|
|
np.asarray(plydata.elements[0]["z"])), axis=1) |
|
|
print("Number of points at loading : ", xyz.shape[0]) |
|
|
|
|
|
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] |
|
|
|
|
|
shs = np.zeros((xyz.shape[0], 3)) |
|
|
shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) |
|
|
shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"]) |
|
|
shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"]) |
|
|
|
|
|
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] |
|
|
scales = np.zeros((xyz.shape[0], len(scale_names))) |
|
|
for idx, attr_name in enumerate(scale_names): |
|
|
scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) |
|
|
|
|
|
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")] |
|
|
rots = np.zeros((xyz.shape[0], len(rot_names))) |
|
|
for idx, attr_name in enumerate(rot_names): |
|
|
rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) |
|
|
|
|
|
gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1) |
|
|
gaussians = torch.from_numpy(gaussians).float() |
|
|
|
|
|
if compatible: |
|
|
gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4]) |
|
|
gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7]) |
|
|
gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5 |
|
|
|
|
|
return gaussians |
|
|
|