| #ifndef moses_Classifier_h |
| #define moses_Classifier_h |
|
|
| #include <iostream> |
| #include <string> |
| #include <fstream> |
| #include <sstream> |
| #include <deque> |
| #include <vector> |
| #include <boost/shared_ptr.hpp> |
|
|
| #include <boost/noncopyable.hpp> |
| #include <boost/thread/condition_variable.hpp> |
| #include <boost/thread/locks.hpp> |
| #include <boost/thread/mutex.hpp> |
| #include <boost/iostreams/filtering_stream.hpp> |
| #include <boost/iostreams/filter/gzip.hpp> |
| #include "../util/string_piece.hh" |
| #include "../moses/Util.h" |
|
|
| |
| struct vw; |
| class ezexample; |
|
|
| namespace Discriminative |
| { |
| typedef std::pair<uint32_t, float> FeatureType; |
| typedef std::vector<FeatureType> FeatureVector; |
|
|
| |
| |
| |
| class Classifier |
| { |
| public: |
| |
| |
| |
| virtual FeatureType AddLabelIndependentFeature(const StringPiece &name, float value) = 0; |
|
|
| |
| |
| |
| virtual FeatureType AddLabelDependentFeature(const StringPiece &name, float value) = 0; |
|
|
| |
| |
| |
| virtual void AddLabelIndependentFeatureVector(const FeatureVector &features) = 0; |
|
|
| |
| |
| |
| virtual void AddLabelDependentFeatureVector(const FeatureVector &features) = 0; |
|
|
| |
| |
| |
| |
| virtual void Train(const StringPiece &label, float loss) = 0; |
|
|
| |
| |
| |
| |
| virtual float Predict(const StringPiece &label) = 0; |
|
|
| |
| FeatureType AddLabelIndependentFeature(const StringPiece &name) { |
| return AddLabelIndependentFeature(name, 1.0); |
| } |
|
|
| FeatureType AddLabelDependentFeature(const StringPiece &name) { |
| return AddLabelDependentFeature(name, 1.0); |
| } |
|
|
| virtual ~Classifier() {} |
|
|
| protected: |
| |
| |
| |
| static std::string EscapeSpecialChars(const std::string &str) { |
| std::string out; |
| out = Moses::Replace(str, "\\", "_/_"); |
| out = Moses::Replace(out, "|", "\\/"); |
| out = Moses::Replace(out, ":", "\\;"); |
| out = Moses::Replace(out, " ", "\\_"); |
| return out; |
| } |
|
|
| const static bool DEBUG = false; |
|
|
| }; |
|
|
| |
| |
| const std::string VW_DEFAULT_OPTIONS = " --hash all --noconstant -q st -t --ldf_override sc "; |
| const std::string VW_DEFAULT_PARSER_OPTIONS = " --quiet --hash all --noconstant -q st -t --csoaa_ldf sc "; |
|
|
| |
| |
| |
| class VWTrainer : public Classifier |
| { |
| public: |
| VWTrainer(const std::string &outputFile); |
| virtual ~VWTrainer(); |
|
|
| virtual FeatureType AddLabelIndependentFeature(const StringPiece &name, float value); |
| virtual FeatureType AddLabelDependentFeature(const StringPiece &name, float value); |
| virtual void AddLabelIndependentFeatureVector(const FeatureVector &features); |
| virtual void AddLabelDependentFeatureVector(const FeatureVector &features); |
| virtual void Train(const StringPiece &label, float loss); |
| virtual float Predict(const StringPiece &label); |
|
|
| protected: |
| void AddFeature(const StringPiece &name, float value); |
|
|
| bool m_isFirstSource, m_isFirstTarget, m_isFirstExample; |
|
|
| private: |
| boost::iostreams::filtering_ostream m_bfos; |
| std::deque<std::string> m_outputBuffer; |
|
|
| void WriteBuffer(); |
| }; |
|
|
| |
| |
| |
| class VWPredictor : public Classifier, private boost::noncopyable |
| { |
| public: |
| VWPredictor(const std::string &modelFile, const std::string &vwOptions); |
| virtual ~VWPredictor(); |
|
|
| virtual FeatureType AddLabelIndependentFeature(const StringPiece &name, float value); |
| virtual FeatureType AddLabelDependentFeature(const StringPiece &name, float value); |
| virtual void AddLabelIndependentFeatureVector(const FeatureVector &features); |
| virtual void AddLabelDependentFeatureVector(const FeatureVector &features); |
| virtual void Train(const StringPiece &label, float loss); |
| virtual float Predict(const StringPiece &label); |
|
|
| friend class ClassifierFactory; |
|
|
| protected: |
| FeatureType AddFeature(const StringPiece &name, float values); |
|
|
| ::vw *m_VWInstance, *m_VWParser; |
| ::ezexample *m_ex; |
| |
| |
| bool m_sharedVwInstance; |
| bool m_isFirstSource, m_isFirstTarget; |
|
|
| private: |
| |
| VWPredictor(vw * instance, const std::string &vwOption); |
| }; |
|
|
| |
| |
| |
| class ClassifierFactory : private boost::noncopyable |
| { |
| public: |
| typedef boost::shared_ptr<Classifier> ClassifierPtr; |
|
|
| |
| |
| |
| ClassifierFactory(const std::string &modelFile, const std::string &vwOptions); |
|
|
| |
| |
| |
| ClassifierFactory(const std::string &modelFilePrefix); |
|
|
| |
| ClassifierPtr operator()(); |
|
|
| ~ClassifierFactory(); |
|
|
| private: |
| std::string m_vwOptions; |
| ::vw *m_VWInstance; |
| int m_lastId; |
| std::string m_modelFilePrefix; |
| bool m_gzip; |
| boost::mutex m_mutex; |
| const bool m_train; |
| }; |
|
|
| } |
|
|
| #endif |
|
|