File size: 2,391 Bytes
909940e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include "gs.h"
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

void gs_render(
        torch::Tensor &sigmas,
        torch::Tensor &coords,
        torch::Tensor &colors,
        torch::Tensor &rendered_img,
	const int s,
	const int h,
	const int w,
	const int c,
	const float dmax
        ){
      
        CHECK_INPUT(sigmas);
        CHECK_INPUT(coords);
        CHECK_INPUT(colors);
        CHECK_INPUT(rendered_img);

        // run the code at the cuda device same with the input
        const at::cuda::OptionalCUDAGuard device_guard(device_of(sigmas));

        _gs_render(
            (const float *) sigmas.data_ptr(),
            (const float *) coords.data_ptr(),
            (const float *) colors.data_ptr(),
            (float *) rendered_img.data_ptr(),
	    s, h, w, c, dmax);
}

void gs_render_backward(
        torch::Tensor &sigmas,
        torch::Tensor &coords,
        torch::Tensor &colors,
        torch::Tensor &grads,
        torch::Tensor &grads_sigmas,
        torch::Tensor &grads_coords,
        torch::Tensor &grads_colors,
	const int s,
	const int h,
	const int w,
	const int c,
	const float dmax
        ){

        CHECK_INPUT(sigmas);
        CHECK_INPUT(coords);
        CHECK_INPUT(colors);
        CHECK_INPUT(grads);
        CHECK_INPUT(grads_sigmas);
        CHECK_INPUT(grads_coords);
        CHECK_INPUT(grads_colors);


        // run the code at the cuda device same with the input
        const at::cuda::OptionalCUDAGuard device_guard(device_of(sigmas));

        _gs_render_backward(
            (const float *) sigmas.data_ptr(),
            (const float *) coords.data_ptr(),
            (const float *) colors.data_ptr(),
            (const float *) grads.data_ptr(),
            (float *) grads_sigmas.data_ptr(),
            (float *) grads_coords.data_ptr(),
            (float *) grads_colors.data_ptr(),
	    s, h, w, c, dmax);
}

PYBIND11_MODULE( TORCH_EXTENSION_NAME, m) {
        m.def( "gs_render",
                &gs_render,
                "cuda forward wrapper");
        m.def( "gs_render_backward",
                &gs_render_backward,
                "cuda backward wrapper");
}