File size: 4,387 Bytes
cda88e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <torch/extension.h>

#include <vector>
#include <stdio.h>

// CUDA forward declarations
std::vector<torch::Tensor> hshadow_render_cuda_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor mask_bb, torch::Tensor hmap, torch::Tensor rechmap, torch::Tensor light_pos);
std::vector<torch::Tensor> reflect_render_cuda_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor hmap, torch::Tensor rechmap, torch::Tensor thresholds);
std::vector<torch::Tensor> glossy_reflect_render_cuda_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor hmap, torch::Tensor rechmap, const int sample_n, const float glossy);
torch::Tensor ray_intersect_cuda_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor hmap, torch::Tensor rechmap, torch::Tensor rd_map);
torch::Tensor ray_scene_intersect_cuda_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor hmap, torch::Tensor ro, torch::Tensor rd, float dh);

// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

/*  Heightmap Shadow Rendering 

    rgb:        B x 3 x H x W

    mask:       B x 1 x H x W

    mask:       B x 1 

    hmap:       B x 1 x H x W

    rechmap:    B x 1 x H x W

    light_pos:  B x 1 (x,y,h)

*/
std::vector<torch::Tensor> hshadow_render_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor bb, torch::Tensor hmap, torch::Tensor rechmap, torch::Tensor light_pos) {
    CHECK_INPUT(rgb);
    CHECK_INPUT(mask);
    CHECK_INPUT(bb);
    CHECK_INPUT(hmap);
    CHECK_INPUT(rechmap);
    CHECK_INPUT(light_pos);

    return hshadow_render_cuda_forward(rgb, mask, bb, hmap, rechmap, light_pos);
}

std::vector<torch::Tensor> reflect_render_forward(torch::Tensor rgb, torch::Tensor mask, torch::Tensor hmap, torch::Tensor rechmap, torch::Tensor thresholds) {
    CHECK_INPUT(rgb);
    CHECK_INPUT(mask);
    CHECK_INPUT(hmap);
    CHECK_INPUT(rechmap);
    CHECK_INPUT(thresholds);

    return reflect_render_cuda_forward(rgb, mask, hmap, rechmap, thresholds);
}


std::vector<torch::Tensor> glossy_reflect_render_forward(torch::Tensor rgb,

                                                  torch::Tensor mask,

                                                  torch::Tensor hmap,

                                                  torch::Tensor rechmap,

                                                  int sample_n,

                                                  float glossy) {
    CHECK_INPUT(rgb);
    CHECK_INPUT(mask);
    CHECK_INPUT(hmap);
    CHECK_INPUT(rechmap);

    return glossy_reflect_render_cuda_forward(rgb, mask, hmap, rechmap, sample_n, glossy);
}


torch::Tensor ray_intersect_foward(torch::Tensor rgb,

                            torch::Tensor mask,

                            torch::Tensor hmap,

                            torch::Tensor rechmap,

                            torch::Tensor rd_map) {
    CHECK_INPUT(rgb);
    CHECK_INPUT(mask);
    CHECK_INPUT(hmap);
    CHECK_INPUT(rechmap);

    return ray_intersect_cuda_forward(rgb, mask, hmap, rechmap, rd_map);
}

torch::Tensor ray_scene_intersect_foward(torch::Tensor rgb,

                                         torch::Tensor mask,

                                         torch::Tensor hmap,

                                         torch::Tensor ro,

                                         torch::Tensor rd,

                                         float dh) {
    CHECK_INPUT(rgb);
    CHECK_INPUT(mask);
    CHECK_INPUT(hmap);
    CHECK_INPUT(ro);
    CHECK_INPUT(rd);

    return ray_scene_intersect_cuda_forward(rgb, mask, hmap, ro, rd, dh);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &hshadow_render_forward, "Heightmap Shadow Rendering Forward (CUDA)");
    m.def("reflection", &reflect_render_forward, "Reflection Rendering Forward (CUDA)");
    m.def("glossy_reflection", &glossy_reflect_render_forward, "Glossy Reflection Rendering Forward (CUDA)");
    m.def("ray_intersect", &ray_intersect_foward, "Ray scene intersection");
    m.def("ray_scene_intersect", &ray_scene_intersect_foward, "Ray scene intersection");
}