/* * KENLM.cpp * * Created on: 4 Nov 2015 * Author: hieu */ #include #include #include "KENLM.h" #include "../Phrase.h" #include "../Scores.h" #include "../System.h" #include "../PhraseBased/Hypothesis.h" #include "../PhraseBased/Manager.h" #include "../PhraseBased/TargetPhraseImpl.h" #include "lm/state.hh" #include "lm/left.hh" #include "util/exception.hh" #include "util/tokenize_piece.hh" #include "util/string_stream.hh" #include "../legacy/FactorCollection.h" #include "../SCFG/TargetPhraseImpl.h" #include "../SCFG/Hypothesis.h" #include "../SCFG/Manager.h" using namespace std; namespace Moses2 { struct KenLMState: public FFState { lm::ngram::State state; virtual size_t hash() const { size_t ret = hash_value(state); return ret; } virtual bool operator==(const FFState& o) const { const KenLMState &other = static_cast(o); bool ret = state == other.state; return ret; } virtual std::string ToString() const { stringstream ss; for (size_t i = 0; i < state.Length(); ++i) { ss << state.words[i] << " "; } return ss.str(); } }; ///////////////////////////////////////////////////////////////// class LanguageModelChartStateKenLM : public FFState { public: LanguageModelChartStateKenLM() {} const lm::ngram::ChartState &GetChartState() const { return m_state; } lm::ngram::ChartState &GetChartState() { return m_state; } size_t hash() const { size_t ret = hash_value(m_state); return ret; } virtual bool operator==(const FFState& o) const { const LanguageModelChartStateKenLM &other = static_cast(o); bool ret = m_state == other.m_state; return ret; } virtual std::string ToString() const { return "LanguageModelChartStateKenLM"; } private: lm::ngram::ChartState m_state; }; ///////////////////////////////////////////////////////////////// class MappingBuilder: public lm::EnumerateVocab { public: MappingBuilder(FactorCollection &factorCollection, System &system, std::vector &mapping) : m_factorCollection(factorCollection), m_system(system), m_mapping(mapping) { } void Add(lm::WordIndex index, const StringPiece &str) { std::size_t factorId = m_factorCollection.AddFactor(str, m_system, false)->GetId(); if (m_mapping.size() <= factorId) { // 0 is :-) m_mapping.resize(factorId + 1); } m_mapping[factorId] = index; } private: FactorCollection &m_factorCollection; std::vector &m_mapping; System &m_system; }; ///////////////////////////////////////////////////////////////// template KENLM::KENLM(size_t startInd, const std::string &line, const std::string &file, FactorType factorType, util::LoadMethod load_method) : StatefulFeatureFunction(startInd, line), m_path(file), m_factorType( factorType), m_load_method(load_method) { ReadParameters(); } template KENLM::~KENLM() { // TODO Auto-generated destructor stub } template void KENLM::Load(System &system) { FactorCollection &fc = system.GetVocab(); m_bos = fc.AddFactor(BOS_, system, false); m_eos = fc.AddFactor(EOS_, system, false); lm::ngram::Config config; config.messages = NULL; FactorCollection &collection = system.GetVocab(); MappingBuilder builder(collection, system, m_lmIdLookup); config.enumerate_vocab = &builder; config.load_method = m_load_method; m_ngram.reset(new Model(m_path.c_str(), config)); } template FFState* KENLM::BlankState(MemPool &pool, const System &sys) const { FFState *ret; if (sys.isPb) { ret = new (pool.Allocate()) KenLMState(); } else { ret = new (pool.Allocate()) LanguageModelChartStateKenLM(); } return ret; } //! return the state associated with the empty hypothesis for a given sentence template void KENLM::EmptyHypothesisState(FFState &state, const ManagerBase &mgr, const InputType &input, const Hypothesis &hypo) const { KenLMState &stateCast = static_cast(state); stateCast.state = m_ngram->BeginSentenceState(); } template void KENLM::EvaluateInIsolation(MemPool &pool, const System &system, const Phrase &source, const TargetPhraseImpl &targetPhrase, Scores &scores, SCORE &estimatedScore) const { // contains factors used by this LM float fullScore, nGramScore; size_t oovCount; CalcScore(targetPhrase, fullScore, nGramScore, oovCount); float estimateScore = fullScore - nGramScore; bool GetLMEnableOOVFeature = false; if (GetLMEnableOOVFeature) { float scoresVec[2], estimateScoresVec[2]; scoresVec[0] = nGramScore; scoresVec[1] = oovCount; scores.PlusEquals(system, *this, scoresVec); estimateScoresVec[0] = estimateScore; estimateScoresVec[1] = 0; SCORE weightedScore = Scores::CalcWeightedScore(system, *this, estimateScoresVec); estimatedScore += weightedScore; } else { scores.PlusEquals(system, *this, nGramScore); SCORE weightedScore = Scores::CalcWeightedScore(system, *this, estimateScore); estimatedScore += weightedScore; } } template void KENLM::EvaluateInIsolation(MemPool &pool, const System &system, const Phrase &source, const TargetPhrase &targetPhrase, Scores &scores, SCORE &estimatedScore) const { // contains factors used by this LM float fullScore, nGramScore; size_t oovCount; CalcScore(targetPhrase, fullScore, nGramScore, oovCount); //float estimateScore = fullScore - nGramScore; // all LM scores are estimated float estimateScore = fullScore; nGramScore = 0; bool GetLMEnableOOVFeature = false; if (GetLMEnableOOVFeature) { float scoresVec[2], estimateScoresVec[2]; scoresVec[0] = nGramScore; scoresVec[1] = oovCount; scores.PlusEquals(system, *this, scoresVec); estimateScoresVec[0] = estimateScore; estimateScoresVec[1] = 0; SCORE weightedScore = Scores::CalcWeightedScore(system, *this, estimateScoresVec); estimatedScore += weightedScore; } else { scores.PlusEquals(system, *this, nGramScore); SCORE weightedScore = Scores::CalcWeightedScore(system, *this, estimateScore); estimatedScore += weightedScore; } } template void KENLM::EvaluateWhenApplied(const ManagerBase &mgr, const Hypothesis &hypo, const FFState &prevState, Scores &scores, FFState &state) const { KenLMState &stateCast = static_cast(state); const System &system = mgr.system; const lm::ngram::State &in_state = static_cast(prevState).state; if (!hypo.GetTargetPhrase().GetSize()) { stateCast.state = in_state; return; } const std::size_t begin = hypo.GetCurrTargetWordsRange().GetStartPos(); //[begin, end) in STL-like fashion. const std::size_t end = hypo.GetCurrTargetWordsRange().GetEndPos() + 1; const std::size_t adjust_end = std::min(end, begin + m_ngram->Order() - 1); std::size_t position = begin; typename Model::State aux_state; typename Model::State *state0 = &stateCast.state, *state1 = &aux_state; float score = m_ngram->Score(in_state, TranslateID(hypo.GetWord(position)), *state0); ++position; for (; position < adjust_end; ++position) { score += m_ngram->Score(*state0, TranslateID(hypo.GetWord(position)), *state1); std::swap(state0, state1); } if (hypo.GetBitmap().IsComplete()) { // Score end of sentence. std::vector indices(m_ngram->Order() - 1); const lm::WordIndex *last = LastIDs(hypo, &indices.front()); score += m_ngram->FullScoreForgotState(&indices.front(), last, m_ngram->GetVocabulary().EndSentence(), stateCast.state).prob; } else if (adjust_end < end) { // Get state after adding a long phrase. std::vector indices(m_ngram->Order() - 1); const lm::WordIndex *last = LastIDs(hypo, &indices.front()); m_ngram->GetState(&indices.front(), last, stateCast.state); } else if (state0 != &stateCast.state) { // Short enough phrase that we can just reuse the state. stateCast.state = *state0; } score = TransformLMScore(score); bool OOVFeatureEnabled = false; if (OOVFeatureEnabled) { std::vector scoresVec(2); scoresVec[0] = score; scoresVec[1] = 0.0; scores.PlusEquals(system, *this, scoresVec); } else { scores.PlusEquals(system, *this, score); } } template void KENLM::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, std::size_t &oovCount) const { fullScore = 0; ngramScore = 0; oovCount = 0; if (!phrase.GetSize()) return; lm::ngram::ChartState discarded_sadly; lm::ngram::RuleScore scorer(*m_ngram, discarded_sadly); size_t position; if (m_bos == phrase[0][m_factorType]) { scorer.BeginSentence(); position = 1; } else { position = 0; } size_t ngramBoundary = m_ngram->Order() - 1; size_t end_loop = std::min(ngramBoundary, phrase.GetSize()); for (; position < end_loop; ++position) { const Word &word = phrase[position]; lm::WordIndex index = TranslateID(word); scorer.Terminal(index); if (!index) ++oovCount; } float before_boundary = fullScore + scorer.Finish(); for (; position < phrase.GetSize(); ++position) { const Word &word = phrase[position]; lm::WordIndex index = TranslateID(word); scorer.Terminal(index); if (!index) ++oovCount; } fullScore += scorer.Finish(); ngramScore = TransformLMScore(fullScore - before_boundary); fullScore = TransformLMScore(fullScore); } template void KENLM::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, std::size_t &oovCount) const { fullScore = 0; ngramScore = 0; oovCount = 0; if (!phrase.GetSize()) return; lm::ngram::ChartState discarded_sadly; lm::ngram::RuleScore scorer(*m_ngram, discarded_sadly); size_t position; if (m_bos == phrase[0][m_factorType]) { scorer.BeginSentence(); position = 1; } else { position = 0; } size_t ngramBoundary = m_ngram->Order() - 1; size_t end_loop = std::min(ngramBoundary, phrase.GetSize()); for (; position < end_loop; ++position) { const SCFG::Word &word = phrase[position]; if (word.isNonTerminal) { fullScore += scorer.Finish(); scorer.Reset(); } else { lm::WordIndex index = TranslateID(word); scorer.Terminal(index); if (!index) ++oovCount; } } float before_boundary = fullScore + scorer.Finish(); for (; position < phrase.GetSize(); ++position) { const SCFG::Word &word = phrase[position]; if (word.isNonTerminal) { fullScore += scorer.Finish(); scorer.Reset(); } else { lm::WordIndex index = TranslateID(word); scorer.Terminal(index); if (!index) ++oovCount; } } fullScore += scorer.Finish(); ngramScore = TransformLMScore(fullScore - before_boundary); fullScore = TransformLMScore(fullScore); } // Convert last words of hypothesis into vocab ids, returning an end pointer. template lm::WordIndex *KENLM::LastIDs(const Hypothesis &hypo, lm::WordIndex *indices) const { lm::WordIndex *index = indices; lm::WordIndex *end = indices + m_ngram->Order() - 1; int position = hypo.GetCurrTargetWordsRange().GetEndPos(); for (;; ++index, --position) { if (index == end) return index; if (position == -1) { *index = m_ngram->GetVocabulary().BeginSentence(); return index + 1; } *index = TranslateID(hypo.GetWord(position)); } } template void KENLM::EvaluateWhenApplied(const SCFG::Manager &mgr, const SCFG::Hypothesis &hypo, int featureID, Scores &scores, FFState &state) const { LanguageModelChartStateKenLM &newState = static_cast(state); lm::ngram::RuleScore ruleScore(*m_ngram, newState.GetChartState()); const SCFG::TargetPhraseImpl &target = hypo.GetTargetPhrase(); const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = target.GetAlignNonTerm().GetNonTermIndexMap(); const size_t size = target.GetSize(); size_t phrasePos = 0; // Special cases for first word. if (size) { const SCFG::Word &word = target[0]; if (word[m_factorType] == m_bos) { // Begin of sentence ruleScore.BeginSentence(); phrasePos++; } else if (word.isNonTerminal) { // Non-terminal is first so we can copy instead of rescoring. const SCFG::Hypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]); const lm::ngram::ChartState &prevState = static_cast(prevHypo->GetState(featureID))->GetChartState(); ruleScore.BeginNonTerminal(prevState); phrasePos++; } } for (; phrasePos < size; phrasePos++) { const SCFG::Word &word = target[phrasePos]; if (word.isNonTerminal) { const SCFG::Hypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]); const lm::ngram::ChartState &prevState = static_cast(prevHypo->GetState(featureID))->GetChartState(); ruleScore.NonTerminal(prevState); } else { ruleScore.Terminal(TranslateID(word)); } } float score = ruleScore.Finish(); score = TransformLMScore(score); // take out score from loading. This needs reworking //score -= target.GetScores().GetScores(*this)[0]; bool OOVFeatureEnabled = false; if (OOVFeatureEnabled) { std::vector scoresVec(2); scoresVec[0] = score; scoresVec[1] = 0.0; scores.PlusEquals(mgr.system, *this, scoresVec); } else { scores.PlusEquals(mgr.system, *this, score); } } /////////////////////////////////////////////////////////////////////////// /* Instantiate LanguageModelKen here. Tells the compiler to generate code * for the instantiations' non-inline member functions in this file. * Otherwise, depending on the compiler, those functions may not be present * at link time. */ template class KENLM ; template class KENLM ; template class KENLM ; template class KENLM ; template class KENLM ; template class KENLM ; FeatureFunction *ConstructKenLM(size_t startInd, const std::string &lineOrig) { FactorType factorType = 0; string filePath; util::LoadMethod load_method = util::POPULATE_OR_READ; util::TokenIter argument(lineOrig, ' '); ++argument; // KENLM util::StringStream line; line << "KENLM"; for (; argument; ++argument) { const char *equals = std::find(argument->data(), argument->data() + argument->size(), '='); UTIL_THROW_IF2(equals == argument->data() + argument->size(), "Expected = in KenLM argument " << *argument); StringPiece name(argument->data(), equals - argument->data()); StringPiece value(equals + 1, argument->data() + argument->size() - equals - 1); if (name == "factor") { factorType = boost::lexical_cast(value); } else if (name == "order") { // Ignored } else if (name == "path") { filePath.assign(value.data(), value.size()); } else if (name == "lazyken") { // deprecated: use load instead. load_method = boost::lexical_cast(value) ? util::LAZY : util::POPULATE_OR_READ; } else if (name == "load") { if (value == "lazy") { load_method = util::LAZY; } else if (value == "populate_or_lazy") { load_method = util::POPULATE_OR_LAZY; } else if (value == "populate_or_read" || value == "populate") { load_method = util::POPULATE_OR_READ; } else if (value == "read") { load_method = util::READ; } else if (value == "parallel_read") { load_method = util::PARALLEL_READ; } else { UTIL_THROW2("Unknown KenLM load method " << value); } } else { // pass to base class to interpret line << " " << name << "=" << value; } } return ConstructKenLM(startInd, line.str(), filePath, factorType, load_method); } FeatureFunction *ConstructKenLM(size_t startInd, const std::string &line, const std::string &file, FactorType factorType, util::LoadMethod load_method) { lm::ngram::ModelType model_type; if (lm::ngram::RecognizeBinary(file.c_str(), model_type)) { switch (model_type) { case lm::ngram::PROBING: return new KENLM(startInd, line, file, factorType, load_method); case lm::ngram::REST_PROBING: return new KENLM(startInd, line, file, factorType, load_method); case lm::ngram::TRIE: return new KENLM(startInd, line, file, factorType, load_method); case lm::ngram::QUANT_TRIE: return new KENLM(startInd, line, file, factorType, load_method); case lm::ngram::ARRAY_TRIE: return new KENLM(startInd, line, file, factorType, load_method); case lm::ngram::QUANT_ARRAY_TRIE: return new KENLM(startInd, line, file, factorType, load_method); default: UTIL_THROW2("Unrecognized kenlm model type " << model_type) ; } } else { return new KENLM(startInd, line, file, factorType, load_method); } } }