Spaces:
Paused
Paused
File size: 7,763 Bytes
b177539 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
#include <torch/extension.h>
#include <vector>
// CUDA forward declarations
std::vector<torch::Tensor> infer_t_minmax_cuda(
torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor xyz_min, torch::Tensor xyz_max,
const float near, const float far);
torch::Tensor infer_n_samples_cuda(torch::Tensor rays_d, torch::Tensor t_min, torch::Tensor t_max, const float stepdist);
std::vector<torch::Tensor> infer_ray_start_dir_cuda(torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_min);
std::vector<torch::Tensor> sample_pts_on_rays_cuda(
torch::Tensor rays_o, torch::Tensor rays_d,
torch::Tensor xyz_min, torch::Tensor xyz_max,
const float near, const float far, const float stepdist);
std::vector<torch::Tensor> sample_ndc_pts_on_rays_cuda(
torch::Tensor rays_o, torch::Tensor rays_d,
torch::Tensor xyz_min, torch::Tensor xyz_max,
const int N_samples);
torch::Tensor sample_bg_pts_on_rays_cuda(
torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_max,
const float bg_preserve, const int N_samples);
torch::Tensor maskcache_lookup_cuda(torch::Tensor world, torch::Tensor xyz, torch::Tensor xyz2ijk_scale, torch::Tensor xyz2ijk_shift);
std::vector<torch::Tensor> raw2alpha_cuda(torch::Tensor density, const float shift, const float interval);
std::vector<torch::Tensor> raw2alpha_nonuni_cuda(torch::Tensor density, const float shift, torch::Tensor interval);
torch::Tensor raw2alpha_backward_cuda(torch::Tensor exp, torch::Tensor grad_back, const float interval);
torch::Tensor raw2alpha_nonuni_backward_cuda(torch::Tensor exp, torch::Tensor grad_back, torch::Tensor interval);
std::vector<torch::Tensor> alpha2weight_cuda(torch::Tensor alpha, torch::Tensor ray_id, const int n_rays);
torch::Tensor alpha2weight_backward_cuda(
torch::Tensor alpha, torch::Tensor weight, torch::Tensor T, torch::Tensor alphainv_last,
torch::Tensor i_start, torch::Tensor i_end, const int n_rays,
torch::Tensor grad_weights, torch::Tensor grad_last);
// C++ interface
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> infer_t_minmax(
torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor xyz_min, torch::Tensor xyz_max,
const float near, const float far) {
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(xyz_min);
CHECK_INPUT(xyz_max);
return infer_t_minmax_cuda(rays_o, rays_d, xyz_min, xyz_max, near, far);
}
torch::Tensor infer_n_samples(torch::Tensor rays_d, torch::Tensor t_min, torch::Tensor t_max, const float stepdist) {
CHECK_INPUT(rays_d);
CHECK_INPUT(t_min);
CHECK_INPUT(t_max);
return infer_n_samples_cuda(rays_d, t_min, t_max, stepdist);
}
std::vector<torch::Tensor> infer_ray_start_dir(torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_min) {
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(t_min);
return infer_ray_start_dir_cuda(rays_o, rays_d, t_min);
}
std::vector<torch::Tensor> sample_pts_on_rays(
torch::Tensor rays_o, torch::Tensor rays_d,
torch::Tensor xyz_min, torch::Tensor xyz_max,
const float near, const float far, const float stepdist) {
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(xyz_min);
CHECK_INPUT(xyz_max);
assert(rays_o.dim()==2);
assert(rays_o.size(1)==3);
return sample_pts_on_rays_cuda(rays_o, rays_d, xyz_min, xyz_max, near, far, stepdist);
}
std::vector<torch::Tensor> sample_ndc_pts_on_rays(
torch::Tensor rays_o, torch::Tensor rays_d,
torch::Tensor xyz_min, torch::Tensor xyz_max,
const int N_samples) {
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(xyz_min);
CHECK_INPUT(xyz_max);
assert(rays_o.dim()==2);
assert(rays_o.size(1)==3);
return sample_ndc_pts_on_rays_cuda(rays_o, rays_d, xyz_min, xyz_max, N_samples);
}
torch::Tensor sample_bg_pts_on_rays(
torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_max,
const float bg_preserve, const int N_samples) {
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(t_max);
return sample_bg_pts_on_rays_cuda(rays_o, rays_d, t_max, bg_preserve, N_samples);
}
torch::Tensor maskcache_lookup(torch::Tensor world, torch::Tensor xyz, torch::Tensor xyz2ijk_scale, torch::Tensor xyz2ijk_shift) {
CHECK_INPUT(world);
CHECK_INPUT(xyz);
CHECK_INPUT(xyz2ijk_scale);
CHECK_INPUT(xyz2ijk_shift);
assert(world.dim()==3);
assert(xyz.dim()==2);
assert(xyz.size(1)==3);
return maskcache_lookup_cuda(world, xyz, xyz2ijk_scale, xyz2ijk_shift);
}
std::vector<torch::Tensor> raw2alpha(torch::Tensor density, const float shift, const float interval) {
CHECK_INPUT(density);
assert(density.dim()==1);
return raw2alpha_cuda(density, shift, interval);
}
std::vector<torch::Tensor> raw2alpha_nonuni(torch::Tensor density, const float shift, torch::Tensor interval) {
CHECK_INPUT(density);
assert(density.dim()==1);
return raw2alpha_nonuni_cuda(density, shift, interval);
}
torch::Tensor raw2alpha_backward(torch::Tensor exp, torch::Tensor grad_back, const float interval) {
CHECK_INPUT(exp);
CHECK_INPUT(grad_back);
return raw2alpha_backward_cuda(exp, grad_back, interval);
}
torch::Tensor raw2alpha_nonuni_backward(torch::Tensor exp, torch::Tensor grad_back, torch::Tensor interval) {
CHECK_INPUT(exp);
CHECK_INPUT(grad_back);
return raw2alpha_nonuni_backward_cuda(exp, grad_back, interval);
}
std::vector<torch::Tensor> alpha2weight(torch::Tensor alpha, torch::Tensor ray_id, const int n_rays) {
CHECK_INPUT(alpha);
CHECK_INPUT(ray_id);
assert(alpha.dim()==1);
assert(ray_id.dim()==1);
assert(alpha.sizes()==ray_id.sizes());
return alpha2weight_cuda(alpha, ray_id, n_rays);
}
torch::Tensor alpha2weight_backward(
torch::Tensor alpha, torch::Tensor weight, torch::Tensor T, torch::Tensor alphainv_last,
torch::Tensor i_start, torch::Tensor i_end, const int n_rays,
torch::Tensor grad_weights, torch::Tensor grad_last) {
CHECK_INPUT(alpha);
CHECK_INPUT(weight);
CHECK_INPUT(T);
CHECK_INPUT(alphainv_last);
CHECK_INPUT(i_start);
CHECK_INPUT(i_end);
CHECK_INPUT(grad_weights);
CHECK_INPUT(grad_last);
return alpha2weight_backward_cuda(
alpha, weight, T, alphainv_last,
i_start, i_end, n_rays,
grad_weights, grad_last);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("infer_t_minmax", &infer_t_minmax, "Inference t_min and t_max of ray-bbox intersection");
m.def("infer_n_samples", &infer_n_samples, "Inference the number of points to sample on each ray");
m.def("infer_ray_start_dir", &infer_ray_start_dir, "Inference the starting point and shooting direction of each ray");
m.def("sample_pts_on_rays", &sample_pts_on_rays, "Sample points on rays");
m.def("sample_ndc_pts_on_rays", &sample_ndc_pts_on_rays, "Sample points on rays");
m.def("sample_bg_pts_on_rays", &sample_bg_pts_on_rays, "Sample points on bg");
m.def("maskcache_lookup", &maskcache_lookup, "Lookup to skip know freespace.");
m.def("raw2alpha", &raw2alpha, "Raw values [-inf, inf] to alpha [0, 1].");
m.def("raw2alpha_backward", &raw2alpha_backward, "Backward pass of the raw to alpha");
m.def("raw2alpha_nonuni", &raw2alpha_nonuni, "Raw values [-inf, inf] to alpha [0, 1].");
m.def("raw2alpha_nonuni_backward", &raw2alpha_nonuni_backward, "Backward pass of the raw to alpha");
m.def("alpha2weight", &alpha2weight, "Per-point alpha to accumulated blending weight");
m.def("alpha2weight_backward", &alpha2weight_backward, "Backward pass of alpha2weight");
}
|