Lyra / src /rendering /gs_deferred_patch.py
Muhammad Taqi Raza
adding lyra files
af758d1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.utils.checkpoint import _get_autocast_kwargs
from gsplat.rendering import rasterization
class DeferredBPPatch(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz, features, scaling, rotation, opacity, C2W, Ks, width, height, near_plane, far_plane, backgrounds, patch_size, raster_kwargs):
"""
Forward rendering with the addition of near_plane and far_plane.
"""
assert (xyz.dim() == 3) and (
features.dim() == 3
) and (scaling.dim() == 3) and (rotation.dim() == 3), f"xyz: {xyz.shape}, features: {features.shape}, scaling: {scaling.shape}, rotation: {rotation.shape}, opacity: {opacity.shape}"
assert height % patch_size[0] == 0 and width % patch_size[1] == 0, f'patch_size must be divisible by H ({height} / {patch_size[0]}) and W ({width} / {patch_size[1]})!'
ctx.save_for_backward(xyz, features, scaling, rotation, opacity) # save tensors for backward
ctx.height = height
ctx.width = width
ctx.C2W = C2W
ctx.Ks = Ks
ctx.patch_size = patch_size
ctx.backgrounds = backgrounds
ctx.near_plane = near_plane
ctx.far_plane = far_plane
ctx.raster_kwargs = raster_kwargs
ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
ctx.manual_seeds = []
with torch.no_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
device = C2W.device
b, v = C2W.shape[:2]
colors = torch.zeros(b, v, 3, height, width, device=device)
alphas = torch.zeros(b, v, 1, height, width, device=device)
depths = torch.zeros(b, v, 1, height, width, device=device) # We will store depth here
for i in range(b):
ctx.manual_seeds.append([])
for j in range(v):
Ks_ij = Ks[i, j]
fx, fy, cx, cy = Ks_ij[0, 0], Ks_ij[1, 1], Ks_ij[0, 2], Ks_ij[1, 2]
for m in range(0, ctx.width // ctx.patch_size[1]):
for n in range(0, ctx.height // ctx.patch_size[0]):
seed = torch.randint(0, 2**32, (1,)).long().item()
ctx.manual_seeds[-1].append(seed)
new_fx = fx
new_fy = fy
new_cx = cx - m * ctx.patch_size[1]
new_cy = cy - n * ctx.patch_size[0]
new_K = torch.tensor([[new_fx, 0., new_cx], [0., new_fy, new_cy], [0., 0., 1.]], dtype=torch.float32, device=device)
rgbd, alpha, _ = rasterization(
means=xyz[i],
quats=rotation[i],
scales=scaling[i],
opacities=opacity[i].squeeze(-1),
colors=features[i],
viewmats=C2W[i, j][None],
Ks=new_K[None],
width=ctx.patch_size[1],
height=ctx.patch_size[0],
near_plane=ctx.near_plane, # Use near_plane here
far_plane=ctx.far_plane, # Use far_plane here
backgrounds=ctx.backgrounds[i, j][None],
render_mode="RGB+ED", # RGB + Depth (last channel)
**raster_kwargs,
)
# Permute and clamp the rendered image and alpha
rendered_image = rgbd[0, :, :, :3].permute(2, 0, 1).clamp(0, 1) # (1, 3, H, W)
rendered_alpha = alpha[0].permute(2, 0, 1).clamp(0, 1) # (1, 1, H, W)
rendered_depth = rgbd[0, :, :, 3:].permute(2, 0, 1) # Depth is the last channel
# Store the results in the final output tensors
colors[i, j, :, n * ctx.patch_size[0]:(n + 1) * ctx.patch_size[0], m * ctx.patch_size[1]:(m + 1) * ctx.patch_size[1]] = rendered_image
alphas[i, j, :, n * ctx.patch_size[0]:(n + 1) * ctx.patch_size[0], m * ctx.patch_size[1]:(m + 1) * ctx.patch_size[1]] = rendered_alpha
depths[i, j, :, n * ctx.patch_size[0]:(n + 1) * ctx.patch_size[0], m * ctx.patch_size[1]:(m + 1) * ctx.patch_size[1]] = rendered_depth
return colors, alphas, depths
@staticmethod
def backward(ctx, grad_colors, grad_alphas, grad_depths):
"""
Backward process.
"""
xyz, features, scaling, rotation, opacity = ctx.saved_tensors
raster_kwargs = ctx.raster_kwargs
xyz_nosync = xyz.detach().clone()
xyz_nosync.requires_grad = True
xyz_nosync.grad = None
features_nosync = features.detach().clone()
features_nosync.requires_grad = True
features_nosync.grad = None
scaling_nosync = scaling.detach().clone()
scaling_nosync.requires_grad = True
scaling_nosync.grad = None
rotation_nosync = rotation.detach().clone()
rotation_nosync.requires_grad = True
rotation_nosync.grad = None
opacity_nosync = opacity.detach().clone()
opacity_nosync.requires_grad = True
opacity_nosync.grad = None
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
device = ctx.C2W.device
dtype = ctx.C2W.dtype
b, v = ctx.C2W.shape[:2]
for i in range(b):
ctx.manual_seeds.append([])
for j in range(v):
Ks_ij = ctx.Ks[i, j]
fx, fy, cx, cy = Ks_ij[0, 0], Ks_ij[1, 1], Ks_ij[0, 2], Ks_ij[1, 2]
for m in range(0, ctx.width // ctx.patch_size[1]):
for n in range(0, ctx.height // ctx.patch_size[0]):
grad_colors_split = grad_colors[i, j, :, n * ctx.patch_size[0]:(n + 1) * ctx.patch_size[0], m * ctx.patch_size[1]:(m + 1) * ctx.patch_size[1]]
grad_alphas_split = grad_alphas[i, j, :, n * ctx.patch_size[0]:(n + 1) * ctx.patch_size[0], m * ctx.patch_size[1]:(m + 1) * ctx.patch_size[1]]
grad_depths_split = grad_depths[i, j, :, n * ctx.patch_size[0]:(n + 1) * ctx.patch_size[0], m * ctx.patch_size[1]:(m + 1) * ctx.patch_size[1]]
seed = torch.randint(0, 2**32, (1,)).long().item()
ctx.manual_seeds[-1].append(seed)
new_fx = fx
new_fy = fy
new_cx = cx - m * ctx.patch_size[1]
new_cy = cy - n * ctx.patch_size[0]
new_K = torch.tensor([[new_fx, 0., new_cx], [0., new_fy, new_cy], [0., 0., 1.]], dtype=dtype, device=device)
rgbd, alpha, _ = rasterization(
means=xyz_nosync[i],
quats=rotation_nosync[i],
scales=scaling_nosync[i],
opacities=opacity_nosync[i].squeeze(-1),
colors=features_nosync[i],
viewmats=ctx.C2W[i, j][None],
Ks=new_K[None],
width=ctx.patch_size[1],
height=ctx.patch_size[0],
near_plane=ctx.near_plane,
far_plane=ctx.far_plane,
backgrounds=ctx.backgrounds[i, j][None],
render_mode="RGB+ED",
**raster_kwargs,
)
# Permute and clamp the rendered image and alpha
rendered_image = rgbd[0, :, :, :3].permute(2, 0, 1)
rendered_image = rendered_image.clamp(0, 1)
rendered_alpha = alpha[0].permute(2, 0, 1) #.clamp(0, 1)
rendered_depth = rgbd[0, :, :, 3:].permute(2, 0, 1)
# Concatenate rendered output and gradients
render_split = torch.cat([rendered_image, rendered_alpha, rendered_depth], dim=0) # (1, H, W, 5)
grad_split = torch.cat([grad_colors_split, grad_alphas_split, grad_depths_split], dim=0) # Same shape as render_split
render_split.backward(grad_split)
# Return the gradients for the inputs that were used in forward pass
return xyz_nosync.grad, features_nosync.grad, scaling_nosync.grad, rotation_nosync.grad, opacity_nosync.grad, None, None, None, None, None, None, None, None, None