Spaces:
Runtime error
Runtime error
| // CUDA forward declarations | |
| at::Tensor forget_mult_cuda_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first); | |
| // C++ interface | |
| at::Tensor forget_mult_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first) { | |
| CHECK_INPUT(x); CHECK_INPUT(f); CHECK_INPUT(output); | |
| return forget_mult_cuda_forward(x, f, output, batch_first); | |
| } | |
| std::vector<at::Tensor> forget_mult_cuda_backward(at::Tensor x, at::Tensor f, at::Tensor output, | |
| at::Tensor grad_output, bool batch_first); | |
| std::vector<at::Tensor> forget_mult_backward(at::Tensor x, at::Tensor f, at::Tensor output, | |
| at::Tensor grad_output, bool batch_first) { | |
| CHECK_INPUT(x); CHECK_INPUT(f); CHECK_INPUT(output); | |
| return forget_mult_cuda_backward(x, f, output, grad_output, batch_first); | |
| } | |
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
| m.def("forward", &forget_mult_forward, "ForgetMult forward (CUDA)"); | |
| m.def("backward", &forget_mult_backward, "ForgetMult backward (CUDA)"); | |
| } | |