File size: 9,821 Bytes
af758d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch.utils.checkpoint import _get_autocast_kwargs
from gsplat.rendering import rasterization

class DeferredBPPatch(torch.autograd.Function):
    @staticmethod
    def forward(ctx, xyz, features, scaling, rotation, opacity, C2W, Ks, width, height, near_plane, far_plane, backgrounds, patch_size, raster_kwargs):
        """
        Forward rendering with the addition of near_plane and far_plane.
        """
        assert (xyz.dim() == 3) and (
            features.dim() == 3
        ) and (scaling.dim() == 3) and (rotation.dim() == 3), f"xyz: {xyz.shape}, features: {features.shape}, scaling: {scaling.shape}, rotation: {rotation.shape}, opacity: {opacity.shape}"
        assert height % patch_size[0] == 0 and width % patch_size[1] == 0, f'patch_size must be divisible by H ({height} / {patch_size[0]}) and W ({width} / {patch_size[1]})!'

        ctx.save_for_backward(xyz, features, scaling, rotation, opacity)  # save tensors for backward
        ctx.height = height
        ctx.width = width
        ctx.C2W = C2W
        ctx.Ks = Ks
        ctx.patch_size = patch_size
        ctx.backgrounds = backgrounds
        ctx.near_plane = near_plane
        ctx.far_plane = far_plane
        ctx.raster_kwargs = raster_kwargs

        ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
        ctx.manual_seeds = []

        with torch.no_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
            device = C2W.device
            b, v = C2W.shape[:2]
            colors = torch.zeros(b, v, 3, height, width, device=device)
            alphas = torch.zeros(b, v, 1, height, width, device=device)
            depths = torch.zeros(b, v, 1, height, width, device=device)  # We will store depth here

            for i in range(b):
                ctx.manual_seeds.append([])

                for j in range(v):
                    Ks_ij = Ks[i, j]
                    fx, fy, cx, cy = Ks_ij[0, 0], Ks_ij[1, 1], Ks_ij[0, 2], Ks_ij[1, 2]
                    for m in range(0, ctx.width // ctx.patch_size[1]):
                        for n in range(0, ctx.height // ctx.patch_size[0]):
                            seed = torch.randint(0, 2**32, (1,)).long().item()
                            ctx.manual_seeds[-1].append(seed)

                            new_fx = fx
                            new_fy = fy
                            new_cx = cx - m * ctx.patch_size[1]
                            new_cy = cy - n * ctx.patch_size[0]
                            
                            new_K = torch.tensor([[new_fx, 0., new_cx], [0., new_fy, new_cy], [0., 0., 1.]], dtype=torch.float32, device=device)

                            rgbd, alpha, _ = rasterization(
                                means=xyz[i],
                                quats=rotation[i],
                                scales=scaling[i],
                                opacities=opacity[i].squeeze(-1),
                                colors=features[i],
                                viewmats=C2W[i, j][None],
                                Ks=new_K[None],
                                width=ctx.patch_size[1],
                                height=ctx.patch_size[0],
                                near_plane=ctx.near_plane,  # Use near_plane here
                                far_plane=ctx.far_plane,    # Use far_plane here
                                backgrounds=ctx.backgrounds[i, j][None],
                                render_mode="RGB+ED",  # RGB + Depth (last channel)
                                **raster_kwargs,
                            )

                            # Permute and clamp the rendered image and alpha
                            rendered_image = rgbd[0, :, :, :3].permute(2, 0, 1).clamp(0, 1)  # (1, 3, H, W)
                            rendered_alpha = alpha[0].permute(2, 0, 1).clamp(0, 1)  # (1, 1, H, W)
                            rendered_depth = rgbd[0, :, :, 3:].permute(2, 0, 1)  # Depth is the last channel

                            # Store the results in the final output tensors
                            colors[i, j, :, n * ctx.patch_size[0]:(n + 1) * ctx.patch_size[0], m * ctx.patch_size[1]:(m + 1) * ctx.patch_size[1]] = rendered_image
                            alphas[i, j, :, n * ctx.patch_size[0]:(n + 1) * ctx.patch_size[0], m * ctx.patch_size[1]:(m + 1) * ctx.patch_size[1]] = rendered_alpha
                            depths[i, j, :, n * ctx.patch_size[0]:(n + 1) * ctx.patch_size[0], m * ctx.patch_size[1]:(m + 1) * ctx.patch_size[1]] = rendered_depth

        return colors, alphas, depths

    @staticmethod
    def backward(ctx, grad_colors, grad_alphas, grad_depths):
        """
        Backward process.
        """
        xyz, features, scaling, rotation, opacity = ctx.saved_tensors
        raster_kwargs = ctx.raster_kwargs

        xyz_nosync = xyz.detach().clone()
        xyz_nosync.requires_grad = True
        xyz_nosync.grad = None

        features_nosync = features.detach().clone()
        features_nosync.requires_grad = True
        features_nosync.grad = None

        scaling_nosync = scaling.detach().clone()
        scaling_nosync.requires_grad = True
        scaling_nosync.grad = None

        rotation_nosync = rotation.detach().clone()
        rotation_nosync.requires_grad = True
        rotation_nosync.grad = None

        opacity_nosync = opacity.detach().clone()
        opacity_nosync.requires_grad = True
        opacity_nosync.grad = None

        with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
            device = ctx.C2W.device
            dtype = ctx.C2W.dtype
            b, v = ctx.C2W.shape[:2]

            for i in range(b):
                ctx.manual_seeds.append([])

                for j in range(v):
                    Ks_ij = ctx.Ks[i, j]
                    fx, fy, cx, cy = Ks_ij[0, 0], Ks_ij[1, 1], Ks_ij[0, 2], Ks_ij[1, 2]
                    for m in range(0, ctx.width // ctx.patch_size[1]):
                        for n in range(0, ctx.height // ctx.patch_size[0]):
                            grad_colors_split = grad_colors[i, j, :, n * ctx.patch_size[0]:(n + 1) * ctx.patch_size[0], m * ctx.patch_size[1]:(m + 1) * ctx.patch_size[1]]
                            grad_alphas_split = grad_alphas[i, j, :, n * ctx.patch_size[0]:(n + 1) * ctx.patch_size[0], m * ctx.patch_size[1]:(m + 1) * ctx.patch_size[1]]
                            grad_depths_split = grad_depths[i, j, :, n * ctx.patch_size[0]:(n + 1) * ctx.patch_size[0], m * ctx.patch_size[1]:(m + 1) * ctx.patch_size[1]]

                            seed = torch.randint(0, 2**32, (1,)).long().item()
                            ctx.manual_seeds[-1].append(seed)

                            new_fx = fx
                            new_fy = fy
                            new_cx = cx - m * ctx.patch_size[1]
                            new_cy = cy - n * ctx.patch_size[0]

                            new_K = torch.tensor([[new_fx, 0., new_cx], [0., new_fy, new_cy], [0., 0., 1.]], dtype=dtype, device=device)

                            rgbd, alpha, _ = rasterization(
                                means=xyz_nosync[i],
                                quats=rotation_nosync[i],
                                scales=scaling_nosync[i],
                                opacities=opacity_nosync[i].squeeze(-1),
                                colors=features_nosync[i],
                                viewmats=ctx.C2W[i, j][None],
                                Ks=new_K[None],
                                width=ctx.patch_size[1],
                                height=ctx.patch_size[0],
                                near_plane=ctx.near_plane,
                                far_plane=ctx.far_plane,
                                backgrounds=ctx.backgrounds[i, j][None],
                                render_mode="RGB+ED",
                                **raster_kwargs,
                            )

                            # Permute and clamp the rendered image and alpha
                            rendered_image = rgbd[0, :, :, :3].permute(2, 0, 1)
                            rendered_image = rendered_image.clamp(0, 1)
                            rendered_alpha = alpha[0].permute(2, 0, 1) #.clamp(0, 1)
                            rendered_depth = rgbd[0, :, :, 3:].permute(2, 0, 1)


                            # Concatenate rendered output and gradients
                            render_split = torch.cat([rendered_image, rendered_alpha, rendered_depth], dim=0)  # (1, H, W, 5)
                            grad_split = torch.cat([grad_colors_split, grad_alphas_split, grad_depths_split], dim=0)  # Same shape as render_split
                            render_split.backward(grad_split)

        # Return the gradients for the inputs that were used in forward pass
        return xyz_nosync.grad, features_nosync.grad, scaling_nosync.grad, rotation_nosync.grad, opacity_nosync.grad, None, None, None, None, None, None, None, None, None