Our3D / lib /cuda /render_utils_kernel.cu
yansong1616's picture
Upload 384 files
b177539 verified
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
/*
Points sampling helper functions.
*/
template <typename scalar_t>
__global__ void infer_t_minmax_cuda_kernel(
scalar_t* __restrict__ rays_o,
scalar_t* __restrict__ rays_d,
scalar_t* __restrict__ xyz_min,
scalar_t* __restrict__ xyz_max,
const float near, const float far, const int n_rays,
scalar_t* __restrict__ t_min,
scalar_t* __restrict__ t_max) {
const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
if(i_ray<n_rays) {
const int offset = i_ray * 3;
float vx = ((rays_d[offset ]==0) ? 1e-6 : rays_d[offset ]);
float vy = ((rays_d[offset+1]==0) ? 1e-6 : rays_d[offset+1]);
float vz = ((rays_d[offset+2]==0) ? 1e-6 : rays_d[offset+2]);
float ax = (xyz_max[0] - rays_o[offset ]) / vx;
float ay = (xyz_max[1] - rays_o[offset+1]) / vy;
float az = (xyz_max[2] - rays_o[offset+2]) / vz;
float bx = (xyz_min[0] - rays_o[offset ]) / vx;
float by = (xyz_min[1] - rays_o[offset+1]) / vy;
float bz = (xyz_min[2] - rays_o[offset+2]) / vz;
t_min[i_ray] = max(min(max(max(min(ax, bx), min(ay, by)), min(az, bz)), far), near);
t_max[i_ray] = max(min(min(min(max(ax, bx), max(ay, by)), max(az, bz)), far), near);
}
}
template <typename scalar_t>
__global__ void infer_n_samples_cuda_kernel(
scalar_t* __restrict__ rays_d,
scalar_t* __restrict__ t_min,
scalar_t* __restrict__ t_max,
const float stepdist,
const int n_rays,
int64_t* __restrict__ n_samples) {
const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
if(i_ray<n_rays) {
const int offset = i_ray * 3;
const float rnorm = sqrt(
rays_d[offset ]*rays_d[offset ] +\
rays_d[offset+1]*rays_d[offset+1] +\
rays_d[offset+2]*rays_d[offset+2]);
// at least 1 point for easier implementation in the later sample_pts_on_rays_cuda
n_samples[i_ray] = max(ceil((t_max[i_ray]-t_min[i_ray]) * rnorm / stepdist), 1.);
}
}
template <typename scalar_t>
__global__ void infer_ray_start_dir_cuda_kernel(
scalar_t* __restrict__ rays_o,
scalar_t* __restrict__ rays_d,
scalar_t* __restrict__ t_min,
const int n_rays,
scalar_t* __restrict__ rays_start,
scalar_t* __restrict__ rays_dir) {
const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
if(i_ray<n_rays) {
const int offset = i_ray * 3;
const float rnorm = sqrt(
rays_d[offset ]*rays_d[offset ] +\
rays_d[offset+1]*rays_d[offset+1] +\
rays_d[offset+2]*rays_d[offset+2]);
rays_start[offset ] = rays_o[offset ] + rays_d[offset ] * t_min[i_ray];
rays_start[offset+1] = rays_o[offset+1] + rays_d[offset+1] * t_min[i_ray];
rays_start[offset+2] = rays_o[offset+2] + rays_d[offset+2] * t_min[i_ray];
rays_dir [offset ] = rays_d[offset ] / rnorm;
rays_dir [offset+1] = rays_d[offset+1] / rnorm;
rays_dir [offset+2] = rays_d[offset+2] / rnorm;
}
}
std::vector<torch::Tensor> infer_t_minmax_cuda(
torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor xyz_min, torch::Tensor xyz_max,
const float near, const float far) {
const int n_rays = rays_o.size(0);
auto t_min = torch::empty({n_rays}, rays_o.options());
auto t_max = torch::empty({n_rays}, rays_o.options());
const int threads = 256;
const int blocks = (n_rays + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "infer_t_minmax_cuda", ([&] {
infer_t_minmax_cuda_kernel<scalar_t><<<blocks, threads>>>(
rays_o.data<scalar_t>(),
rays_d.data<scalar_t>(),
xyz_min.data<scalar_t>(),
xyz_max.data<scalar_t>(),
near, far, n_rays,
t_min.data<scalar_t>(),
t_max.data<scalar_t>());
}));
return {t_min, t_max};
}
torch::Tensor infer_n_samples_cuda(torch::Tensor rays_d, torch::Tensor t_min, torch::Tensor t_max, const float stepdist) {
const int n_rays = t_min.size(0);
auto n_samples = torch::empty({n_rays}, torch::dtype(torch::kInt64).device(torch::kCUDA));
const int threads = 256;
const int blocks = (n_rays + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(t_min.type(), "infer_n_samples_cuda", ([&] {
infer_n_samples_cuda_kernel<scalar_t><<<blocks, threads>>>(
rays_d.data<scalar_t>(),
t_min.data<scalar_t>(),
t_max.data<scalar_t>(),
stepdist,
n_rays,
n_samples.data<int64_t>());
}));
return n_samples;
}
std::vector<torch::Tensor> infer_ray_start_dir_cuda(torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_min) {
const int n_rays = rays_o.size(0);
const int threads = 256;
const int blocks = (n_rays + threads - 1) / threads;
auto rays_start = torch::empty_like(rays_o);
auto rays_dir = torch::empty_like(rays_o);
AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "infer_ray_start_dir_cuda", ([&] {
infer_ray_start_dir_cuda_kernel<scalar_t><<<blocks, threads>>>(
rays_o.data<scalar_t>(),
rays_d.data<scalar_t>(),
t_min.data<scalar_t>(),
n_rays,
rays_start.data<scalar_t>(),
rays_dir.data<scalar_t>());
}));
return {rays_start, rays_dir};
}
/*
Sampling query points on rays.
*/
__global__ void __set_1_at_ray_seg_start(
int64_t* __restrict__ ray_id,
int64_t* __restrict__ N_steps_cumsum,
const int n_rays) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if(0<idx && idx<n_rays) {
ray_id[N_steps_cumsum[idx-1]] = 1;
}
}
__global__ void __set_step_id(
int64_t* __restrict__ step_id,
int64_t* __restrict__ ray_id,
int64_t* __restrict__ N_steps_cumsum,
const int total_len) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if(idx<total_len) {
const int rid = ray_id[idx];
step_id[idx] = idx - ((rid!=0) ? N_steps_cumsum[rid-1] : 0);
}
}
template <typename scalar_t>
__global__ void sample_pts_on_rays_cuda_kernel(
scalar_t* __restrict__ rays_start,
scalar_t* __restrict__ rays_dir,
scalar_t* __restrict__ xyz_min,
scalar_t* __restrict__ xyz_max,
int64_t* __restrict__ ray_id,
int64_t* __restrict__ step_id,
const float stepdist, const int total_len,
scalar_t* __restrict__ rays_pts,
bool* __restrict__ mask_outbbox) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if(idx<total_len) {
const int i_ray = ray_id[idx];
const int i_step = step_id[idx];
const int offset_p = idx * 3;
const int offset_r = i_ray * 3;
const float dist = stepdist * i_step;
const float px = rays_start[offset_r ] + rays_dir[offset_r ] * dist;
const float py = rays_start[offset_r+1] + rays_dir[offset_r+1] * dist;
const float pz = rays_start[offset_r+2] + rays_dir[offset_r+2] * dist;
rays_pts[offset_p ] = px;
rays_pts[offset_p+1] = py;
rays_pts[offset_p+2] = pz;
mask_outbbox[idx] = (xyz_min[0]>px) | (xyz_min[1]>py) | (xyz_min[2]>pz) | \
(xyz_max[0]<px) | (xyz_max[1]<py) | (xyz_max[2]<pz);
}
}
std::vector<torch::Tensor> sample_pts_on_rays_cuda(
torch::Tensor rays_o, torch::Tensor rays_d,
torch::Tensor xyz_min, torch::Tensor xyz_max,
const float near, const float far, const float stepdist) {
const int threads = 256;
const int n_rays = rays_o.size(0);
// Compute ray-bbox intersection
auto t_minmax = infer_t_minmax_cuda(rays_o, rays_d, xyz_min, xyz_max, near, far);
auto t_min = t_minmax[0];
auto t_max = t_minmax[1];
// Compute the number of points required.
// Assign ray index and step index to each.
auto N_steps = infer_n_samples_cuda(rays_d, t_min, t_max, stepdist);
auto N_steps_cumsum = N_steps.cumsum(0);
const int total_len = N_steps.sum().item<int>();
auto ray_id = torch::zeros({total_len}, torch::dtype(torch::kInt64).device(torch::kCUDA));
__set_1_at_ray_seg_start<<<(n_rays+threads-1)/threads, threads>>>(
ray_id.data<int64_t>(), N_steps_cumsum.data<int64_t>(), n_rays);
ray_id.cumsum_(0);
auto step_id = torch::empty({total_len}, ray_id.options());
__set_step_id<<<(total_len+threads-1)/threads, threads>>>(
step_id.data<int64_t>(), ray_id.data<int64_t>(), N_steps_cumsum.data<int64_t>(), total_len);
// Compute the global xyz of each point
auto rays_start_dir = infer_ray_start_dir_cuda(rays_o, rays_d, t_min);
auto rays_start = rays_start_dir[0];
auto rays_dir = rays_start_dir[1];
auto rays_pts = torch::empty({total_len, 3}, torch::dtype(rays_o.dtype()).device(torch::kCUDA));
auto mask_outbbox = torch::empty({total_len}, torch::dtype(torch::kBool).device(torch::kCUDA));
AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "sample_pts_on_rays_cuda", ([&] {
sample_pts_on_rays_cuda_kernel<scalar_t><<<(total_len+threads-1)/threads, threads>>>(
rays_start.data<scalar_t>(),
rays_dir.data<scalar_t>(),
xyz_min.data<scalar_t>(),
xyz_max.data<scalar_t>(),
ray_id.data<int64_t>(),
step_id.data<int64_t>(),
stepdist, total_len,
rays_pts.data<scalar_t>(),
mask_outbbox.data<bool>());
}));
return {rays_pts, mask_outbbox, ray_id, step_id, N_steps, t_min, t_max};
}
template <typename scalar_t>
__global__ void sample_ndc_pts_on_rays_cuda_kernel(
const scalar_t* __restrict__ rays_o,
const scalar_t* __restrict__ rays_d,
const scalar_t* __restrict__ xyz_min,
const scalar_t* __restrict__ xyz_max,
const int N_samples, const int n_rays,
scalar_t* __restrict__ rays_pts,
bool* __restrict__ mask_outbbox) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if(idx<N_samples*n_rays) {
const int i_ray = idx / N_samples;
const int i_step = idx % N_samples;
const int offset_p = idx * 3;
const int offset_r = i_ray * 3;
const float dist = ((float)i_step) / (N_samples-1);
const float px = rays_o[offset_r ] + rays_d[offset_r ] * dist;
const float py = rays_o[offset_r+1] + rays_d[offset_r+1] * dist;
const float pz = rays_o[offset_r+2] + rays_d[offset_r+2] * dist;
rays_pts[offset_p ] = px;
rays_pts[offset_p+1] = py;
rays_pts[offset_p+2] = pz;
mask_outbbox[idx] = (xyz_min[0]>px) | (xyz_min[1]>py) | (xyz_min[2]>pz) | \
(xyz_max[0]<px) | (xyz_max[1]<py) | (xyz_max[2]<pz);
}
}
std::vector<torch::Tensor> sample_ndc_pts_on_rays_cuda(
torch::Tensor rays_o, torch::Tensor rays_d,
torch::Tensor xyz_min, torch::Tensor xyz_max,
const int N_samples) {
const int threads = 256;
const int n_rays = rays_o.size(0);
auto rays_pts = torch::empty({n_rays, N_samples, 3}, torch::dtype(rays_o.dtype()).device(torch::kCUDA));
auto mask_outbbox = torch::empty({n_rays, N_samples}, torch::dtype(torch::kBool).device(torch::kCUDA));
AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "sample_ndc_pts_on_rays_cuda", ([&] {
sample_ndc_pts_on_rays_cuda_kernel<scalar_t><<<(n_rays*N_samples+threads-1)/threads, threads>>>(
rays_o.data<scalar_t>(),
rays_d.data<scalar_t>(),
xyz_min.data<scalar_t>(),
xyz_max.data<scalar_t>(),
N_samples, n_rays,
rays_pts.data<scalar_t>(),
mask_outbbox.data<bool>());
}));
return {rays_pts, mask_outbbox};
}
template <typename scalar_t>
__device__ __forceinline__ scalar_t norm3(const scalar_t x, const scalar_t y, const scalar_t z) {
return sqrt(x*x + y*y + z*z);
}
template <typename scalar_t>
__global__ void sample_bg_pts_on_rays_cuda_kernel(
const scalar_t* __restrict__ rays_o,
const scalar_t* __restrict__ rays_d,
const scalar_t* __restrict__ t_max,
const float bg_preserve,
const int N_samples, const int n_rays,
scalar_t* __restrict__ rays_pts) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if(idx<N_samples*n_rays) {
const int i_ray = idx / N_samples;
const int i_step = idx % N_samples;
const int offset_p = idx * 3;
const int offset_r = i_ray * 3;
/* Original pytorch implementation
ori_t_outer = t_max[:,None] - 1 + 1 / torch.linspace(1, 0, N_outer+1)[:-1]
ori_ray_pts_outer = (rays_o[:,None,:] + rays_d[:,None,:] * ori_t_outer[:,:,None]).reshape(-1,3)
t_outer = ori_ray_pts_outer.norm(dim=-1)
R_outer = t_outer / ori_ray_pts_outer.abs().amax(1)
# r = R * R / t
o2i_p = R_outer.pow(2) / t_outer.pow(2) * (1-self.bg_preserve) + R_outer / t_outer * self.bg_preserve
ray_pts_outer = (ori_ray_pts_outer * o2i_p[:,None]).reshape(len(rays_o), -1, 3)
*/
const float t_inner = t_max[i_ray];
const float ori_t_outer = t_inner - 1. + 1. / (1. - ((float)i_step) / N_samples);
const float ori_ray_pts_x = rays_o[offset_r ] + rays_d[offset_r ] * ori_t_outer;
const float ori_ray_pts_y = rays_o[offset_r+1] + rays_d[offset_r+1] * ori_t_outer;
const float ori_ray_pts_z = rays_o[offset_r+2] + rays_d[offset_r+2] * ori_t_outer;
const float t_outer = norm3(ori_ray_pts_x, ori_ray_pts_y, ori_ray_pts_z);
const float ori_ray_pts_m = max(abs(ori_ray_pts_x), max(abs(ori_ray_pts_y), abs(ori_ray_pts_z)));
const float R_outer = t_outer / ori_ray_pts_m;
const float o2i_p = R_outer*R_outer / (t_outer*t_outer) * (1.-bg_preserve) + R_outer / t_outer * bg_preserve;
const float px = ori_ray_pts_x * o2i_p;
const float py = ori_ray_pts_y * o2i_p;
const float pz = ori_ray_pts_z * o2i_p;
rays_pts[offset_p ] = px;
rays_pts[offset_p+1] = py;
rays_pts[offset_p+2] = pz;
}
}
torch::Tensor sample_bg_pts_on_rays_cuda(
torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_max,
const float bg_preserve, const int N_samples) {
const int threads = 256;
const int n_rays = rays_o.size(0);
auto rays_pts = torch::empty({n_rays, N_samples, 3}, torch::dtype(rays_o.dtype()).device(torch::kCUDA));
AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "sample_bg_pts_on_rays_cuda", ([&] {
sample_bg_pts_on_rays_cuda_kernel<scalar_t><<<(n_rays*N_samples+threads-1)/threads, threads>>>(
rays_o.data<scalar_t>(),
rays_d.data<scalar_t>(),
t_max.data<scalar_t>(),
bg_preserve,
N_samples, n_rays,
rays_pts.data<scalar_t>());
}));
return rays_pts;
}
/*
MaskCache lookup to skip known freespace.
*/
static __forceinline__ __device__
bool check_xyz(int i, int j, int k, int sz_i, int sz_j, int sz_k) {
return (0 <= i) && (i < sz_i) && (0 <= j) && (j < sz_j) && (0 <= k) && (k < sz_k);
}
template <typename scalar_t>
__global__ void maskcache_lookup_cuda_kernel(
bool* __restrict__ world,
scalar_t* __restrict__ xyz,
bool* __restrict__ out,
scalar_t* __restrict__ xyz2ijk_scale,
scalar_t* __restrict__ xyz2ijk_shift,
const int sz_i, const int sz_j, const int sz_k, const int n_pts) {
const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
if(i_pt<n_pts) {
const int offset = i_pt * 3;
const int i = round(xyz[offset ] * xyz2ijk_scale[0] + xyz2ijk_shift[0]);
const int j = round(xyz[offset+1] * xyz2ijk_scale[1] + xyz2ijk_shift[1]);
const int k = round(xyz[offset+2] * xyz2ijk_scale[2] + xyz2ijk_shift[2]);
if(check_xyz(i, j, k, sz_i, sz_j, sz_k)) {
out[i_pt] = world[i*sz_j*sz_k + j*sz_k + k];
}
}
}
torch::Tensor maskcache_lookup_cuda(
torch::Tensor world,
torch::Tensor xyz,
torch::Tensor xyz2ijk_scale,
torch::Tensor xyz2ijk_shift) {
const int sz_i = world.size(0);
const int sz_j = world.size(1);
const int sz_k = world.size(2);
const int n_pts = xyz.size(0);
auto out = torch::zeros({n_pts}, torch::dtype(torch::kBool).device(torch::kCUDA));
if(n_pts==0) {
return out;
}
const int threads = 256;
const int blocks = (n_pts + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(xyz.type(), "maskcache_lookup_cuda", ([&] {
maskcache_lookup_cuda_kernel<scalar_t><<<blocks, threads>>>(
world.data<bool>(),
xyz.data<scalar_t>(),
out.data<bool>(),
xyz2ijk_scale.data<scalar_t>(),
xyz2ijk_shift.data<scalar_t>(),
sz_i, sz_j, sz_k, n_pts);
}));
return out;
}
/*
Ray marching helper function.
*/
template <typename scalar_t>
__global__ void raw2alpha_cuda_kernel(
scalar_t* __restrict__ density,
const float shift, const float interval, const int n_pts,
scalar_t* __restrict__ exp_d,
scalar_t* __restrict__ alpha) {
const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
if(i_pt<n_pts) {
const scalar_t e = exp(density[i_pt] + shift); // can be inf
exp_d[i_pt] = e;
alpha[i_pt] = 1 - pow((1 + e), (-interval));
}
}
template <typename scalar_t>
__global__ void raw2alpha_nonuni_cuda_kernel(
scalar_t* __restrict__ density,
const float shift, scalar_t* __restrict__ interval, const int n_pts,
scalar_t* __restrict__ exp_d,
scalar_t* __restrict__ alpha) {
const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
if(i_pt<n_pts) {
const scalar_t e = exp(density[i_pt] + shift); // can be inf
exp_d[i_pt] = e;
alpha[i_pt] = 1 - pow((1 + e), (-interval[i_pt]));
}
}
std::vector<torch::Tensor> raw2alpha_cuda(torch::Tensor density, const float shift, const float interval) {
const int n_pts = density.size(0);
auto exp_d = torch::empty_like(density);
auto alpha = torch::empty_like(density);
if(n_pts==0) {
return {exp_d, alpha};
}
const int threads = 256;
const int blocks = (n_pts + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(density.type(), "raw2alpha_cuda", ([&] {
raw2alpha_cuda_kernel<scalar_t><<<blocks, threads>>>(
density.data<scalar_t>(),
shift, interval, n_pts,
exp_d.data<scalar_t>(),
alpha.data<scalar_t>());
}));
return {exp_d, alpha};
}
std::vector<torch::Tensor> raw2alpha_nonuni_cuda(torch::Tensor density, const float shift, torch::Tensor interval) {
const int n_pts = density.size(0);
auto exp_d = torch::empty_like(density);
auto alpha = torch::empty_like(density);
if(n_pts==0) {
return {exp_d, alpha};
}
const int threads = 256;
const int blocks = (n_pts + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(density.type(), "raw2alpha_cuda", ([&] {
raw2alpha_nonuni_cuda_kernel<scalar_t><<<blocks, threads>>>(
density.data<scalar_t>(),
shift, interval.data<scalar_t>(), n_pts,
exp_d.data<scalar_t>(),
alpha.data<scalar_t>());
}));
return {exp_d, alpha};
}
template <typename scalar_t>
__global__ void raw2alpha_backward_cuda_kernel(
scalar_t* __restrict__ exp_d,
scalar_t* __restrict__ grad_back,
const float interval, const int n_pts,
scalar_t* __restrict__ grad) {
const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
if(i_pt<n_pts) {
grad[i_pt] = min(exp_d[i_pt], 1e10) * pow((1+exp_d[i_pt]), (-interval-1)) * interval * grad_back[i_pt];
}
}
template <typename scalar_t>
__global__ void raw2alpha_nonuni_backward_cuda_kernel(
scalar_t* __restrict__ exp_d,
scalar_t* __restrict__ grad_back,
scalar_t* __restrict__ interval, const int n_pts,
scalar_t* __restrict__ grad) {
const int i_pt = blockIdx.x * blockDim.x + threadIdx.x;
if(i_pt<n_pts) {
grad[i_pt] = min(exp_d[i_pt], 1e10) * pow((1+exp_d[i_pt]), (-interval[i_pt]-1)) * interval[i_pt] * grad_back[i_pt];
}
}
torch::Tensor raw2alpha_backward_cuda(torch::Tensor exp_d, torch::Tensor grad_back, const float interval) {
const int n_pts = exp_d.size(0);
auto grad = torch::empty_like(exp_d);
if(n_pts==0) {
return grad;
}
const int threads = 256;
const int blocks = (n_pts + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(exp_d.type(), "raw2alpha_backward_cuda", ([&] {
raw2alpha_backward_cuda_kernel<scalar_t><<<blocks, threads>>>(
exp_d.data<scalar_t>(),
grad_back.data<scalar_t>(),
interval, n_pts,
grad.data<scalar_t>());
}));
return grad;
}
torch::Tensor raw2alpha_nonuni_backward_cuda(torch::Tensor exp_d, torch::Tensor grad_back, torch::Tensor interval) {
const int n_pts = exp_d.size(0);
auto grad = torch::empty_like(exp_d);
if(n_pts==0) {
return grad;
}
const int threads = 256;
const int blocks = (n_pts + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(exp_d.type(), "raw2alpha_backward_cuda", ([&] {
raw2alpha_nonuni_backward_cuda_kernel<scalar_t><<<blocks, threads>>>(
exp_d.data<scalar_t>(),
grad_back.data<scalar_t>(),
interval.data<scalar_t>(), n_pts,
grad.data<scalar_t>());
}));
return grad;
}
template <typename scalar_t>
__global__ void alpha2weight_cuda_kernel(
scalar_t* __restrict__ alpha,
const int n_rays,
scalar_t* __restrict__ weight,
scalar_t* __restrict__ T,
scalar_t* __restrict__ alphainv_last,
int64_t* __restrict__ i_start,
int64_t* __restrict__ i_end) {
const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
if(i_ray<n_rays) {
const int i_s = i_start[i_ray];
const int i_e_max = i_end[i_ray];
float T_cum = 1.;
int i;
for(i=i_s; i<i_e_max; ++i) {
T[i] = T_cum;
weight[i] = T_cum * alpha[i];
T_cum *= (1. - alpha[i]);
if(T_cum<1e-3) {
i+=1;
break;
}
}
i_end[i_ray] = i;
alphainv_last[i_ray] = T_cum;
}
}
__global__ void __set_i_for_segment_start_end(
int64_t* __restrict__ ray_id,
const int n_pts,
int64_t* __restrict__ i_start,
int64_t* __restrict__ i_end) {
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if(0<index && index<n_pts && ray_id[index]!=ray_id[index-1]) {
i_start[ray_id[index]] = index;
i_end[ray_id[index-1]] = index;
}
}
std::vector<torch::Tensor> alpha2weight_cuda(torch::Tensor alpha, torch::Tensor ray_id, const int n_rays) {
const int n_pts = alpha.size(0);
const int threads = 256;
auto weight = torch::zeros_like(alpha);
auto T = torch::ones_like(alpha);
auto alphainv_last = torch::ones({n_rays}, alpha.options());
auto i_start = torch::zeros({n_rays}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto i_end = torch::zeros({n_rays}, torch::dtype(torch::kInt64).device(torch::kCUDA));
if(n_pts==0) {
return {weight, T, alphainv_last, i_start, i_end};
}
__set_i_for_segment_start_end<<<(n_pts+threads-1)/threads, threads>>>(
ray_id.data<int64_t>(), n_pts, i_start.data<int64_t>(), i_end.data<int64_t>());
i_end[ray_id[n_pts-1]] = n_pts;
const int blocks = (n_rays + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(alpha.type(), "alpha2weight_cuda", ([&] {
alpha2weight_cuda_kernel<scalar_t><<<blocks, threads>>>(
alpha.data<scalar_t>(),
n_rays,
weight.data<scalar_t>(),
T.data<scalar_t>(),
alphainv_last.data<scalar_t>(),
i_start.data<int64_t>(),
i_end.data<int64_t>());
}));
return {weight, T, alphainv_last, i_start, i_end};
}
template <typename scalar_t>
__global__ void alpha2weight_backward_cuda_kernel(
scalar_t* __restrict__ alpha,
scalar_t* __restrict__ weight,
scalar_t* __restrict__ T,
scalar_t* __restrict__ alphainv_last,
int64_t* __restrict__ i_start,
int64_t* __restrict__ i_end,
const int n_rays,
scalar_t* __restrict__ grad_weights,
scalar_t* __restrict__ grad_last,
scalar_t* __restrict__ grad) {
const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
if(i_ray<n_rays) {
const int i_s = i_start[i_ray];
const int i_e = i_end[i_ray];
float back_cum = grad_last[i_ray] * alphainv_last[i_ray];
for(int i=i_e-1; i>=i_s; --i) {
grad[i] = grad_weights[i] * T[i] - back_cum / (1-alpha[i] + 1e-10);
back_cum += grad_weights[i] * weight[i];
}
}
}
torch::Tensor alpha2weight_backward_cuda(
torch::Tensor alpha, torch::Tensor weight, torch::Tensor T, torch::Tensor alphainv_last,
torch::Tensor i_start, torch::Tensor i_end, const int n_rays,
torch::Tensor grad_weights, torch::Tensor grad_last) {
auto grad = torch::zeros_like(alpha);
if(n_rays==0) {
return grad;
}
const int threads = 256;
const int blocks = (n_rays + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(alpha.type(), "alpha2weight_backward_cuda", ([&] {
alpha2weight_backward_cuda_kernel<scalar_t><<<blocks, threads>>>(
alpha.data<scalar_t>(),
weight.data<scalar_t>(),
T.data<scalar_t>(),
alphainv_last.data<scalar_t>(),
i_start.data<int64_t>(),
i_end.data<int64_t>(),
n_rays,
grad_weights.data<scalar_t>(),
grad_last.data<scalar_t>(),
grad.data<scalar_t>());
}));
return grad;
}