|
|
#ifndef CAFFE_SOLVER_HPP_ |
|
|
#define CAFFE_SOLVER_HPP_ |
|
|
#include <boost/function.hpp> |
|
|
#include <string> |
|
|
#include <vector> |
|
|
|
|
|
#include "caffe/net.hpp" |
|
|
#include "caffe/solver_factory.hpp" |
|
|
#include "caffe/util/benchmark.hpp" |
|
|
|
|
|
namespace caffe { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace SolverAction { |
|
|
enum Enum { |
|
|
NONE = 0, |
|
|
STOP = 1, |
|
|
|
|
|
SNAPSHOT = 2 |
|
|
}; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
typedef boost::function<SolverAction::Enum()> ActionCallback; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Dtype> |
|
|
class Solver { |
|
|
public: |
|
|
explicit Solver(const SolverParameter& param); |
|
|
explicit Solver(const string& param_file); |
|
|
void Init(const SolverParameter& param); |
|
|
void InitTrainNet(); |
|
|
void InitTestNets(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void SetActionFunction(ActionCallback func); |
|
|
SolverAction::Enum GetRequestedAction(); |
|
|
|
|
|
|
|
|
virtual void Solve(const char* resume_file = NULL); |
|
|
inline void Solve(const string& resume_file) { Solve(resume_file.c_str()); } |
|
|
void Step(int iters); |
|
|
|
|
|
|
|
|
|
|
|
void Restore(const char* resume_file); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void Snapshot(); |
|
|
virtual ~Solver() {} |
|
|
inline const SolverParameter& param() const { return param_; } |
|
|
inline shared_ptr<Net<Dtype> > net() { return net_; } |
|
|
inline const vector<shared_ptr<Net<Dtype> > >& test_nets() { |
|
|
return test_nets_; |
|
|
} |
|
|
int iter() const { return iter_; } |
|
|
|
|
|
|
|
|
class Callback { |
|
|
protected: |
|
|
virtual void on_start() = 0; |
|
|
virtual void on_gradients_ready() = 0; |
|
|
|
|
|
template <typename T> |
|
|
friend class Solver; |
|
|
}; |
|
|
const vector<Callback*>& callbacks() const { return callbacks_; } |
|
|
void add_callback(Callback* value) { |
|
|
callbacks_.push_back(value); |
|
|
} |
|
|
|
|
|
void CheckSnapshotWritePermissions(); |
|
|
|
|
|
|
|
|
|
|
|
virtual inline const char* type() const { return ""; } |
|
|
|
|
|
|
|
|
virtual void ApplyUpdate() = 0; |
|
|
|
|
|
protected: |
|
|
string SnapshotFilename(const string& extension); |
|
|
string SnapshotToBinaryProto(); |
|
|
string SnapshotToHDF5(); |
|
|
|
|
|
void TestAll(); |
|
|
void Test(const int test_net_id = 0); |
|
|
virtual void SnapshotSolverState(const string& model_filename) = 0; |
|
|
virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0; |
|
|
virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0; |
|
|
void DisplayOutputBlobs(const int net_id); |
|
|
void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss); |
|
|
|
|
|
SolverParameter param_; |
|
|
int iter_; |
|
|
int current_step_; |
|
|
shared_ptr<Net<Dtype> > net_; |
|
|
vector<shared_ptr<Net<Dtype> > > test_nets_; |
|
|
vector<Callback*> callbacks_; |
|
|
vector<Dtype> losses_; |
|
|
Dtype smoothed_loss_; |
|
|
|
|
|
|
|
|
|
|
|
ActionCallback action_request_function_; |
|
|
|
|
|
|
|
|
bool requested_early_exit_; |
|
|
|
|
|
|
|
|
Timer iteration_timer_; |
|
|
float iterations_last_; |
|
|
|
|
|
DISABLE_COPY_AND_ASSIGN(Solver); |
|
|
}; |
|
|
|
|
|
} |
|
|
|
|
|
#endif |
|
|
|