/** * @brief A solver factory that allows one to register solvers, similar to * layer factory. During runtime, registered solvers could be called by passing * a SolverParameter protobuffer to the CreateSolver function: * * SolverRegistry::CreateSolver(param); * * There are two ways to register a solver. Assuming that we have a solver like: * * template * class MyAwesomeSolver : public Solver { * // your implementations * }; * * and its type is its C++ class name, but without the "Solver" at the end * ("MyAwesomeSolver" -> "MyAwesome"). * * If the solver is going to be created simply by its constructor, in your C++ * file, add the following line: * * REGISTER_SOLVER_CLASS(MyAwesome); * * Or, if the solver is going to be created by another creator function, in the * format of: * * template * Solver GetMyAwesomeSolver(const SolverParameter& param) { * // your implementation * } * * then you can register the creator function instead, like * * REGISTER_SOLVER_CREATOR(MyAwesome, GetMyAwesomeSolver) * * Note that each solver type should only be registered once. */ #ifndef CAFFE_SOLVER_FACTORY_H_ #define CAFFE_SOLVER_FACTORY_H_ #include #include #include #include "caffe/common.hpp" #include "caffe/proto/caffe.pb.h" namespace caffe { template class Solver; template class SolverRegistry { public: typedef Solver* (*Creator)(const SolverParameter&); typedef std::map CreatorRegistry; static CreatorRegistry& Registry() { static CreatorRegistry* g_registry_ = new CreatorRegistry(); return *g_registry_; } // Adds a creator. static void AddCreator(const string& type, Creator creator) { CreatorRegistry& registry = Registry(); CHECK_EQ(registry.count(type), 0) << "Solver type " << type << " already registered."; registry[type] = creator; } // Get a solver using a SolverParameter. static Solver* CreateSolver(const SolverParameter& param) { const string& type = param.type(); CreatorRegistry& registry = Registry(); CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type << " (known types: " << SolverTypeListString() << ")"; return registry[type](param); } static vector SolverTypeList() { CreatorRegistry& registry = Registry(); vector solver_types; for (typename CreatorRegistry::iterator iter = registry.begin(); iter != registry.end(); ++iter) { solver_types.push_back(iter->first); } return solver_types; } private: // Solver registry should never be instantiated - everything is done with its // static variables. SolverRegistry() {} static string SolverTypeListString() { vector solver_types = SolverTypeList(); string solver_types_str; for (vector::iterator iter = solver_types.begin(); iter != solver_types.end(); ++iter) { if (iter != solver_types.begin()) { solver_types_str += ", "; } solver_types_str += *iter; } return solver_types_str; } }; template class SolverRegisterer { public: SolverRegisterer(const string& type, Solver* (*creator)(const SolverParameter&)) { // LOG(INFO) << "Registering solver type: " << type; SolverRegistry::AddCreator(type, creator); } }; #define REGISTER_SOLVER_CREATOR(type, creator) \ static SolverRegisterer g_creator_f_##type(#type, creator); \ static SolverRegisterer g_creator_d_##type(#type, creator) \ #define REGISTER_SOLVER_CLASS(type) \ template \ Solver* Creator_##type##Solver( \ const SolverParameter& param) \ { \ return new type##Solver(param); \ } \ REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver) } // namespace caffe #endif // CAFFE_SOLVER_FACTORY_H_