| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | #include <tiny-cuda-nn/optimizer.h> |
| |
|
| | #include <tiny-cuda-nn/optimizers/adam.h> |
| | #include <tiny-cuda-nn/optimizers/average.h> |
| | #include <tiny-cuda-nn/optimizers/batched.h> |
| | #include <tiny-cuda-nn/optimizers/composite.h> |
| | #include <tiny-cuda-nn/optimizers/ema.h> |
| | #include <tiny-cuda-nn/optimizers/exponential_decay.h> |
| | #include <tiny-cuda-nn/optimizers/lookahead.h> |
| | #include <tiny-cuda-nn/optimizers/novograd.h> |
| | #include <tiny-cuda-nn/optimizers/sgd.h> |
| |
|
| | #ifdef TCNN_SHAMPOO |
| | #include <tiny-cuda-nn/optimizers/shampoo.h> |
| | #endif |
| |
|
| |
|
| | TCNN_NAMESPACE_BEGIN |
| |
|
| | template <typename T> |
| | Optimizer<T>* create_optimizer(const json& optimizer) { |
| | std::string optimizer_type = optimizer.value("otype", "Adam"); |
| |
|
| | if (equals_case_insensitive(optimizer_type, "Adam")) { |
| | return new AdamOptimizer<T>{optimizer}; |
| | } else if (equals_case_insensitive(optimizer_type, "Average")) { |
| | return new AverageOptimizer<T>{optimizer}; |
| | } else if (equals_case_insensitive(optimizer_type, "Batched")) { |
| | return new BatchedOptimizer<T>{optimizer}; |
| | } else if (equals_case_insensitive(optimizer_type, "Composite")) { |
| | return new CompositeOptimizer<T>{optimizer}; |
| | } else if (equals_case_insensitive(optimizer_type, "Ema")) { |
| | return new EmaOptimizer<T>{optimizer}; |
| | } else if (equals_case_insensitive(optimizer_type, "ExponentialDecay")) { |
| | return new ExponentialDecayOptimizer<T>{optimizer}; |
| | } else if (equals_case_insensitive(optimizer_type, "Lookahead")) { |
| | return new LookaheadOptimizer<T>{optimizer}; |
| | } else if (equals_case_insensitive(optimizer_type, "Novograd")) { |
| | return new NovogradOptimizer<T>{optimizer}; |
| | } else if (equals_case_insensitive(optimizer_type, "SGD")) { |
| | return new SGDOptimizer<T>{optimizer}; |
| | } else if (equals_case_insensitive(optimizer_type, "Shampoo")) { |
| | #ifdef TCNN_SHAMPOO |
| | return new ShampooOptimizer<T>{optimizer}; |
| | #else |
| | throw std::runtime_error{"The Shampoo optimizer is only available when compiling with CUDA 11 or higher."}; |
| | #endif |
| | } else { |
| | throw std::runtime_error{fmt::format("Invalid optimizer type: {}", optimizer_type)}; |
| | } |
| | } |
| |
|
| | template Optimizer<float>* create_optimizer(const json& optimizer); |
| | template Optimizer<__half>* create_optimizer(const json& optimizer); |
| |
|
| | TCNN_NAMESPACE_END |
| |
|