#ifndef CAFFE_SGD_SOLVERS_HPP_ #define CAFFE_SGD_SOLVERS_HPP_ #include #include #include "caffe/solver.hpp" namespace caffe { /** * @brief Optimizes the parameters of a Net using * stochastic gradient descent (SGD) with momentum. */ template class SGDSolver : public Solver { public: explicit SGDSolver(const SolverParameter& param) : Solver(param) { PreSolve(); } explicit SGDSolver(const string& param_file) : Solver(param_file) { PreSolve(); } virtual inline const char* type() const { return "SGD"; } const vector > >& history() { return history_; } virtual void ApplyUpdate(); Dtype GetLearningRate(); protected: void PreSolve(); virtual void Normalize(int param_id); virtual void Regularize(int param_id); virtual void ComputeUpdateValue(int param_id, Dtype rate); virtual void ClipGradients(); virtual void SnapshotSolverState(const string& model_filename); virtual void SnapshotSolverStateToBinaryProto(const string& model_filename); virtual void SnapshotSolverStateToHDF5(const string& model_filename); virtual void RestoreSolverStateFromHDF5(const string& state_file); virtual void RestoreSolverStateFromBinaryProto(const string& state_file); // history maintains the historical momentum data. // update maintains update related data and is not needed in snapshots. // temp maintains other information that might be needed in computation // of gradients/updates and is not needed in snapshots vector > > history_, update_, temp_; DISABLE_COPY_AND_ASSIGN(SGDSolver); }; template class NesterovSolver : public SGDSolver { public: explicit NesterovSolver(const SolverParameter& param) : SGDSolver(param) {} explicit NesterovSolver(const string& param_file) : SGDSolver(param_file) {} virtual inline const char* type() const { return "Nesterov"; } protected: virtual void ComputeUpdateValue(int param_id, Dtype rate); DISABLE_COPY_AND_ASSIGN(NesterovSolver); }; template class AdaGradSolver : public SGDSolver { public: explicit AdaGradSolver(const SolverParameter& param) : SGDSolver(param) { constructor_sanity_check(); } explicit AdaGradSolver(const string& param_file) : SGDSolver(param_file) { constructor_sanity_check(); } virtual inline const char* type() const { return "AdaGrad"; } protected: virtual void ComputeUpdateValue(int param_id, Dtype rate); void constructor_sanity_check() { CHECK_EQ(0, this->param_.momentum()) << "Momentum cannot be used with AdaGrad."; } DISABLE_COPY_AND_ASSIGN(AdaGradSolver); }; template class RMSPropSolver : public SGDSolver { public: explicit RMSPropSolver(const SolverParameter& param) : SGDSolver(param) { constructor_sanity_check(); } explicit RMSPropSolver(const string& param_file) : SGDSolver(param_file) { constructor_sanity_check(); } virtual inline const char* type() const { return "RMSProp"; } protected: virtual void ComputeUpdateValue(int param_id, Dtype rate); void constructor_sanity_check() { CHECK_EQ(0, this->param_.momentum()) << "Momentum cannot be used with RMSProp."; CHECK_GE(this->param_.rms_decay(), 0) << "rms_decay should lie between 0 and 1."; CHECK_LT(this->param_.rms_decay(), 1) << "rms_decay should lie between 0 and 1."; } DISABLE_COPY_AND_ASSIGN(RMSPropSolver); }; template class AdaDeltaSolver : public SGDSolver { public: explicit AdaDeltaSolver(const SolverParameter& param) : SGDSolver(param) { AdaDeltaPreSolve(); } explicit AdaDeltaSolver(const string& param_file) : SGDSolver(param_file) { AdaDeltaPreSolve(); } virtual inline const char* type() const { return "AdaDelta"; } protected: void AdaDeltaPreSolve(); virtual void ComputeUpdateValue(int param_id, Dtype rate); DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); }; /** * @brief AdamSolver, an algorithm for first-order gradient-based optimization * of stochastic objective functions, based on adaptive estimates of * lower-order moments. Described in [1]. * * [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization." * arXiv preprint arXiv:1412.6980v8 (2014). */ template class AdamSolver : public SGDSolver { public: explicit AdamSolver(const SolverParameter& param) : SGDSolver(param) { AdamPreSolve();} explicit AdamSolver(const string& param_file) : SGDSolver(param_file) { AdamPreSolve(); } virtual inline const char* type() const { return "Adam"; } protected: void AdamPreSolve(); virtual void ComputeUpdateValue(int param_id, Dtype rate); DISABLE_COPY_AND_ASSIGN(AdamSolver); }; } // namespace caffe #endif // CAFFE_SGD_SOLVERS_HPP_