File size: 6,583 Bytes
c8df52d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import time
from xml.dom.minidom import Notation
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from gsplat import rasterization

# torch.backends.cuda.preferred_linalg_library(backend="magma")

""""
modified from https://github.com/arthurhero/Long-LRM/blob/main/model/llrm.py
"""
class GaussianRendererWithCheckpoint(torch.autograd.Function):
    @staticmethod
    def render(xyz, feature, scale, rotation, opacity, test_c2w, test_intr, 
               W, H, sh_degree, near_plane, far_plane, backgrounds):
        test_w2c = test_c2w.float().inverse().unsqueeze(0) # (1, 4, 4)
        test_intr_i = torch.zeros(3, 3).to(test_intr.device)
        test_intr_i[0, 0] = test_intr[0]
        test_intr_i[1, 1] = test_intr[1]
        test_intr_i[0, 2] = test_intr[2]
        test_intr_i[1, 2] = test_intr[3]
        test_intr_i[2, 2] = 1
        test_intr_i = test_intr_i.unsqueeze(0) # (1, 3, 3)
        rendering, alpha, _ = rasterization(xyz, rotation, scale, opacity, feature,
                                        test_w2c, test_intr_i, W, H, sh_degree=sh_degree, 
                                        near_plane=near_plane, far_plane=far_plane,
                                        render_mode="RGB+D",
                                        backgrounds=backgrounds[None],
                                        rasterize_mode='classic') # (1, H, W, 4) 
        # rendering[..., 3:] = rendering[..., 3:] + far_plane * (1 - alpha)
        return rendering

    @staticmethod
    def forward(ctx, xyz, feature, scale, rotation, opacity, test_c2ws, test_intr,
                W, H, sh_degree, near_plane, far_plane, backgrounds):
        ctx.save_for_backward(xyz, feature, scale, rotation, opacity, test_c2ws, test_intr, backgrounds)
        ctx.W = W
        ctx.H = H
        ctx.sh_degree = sh_degree
        ctx.near_plane = near_plane
        ctx.far_plane = far_plane
        with torch.no_grad():
            V, _ = test_intr.shape
            renderings = torch.zeros(V, H, W, 4).to(xyz.device)
            alphas = torch.rand(V, device=xyz.device)
            for iv in range(V):
                rendering = GaussianRendererWithCheckpoint.render(xyz, feature, scale, rotation, opacity, 
                                                                      test_c2ws[iv], test_intr[iv], W, H, sh_degree, near_plane, far_plane, backgrounds[iv])
                renderings[iv:iv+1] = rendering

        renderings = renderings.requires_grad_()
        return renderings

    @staticmethod
    def backward(ctx, grad_output):
        xyz, feature, scale, rotation, opacity, test_c2ws, test_intr, backgrounds = ctx.saved_tensors
        xyz = xyz.detach().requires_grad_()
        feature = feature.detach().requires_grad_()
        scale = scale.detach().requires_grad_()
        rotation = rotation.detach().requires_grad_()
        opacity = opacity.detach().requires_grad_()
        W = ctx.W
        H = ctx.H
        sh_degree = ctx.sh_degree
        near_plane = ctx.near_plane
        far_plane = ctx.far_plane
        with torch.enable_grad():
            V, _ = test_intr.shape
            for iv in range(V):
                rendering = GaussianRendererWithCheckpoint.render(xyz, feature, scale, rotation, opacity, 
                                                        test_c2ws[iv], test_intr[iv], W, H, sh_degree, near_plane, far_plane, backgrounds[iv])
                rendering.backward(grad_output[iv:iv+1])

        return xyz.grad, feature.grad, scale.grad, rotation.grad, opacity.grad, None, None, None, None, None, None, None, None

def gaussian_render(gaussian_params, test_c2ws, test_intr, W, H, near_plane=0.01, far_plane=1000, use_checkpoint=False, sh_degree=0, bg_mode='random'):

    if not torch.is_grad_enabled():
        use_checkpoint = False

     # opengl2colmap, see https://github.com/imlixinyang/Director3D/blob/main/modules/renderers/gaussians_renderer.py
    test_c2ws[:, :, :3, 1:3] *= -1

    device = test_intr.device
    B, V, _ = test_intr.shape
    
    renderings = []

    for ib in range(B):
        if bg_mode == 'random':
            backgrounds = torch.rand(V, 3).to(device)
        elif bg_mode == 'white':
            backgrounds = torch.ones(V, 3).to(device)
        elif bg_mode == 'black':
            backgrounds = torch.zeros(V, 3).to(device)
        else:
            raise ValueError(f"Invalid background mode: {bg_mode}")
            
        xyz_i, opacity_i, scale_i, rotation_i, feature_i = gaussian_params[ib].float().split([3, 1, 3, 4, (sh_degree + 1)**2 * 3], dim=-1)

        opacity_i = opacity_i.squeeze(-1)
        feature_i = feature_i.reshape(-1, (sh_degree + 1)**2, 3)

        if use_checkpoint:
            
            renderings.append(GaussianRendererWithCheckpoint.apply(xyz_i, feature_i, scale_i, rotation_i, opacity_i, test_c2ws[ib], test_intr[ib], W, H, sh_degree, near_plane, far_plane, backgrounds))

        else:
            rendering = torch.zeros(V, H, W, 4).to(device)
            for iv in range(V):
                rendering[iv:iv+1] = GaussianRendererWithCheckpoint.render(xyz_i, feature_i, scale_i, rotation_i, opacity_i, 
                                                                      test_c2ws[ib][iv], test_intr[ib][iv], W, H, sh_degree, near_plane, far_plane, backgrounds[iv])

            # test_w2c_i = test_c2ws[ib].float().inverse() # (V, 4, 4)
            # test_intr_i = torch.zeros(V, 3, 3).to(device)
            # test_intr_i[:, 0, 0] = test_intr[ib, :, 0]
            # test_intr_i[:, 1, 1] = test_intr[ib, :, 1]
            # test_intr_i[:, 0, 2] = test_intr[ib, :, 2]
            # test_intr_i[:, 1, 2] = test_intr[ib, :, 3]
            # test_intr_i[:, 2, 2] = 1

            # # print(backgrounds.shape)
            # rendering, _, _ = rasterization(xyz_i, rotation_i, scale_i, opacity_i, feature_i,
            #                                     test_w2c_i, test_intr_i, W, H, sh_degree=sh_degree, 
            #                                     near_plane=near_plane, far_plane=far_plane,
            #                                     render_mode="RGB+D",
            #                                     backgrounds=backgrounds,
            #                                     rasterize_mode='classic') # (V, H, W, 3) 
            renderings.append(rendering)
    
    renderings = torch.stack(renderings, dim=0).permute(0, 1, 4, 2, 3).contiguous() # (B, 3, V, H, W)
    rgb = renderings[:, :, :3].mul_(2).add_(-1).clamp(-1, 1)
    depth = renderings[:, :, 3:]
    return rgb, depth