Spaces:
Paused
Paused
| // CUDA forward declarations | |
| void adam_upd_cuda( | |
| torch::Tensor param, | |
| torch::Tensor grad, | |
| torch::Tensor exp_avg, | |
| torch::Tensor exp_avg_sq, | |
| int step, float beta1, float beta2, float lr, float eps); | |
| void masked_adam_upd_cuda( | |
| torch::Tensor param, | |
| torch::Tensor grad, | |
| torch::Tensor exp_avg, | |
| torch::Tensor exp_avg_sq, | |
| int step, float beta1, float beta2, float lr, float eps); | |
| void adam_upd_with_perlr_cuda( | |
| torch::Tensor param, | |
| torch::Tensor grad, | |
| torch::Tensor exp_avg, | |
| torch::Tensor exp_avg_sq, | |
| torch::Tensor perlr, | |
| int step, float beta1, float beta2, float lr, float eps); | |
| // C++ interface | |
| void adam_upd( | |
| torch::Tensor param, | |
| torch::Tensor grad, | |
| torch::Tensor exp_avg, | |
| torch::Tensor exp_avg_sq, | |
| int step, float beta1, float beta2, float lr, float eps) { | |
| CHECK_INPUT(param); | |
| CHECK_INPUT(grad); | |
| CHECK_INPUT(exp_avg); | |
| CHECK_INPUT(exp_avg_sq); | |
| adam_upd_cuda(param, grad, exp_avg, exp_avg_sq, | |
| step, beta1, beta2, lr, eps); | |
| } | |
| void masked_adam_upd( | |
| torch::Tensor param, | |
| torch::Tensor grad, | |
| torch::Tensor exp_avg, | |
| torch::Tensor exp_avg_sq, | |
| int step, float beta1, float beta2, float lr, float eps) { | |
| CHECK_INPUT(param); | |
| CHECK_INPUT(grad); | |
| CHECK_INPUT(exp_avg); | |
| CHECK_INPUT(exp_avg_sq); | |
| masked_adam_upd_cuda(param, grad, exp_avg, exp_avg_sq, | |
| step, beta1, beta2, lr, eps); | |
| } | |
| void adam_upd_with_perlr( | |
| torch::Tensor param, | |
| torch::Tensor grad, | |
| torch::Tensor exp_avg, | |
| torch::Tensor exp_avg_sq, | |
| torch::Tensor perlr, | |
| int step, float beta1, float beta2, float lr, float eps) { | |
| CHECK_INPUT(param); | |
| CHECK_INPUT(grad); | |
| CHECK_INPUT(exp_avg); | |
| CHECK_INPUT(exp_avg_sq); | |
| adam_upd_with_perlr_cuda(param, grad, exp_avg, exp_avg_sq, perlr, | |
| step, beta1, beta2, lr, eps); | |
| } | |
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
| m.def("adam_upd", &adam_upd, | |
| "Adam update"); | |
| m.def("masked_adam_upd", &masked_adam_upd, | |
| "Adam update ignoring zero grad"); | |
| m.def("adam_upd_with_perlr", &adam_upd_with_perlr, | |
| "Adam update ignoring zero grad with per-voxel lr"); | |
| } | |