| namespace at::native { | |
| enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 }; | |
| using fused_adam_fn = void (*)( | |
| const at::Tensor& param, | |
| const at::Tensor& grad, | |
| const at::Tensor& exp_avg, | |
| const at::Tensor& exp_avg_sq, | |
| const at::Tensor& max_exp_avg_sq, | |
| const at::Tensor& state_step, | |
| const double lr, | |
| const double beta1, | |
| const double beta2, | |
| const double weight_decay, | |
| const double eps, | |
| const bool amsgrad, | |
| const bool maximize, | |
| const float* grad_scale_ptr, | |
| const ADAM_MODE); | |
| DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub) | |
| } // namespace at::native | |