| | #ifndef CAFFE_SGD_SOLVERS_HPP_ |
| | #define CAFFE_SGD_SOLVERS_HPP_ |
| |
|
| | #include <string> |
| | #include <vector> |
| |
|
| | #include "caffe/solver.hpp" |
| |
|
| | namespace caffe { |
| |
|
| | |
| | |
| | |
| | |
| | template <typename Dtype> |
| | class SGDSolver : public Solver<Dtype> { |
| | public: |
| | explicit SGDSolver(const SolverParameter& param) |
| | : Solver<Dtype>(param) { PreSolve(); } |
| | explicit SGDSolver(const string& param_file) |
| | : Solver<Dtype>(param_file) { PreSolve(); } |
| | virtual inline const char* type() const { return "SGD"; } |
| |
|
| | const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; } |
| |
|
| | virtual void ApplyUpdate(); |
| | Dtype GetLearningRate(); |
| |
|
| | protected: |
| | void PreSolve(); |
| | virtual void Normalize(int param_id); |
| | virtual void Regularize(int param_id); |
| | virtual void ComputeUpdateValue(int param_id, Dtype rate); |
| | virtual void ClipGradients(); |
| | virtual void SnapshotSolverState(const string& model_filename); |
| | virtual void SnapshotSolverStateToBinaryProto(const string& model_filename); |
| | virtual void SnapshotSolverStateToHDF5(const string& model_filename); |
| | virtual void RestoreSolverStateFromHDF5(const string& state_file); |
| | virtual void RestoreSolverStateFromBinaryProto(const string& state_file); |
| | |
| | |
| | |
| | |
| | vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_; |
| |
|
| | DISABLE_COPY_AND_ASSIGN(SGDSolver); |
| | }; |
| |
|
| | template <typename Dtype> |
| | class NesterovSolver : public SGDSolver<Dtype> { |
| | public: |
| | explicit NesterovSolver(const SolverParameter& param) |
| | : SGDSolver<Dtype>(param) {} |
| | explicit NesterovSolver(const string& param_file) |
| | : SGDSolver<Dtype>(param_file) {} |
| | virtual inline const char* type() const { return "Nesterov"; } |
| |
|
| | protected: |
| | virtual void ComputeUpdateValue(int param_id, Dtype rate); |
| |
|
| | DISABLE_COPY_AND_ASSIGN(NesterovSolver); |
| | }; |
| |
|
| | template <typename Dtype> |
| | class AdaGradSolver : public SGDSolver<Dtype> { |
| | public: |
| | explicit AdaGradSolver(const SolverParameter& param) |
| | : SGDSolver<Dtype>(param) { constructor_sanity_check(); } |
| | explicit AdaGradSolver(const string& param_file) |
| | : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); } |
| | virtual inline const char* type() const { return "AdaGrad"; } |
| |
|
| | protected: |
| | virtual void ComputeUpdateValue(int param_id, Dtype rate); |
| | void constructor_sanity_check() { |
| | CHECK_EQ(0, this->param_.momentum()) |
| | << "Momentum cannot be used with AdaGrad."; |
| | } |
| |
|
| | DISABLE_COPY_AND_ASSIGN(AdaGradSolver); |
| | }; |
| |
|
| |
|
| | template <typename Dtype> |
| | class RMSPropSolver : public SGDSolver<Dtype> { |
| | public: |
| | explicit RMSPropSolver(const SolverParameter& param) |
| | : SGDSolver<Dtype>(param) { constructor_sanity_check(); } |
| | explicit RMSPropSolver(const string& param_file) |
| | : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); } |
| | virtual inline const char* type() const { return "RMSProp"; } |
| |
|
| | protected: |
| | virtual void ComputeUpdateValue(int param_id, Dtype rate); |
| | void constructor_sanity_check() { |
| | CHECK_EQ(0, this->param_.momentum()) |
| | << "Momentum cannot be used with RMSProp."; |
| | CHECK_GE(this->param_.rms_decay(), 0) |
| | << "rms_decay should lie between 0 and 1."; |
| | CHECK_LT(this->param_.rms_decay(), 1) |
| | << "rms_decay should lie between 0 and 1."; |
| | } |
| |
|
| | DISABLE_COPY_AND_ASSIGN(RMSPropSolver); |
| | }; |
| |
|
| | template <typename Dtype> |
| | class AdaDeltaSolver : public SGDSolver<Dtype> { |
| | public: |
| | explicit AdaDeltaSolver(const SolverParameter& param) |
| | : SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); } |
| | explicit AdaDeltaSolver(const string& param_file) |
| | : SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); } |
| | virtual inline const char* type() const { return "AdaDelta"; } |
| |
|
| | protected: |
| | void AdaDeltaPreSolve(); |
| | virtual void ComputeUpdateValue(int param_id, Dtype rate); |
| |
|
| | DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); |
| | }; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | template <typename Dtype> |
| | class AdamSolver : public SGDSolver<Dtype> { |
| | public: |
| | explicit AdamSolver(const SolverParameter& param) |
| | : SGDSolver<Dtype>(param) { AdamPreSolve();} |
| | explicit AdamSolver(const string& param_file) |
| | : SGDSolver<Dtype>(param_file) { AdamPreSolve(); } |
| | virtual inline const char* type() const { return "Adam"; } |
| |
|
| | protected: |
| | void AdamPreSolve(); |
| | virtual void ComputeUpdateValue(int param_id, Dtype rate); |
| |
|
| | DISABLE_COPY_AND_ASSIGN(AdamSolver); |
| | }; |
| |
|
| | } |
| |
|
| | #endif |
| |
|