File size: 5,978 Bytes
909940e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import cv2
import torch
import numpy as np
import torch.nn.functional as F

from gswrapper import gaussiansplatting_render

def generate_2D_gaussian_splatting(kernel_size, sigma_x, sigma_y, rho, coords, 
        colours, image_size=(256, 256, 3), device="cuda"):

    batch_size = colours.shape[0]

    sigma_x = sigma_x.view(batch_size, 1, 1)
    sigma_y = sigma_y.view(batch_size, 1, 1)
    rho = rho.view(batch_size, 1, 1)

    covariance = torch.stack(
        [torch.stack([sigma_x**2, rho*sigma_x*sigma_y], dim=-1),
        torch.stack([rho*sigma_x*sigma_y, sigma_y**2], dim=-1)],
        dim=-2
    )

    # Check for positive semi-definiteness
    # determinant = (sigma_x**2) * (sigma_y**2) - (rho * sigma_x * sigma_y)**2
    # if (determinant <= 0).any():
    #     raise ValueError("Covariance matrix must be positive semi-definite")

    inv_covariance = torch.inverse(covariance)

    # Choosing quite a broad range for the distribution [-5,5] to avoid any clipping
    start = torch.tensor([-5.0], device=device).view(-1, 1)
    end = torch.tensor([5.0], device=device).view(-1, 1)
    base_linspace = torch.linspace(0, 1, steps=kernel_size, device=device)
    ax_batch = start + (end - start) * base_linspace

    # Expanding dims for broadcasting
    ax_batch_expanded_x = ax_batch.unsqueeze(-1).expand(-1, -1, kernel_size)
    ax_batch_expanded_y = ax_batch.unsqueeze(1).expand(-1, kernel_size, -1)

    # Creating a batch-wise meshgrid using broadcasting
    xx, yy = ax_batch_expanded_x, ax_batch_expanded_y # (batchsize, kernelsize, kernelsize)

    xy = torch.stack([xx, yy], dim=-1) # (batchsize, kernelsize, kernelsize, 2)
    z = torch.einsum('b...i,b...ij,b...j->b...', xy, -0.5 * inv_covariance, xy) # (batchsize, kernelsize, kernelsize, 2)
    kernel = torch.exp(z) / (2 * torch.tensor(np.pi, device=device) * torch.sqrt(torch.det(covariance)).view(batch_size, 1, 1)) # (batchsize, kernelsize, kernelsize)


    kernel_max_1, _ = kernel.max(dim=-1, keepdim=True)  # Find max along the last dimension
    kernel_max_2, _ = kernel_max_1.max(dim=-2, keepdim=True)  # Find max along the second-to-last dimension
    kernel_normalized = kernel / kernel_max_2 # (batchsize, kernelsize, kernelsize)


    kernel_reshaped = kernel_normalized.repeat(1, 3, 1).view(batch_size * 3, kernel_size, kernel_size)
    kernel_rgb = kernel_reshaped.unsqueeze(0).reshape(batch_size, 3, kernel_size, kernel_size)  # (batchsize, 3, kernelsize, kernelsize)

    # Calculating the padding needed to match the image size
    pad_h = image_size[0] - kernel_size
    pad_w = image_size[1] - kernel_size

    if pad_h < 0 or pad_w < 0:
        raise ValueError("Kernel size should be smaller or equal to the image size.")

    # Adding padding to make kernel size equal to the image size
    padding = (pad_w // 2, pad_w // 2 + pad_w % 2,  # padding left and right
               pad_h // 2, pad_h // 2 + pad_h % 2)  # padding top and bottom

    kernel_rgb_padded = torch.nn.functional.pad(kernel_rgb, padding, "constant", 0) # (batchsize, 3, h, w)

    # Extracting shape information
    b, c, h, w = kernel_rgb_padded.shape

    # Create a batch of 2D affine matrices
    theta = torch.zeros(b, 2, 3, dtype=torch.float32, device=device)
    theta[:, 0, 0] = 1.0
    theta[:, 1, 1] = 1.0
    theta[:, :, 2] = -coords # (b, 2) - the offset of gaussian splating

    # Creating grid and performing grid sampling
    grid = F.affine_grid(theta, size=(b, c, h, w), align_corners=True) # (b, 3, h, w)
    # grid_y = torch.linspace(-1, 1, steps=h, device=device).reshape(1, h, 1, 1).repeat(1, 1, w, 1)
    # grid_x = torch.linspace(-1, 1, steps=w, device=device).reshape(1, 1, w, 1).repeat(1, h, 1, 1)
    # grid = torch.cat([grid_x, grid_y], dim=-1)
    # grid = grid - coords.reshape(-1, 1, 1, 2)

    kernel_rgb_padded_translated = F.grid_sample(kernel_rgb_padded, grid, align_corners=True) # (b, 3, h, w)

    rgb_values_reshaped = colours.unsqueeze(-1).unsqueeze(-1)

    final_image_layers = rgb_values_reshaped * kernel_rgb_padded_translated
    final_image = final_image_layers.sum(dim=0)
    # final_image = torch.clamp(final_image, 0, 1)
    final_image = final_image.permute(1,2,0)

    return final_image


if __name__ == "__main__":
    from mylineprofiler import MyLineProfiler
    profiler_th = MyLineProfiler(cuda_sync=True)
    generate_2D_gaussian_splatting = profiler_th.decorate(generate_2D_gaussian_splatting)
    profiler_cuda = MyLineProfiler(cuda_sync=True)
    gaussiansplatting_render = profiler_cuda.decorate(gaussiansplatting_render)


    # --- test ---
    # s = 1000
    s = 5
    # image_size = (512, 512, 3)
    image_size = (511, 511, 3)
    # image_size = (256, 512, 3)
    # image_size = (256, 256, 3)

    sigmas = 0.999*torch.rand(s, 3).to(torch.float32).to("cuda")
    sigmas[:,:2] = 5*sigmas[:, :2]
    coords = 2*torch.rand(s, 2).to(torch.float32).to("cuda")-1.0
    colors = torch.rand(s, 3).to(torch.float32).to("cuda")

    # --- torch version ---
    import gc
    gc.collect()
    torch.cuda.empty_cache()
    for _ in range(20):
        img = generate_2D_gaussian_splatting(101, sigmas[:,1], sigmas[:,0], sigmas[:,2], coords, colors, image_size)
    profiler_th.print("profile.log", "w")
    cv2.imwrite("th.png", 255.0 * img.detach().clamp(0, 1).cpu().numpy())
    # --- ends ---

    # --- cuda version ---
    _stepsize_of_gs_th = 10 / (101-1)
    _stepsize_of_gs_cuda_w = 2 / (image_size[1]-1)
    _stepsize_of_gs_cuda_h = 2 / (image_size[0]-1)
    sigmas[:, 0] = sigmas[:, 0] * _stepsize_of_gs_cuda_w / _stepsize_of_gs_th
    sigmas[:, 1] = sigmas[:, 1] * _stepsize_of_gs_cuda_h / _stepsize_of_gs_th
    dmax = 101/2*_stepsize_of_gs_cuda_w
    gc.collect()
    torch.cuda.empty_cache()
    for _ in range(20):
        img = gaussiansplatting_render(sigmas, coords, colors, image_size, dmax)

    profiler_cuda.print("profile.log", "a")
    cv2.imwrite("cuda.png", 255.0 * img.detach().clamp(0, 1).cpu().numpy())
    # --- ends ---