| namespace at { namespace native { | |
| void _fused_adam_cuda_impl_( | |
| at::TensorList params, | |
| at::TensorList grads, | |
| at::TensorList exp_avgs, | |
| at::TensorList exp_avg_sqs, | |
| at::TensorList state_steps, | |
| const double lr, | |
| const double beta1, | |
| const double beta2, | |
| const double weight_decay, | |
| const double eps, | |
| const bool amsgrad, | |
| const bool maximize, | |
| const c10::optional<at::Tensor>& grad_scale, | |
| const c10::optional<at::Tensor>& found_inf | |
| ); | |
| } } // namespace at::native | |