Spaces:
Paused
Paused
| // CUDA forward declarations | |
| void total_variation_add_grad_cuda(torch::Tensor param, torch::Tensor grad, float wx, float wy, float wz, bool dense_mode); | |
| // C++ interface | |
| void total_variation_add_grad(torch::Tensor param, torch::Tensor grad, float wx, float wy, float wz, bool dense_mode) { | |
| CHECK_INPUT(param); | |
| CHECK_INPUT(grad); | |
| total_variation_add_grad_cuda(param, grad, wx, wy, wz, dense_mode); | |
| } | |
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
| m.def("total_variation_add_grad", &total_variation_add_grad, "Add total variation grad"); | |
| } | |