| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #pragma once |
| | #include <torch/extension.h> |
| | #include <cstdio> |
| | #include <tuple> |
| | #include <string> |
| | |
| | std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> |
| | RasterizeGaussiansCUDA( |
| | const torch::Tensor& background, |
| | const torch::Tensor& means3D, |
| | const torch::Tensor& colors, |
| | const torch::Tensor& opacity, |
| | const torch::Tensor& scales, |
| | const torch::Tensor& rotations, |
| | const float scale_modifier, |
| | const torch::Tensor& cov3D_precomp, |
| | const torch::Tensor& viewmatrix, |
| | const torch::Tensor& projmatrix, |
| | const float tan_fovx, |
| | const float tan_fovy, |
| | const int image_height, |
| | const int image_width, |
| | const torch::Tensor& sh, |
| | const int degree, |
| | const torch::Tensor& campos, |
| | const bool prefiltered, |
| | const bool debug); |
| |
|
| | std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> |
| | RasterizeGaussiansBackwardCUDA( |
| | const torch::Tensor& background, |
| | const torch::Tensor& means3D, |
| | const torch::Tensor& radii, |
| | const torch::Tensor& colors, |
| | const torch::Tensor& scales, |
| | const torch::Tensor& rotations, |
| | const float scale_modifier, |
| | const torch::Tensor& cov3D_precomp, |
| | const torch::Tensor& viewmatrix, |
| | const torch::Tensor& projmatrix, |
| | const float tan_fovx, |
| | const float tan_fovy, |
| | const torch::Tensor& dL_dout_color, |
| | const torch::Tensor& dL_dout_depth, |
| | const torch::Tensor& dL_dout_alpha, |
| | const torch::Tensor& sh, |
| | const int degree, |
| | const torch::Tensor& campos, |
| | const torch::Tensor& geomBuffer, |
| | const int R, |
| | const torch::Tensor& binningBuffer, |
| | const torch::Tensor& imageBuffer, |
| | const torch::Tensor& alpha, |
| | const bool debug); |
| | |
| | torch::Tensor markVisible( |
| | torch::Tensor& means3D, |
| | torch::Tensor& viewmatrix, |
| | torch::Tensor& projmatrix); |