| |
| #pragma once |
|
|
| #include <iostream> |
| #include <sstream> |
|
|
| #include "moses/DecodeGraph.h" |
| #include "moses/StaticData.h" |
| #include "moses/Syntax/BoundedPriorityContainer.h" |
| #include "moses/Syntax/CubeQueue.h" |
| #include "moses/Syntax/PHyperedge.h" |
| #include "moses/Syntax/RuleTable.h" |
| #include "moses/Syntax/RuleTableFF.h" |
| #include "moses/Syntax/SHyperedgeBundle.h" |
| #include "moses/Syntax/SVertex.h" |
| #include "moses/Syntax/SVertexRecombinationEqualityPred.h" |
| #include "moses/Syntax/SVertexRecombinationHasher.h" |
| #include "moses/Syntax/SymbolEqualityPred.h" |
| #include "moses/Syntax/SymbolHasher.h" |
|
|
| #include "DerivationWriter.h" |
| #include "OovHandler.h" |
| #include "PChart.h" |
| #include "RuleTrie.h" |
| #include "SChart.h" |
|
|
| namespace Moses |
| { |
| namespace Syntax |
| { |
| namespace S2T |
| { |
|
|
| template<typename Parser> |
| Manager<Parser>::Manager(ttasksptr const& ttask) |
| : Syntax::Manager(ttask) |
| , m_pchart(m_source.GetSize(), Parser::RequiresCompressedChart()) |
| , m_schart(m_source.GetSize()) |
| { } |
|
|
| template<typename Parser> |
| void Manager<Parser>::InitializeCharts() |
| { |
| |
| for (std::size_t i = 0; i < m_source.GetSize(); ++i) { |
| const Word &terminal = m_source.GetWord(i); |
|
|
| |
| PVertex tmp(Range(i,i), terminal); |
| PVertex &pvertex = m_pchart.AddVertex(tmp); |
|
|
| |
| boost::shared_ptr<SVertex> v(new SVertex()); |
| v->best = 0; |
| v->pvertex = &pvertex; |
| SChart::Cell &scell = m_schart.GetCell(i,i); |
| SVertexStack stack(1, v); |
| SChart::Cell::TMap::value_type x(terminal, stack); |
| scell.terminalStacks.insert(x); |
| } |
| } |
|
|
| template<typename Parser> |
| void Manager<Parser>::InitializeParsers(PChart &pchart, |
| std::size_t ruleLimit) |
| { |
| const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances(); |
|
|
| const std::vector<DecodeGraph*> &graphs = |
| StaticData::Instance().GetDecodeGraphs(); |
|
|
| UTIL_THROW_IF2(ffs.size() != graphs.size(), |
| "number of RuleTables does not match number of decode graphs"); |
|
|
| for (std::size_t i = 0; i < ffs.size(); ++i) { |
| RuleTableFF *ff = ffs[i]; |
| std::size_t maxChartSpan = graphs[i]->GetMaxChartSpan(); |
| |
| |
| |
| |
| const RuleTable *table = ff->GetTable(); |
| assert(table); |
| RuleTable *nonConstTable = const_cast<RuleTable*>(table); |
| boost::shared_ptr<Parser> parser; |
| typename Parser::RuleTrie *trie = |
| dynamic_cast<typename Parser::RuleTrie*>(nonConstTable); |
| assert(trie); |
| parser.reset(new Parser(pchart, *trie, maxChartSpan)); |
| m_parsers.push_back(parser); |
| } |
|
|
| |
| |
| m_oovs.clear(); |
| std::size_t maxOovWidth = 0; |
| FindOovs(pchart, m_oovs, maxOovWidth); |
| if (!m_oovs.empty()) { |
| |
| OovHandler<typename Parser::RuleTrie> oovHandler(*ffs[0]); |
| m_oovRuleTrie = oovHandler.SynthesizeRuleTrie(m_oovs.begin(), m_oovs.end()); |
| |
| boost::shared_ptr<Parser> parser( |
| new Parser(pchart, *m_oovRuleTrie, maxOovWidth)); |
| m_parsers.push_back(parser); |
| } |
| } |
|
|
| |
| |
| template<typename Parser> |
| void Manager<Parser>::FindOovs(const PChart &pchart, boost::unordered_set<Word> &oovs, |
| std::size_t maxOovWidth) |
| { |
| |
| std::vector<const RuleTrie *> tries; |
| const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances(); |
| for (std::size_t i = 0; i < ffs.size(); ++i) { |
| const RuleTableFF *ff = ffs[i]; |
| if (ff->GetTable()) { |
| const RuleTrie *trie = dynamic_cast<const RuleTrie*>(ff->GetTable()); |
| assert(trie); |
| tries.push_back(trie); |
| } |
| } |
|
|
| |
| |
| |
| oovs.clear(); |
| maxOovWidth = 0; |
| |
| |
| for (std::size_t i = 1; i < pchart.GetWidth()-1; ++i) { |
| for (std::size_t j = i; j < pchart.GetWidth()-1; ++j) { |
| std::size_t width = j-i+1; |
| const PChart::Cell::TMap &map = pchart.GetCell(i,j).terminalVertices; |
| for (PChart::Cell::TMap::const_iterator p = map.begin(); |
| p != map.end(); ++p) { |
| const Word &word = p->first; |
| assert(!word.IsNonTerminal()); |
| bool found = false; |
| for (std::vector<const RuleTrie *>::const_iterator q = tries.begin(); |
| q != tries.end(); ++q) { |
| const RuleTrie *trie = *q; |
| if (trie->HasPreterminalRule(word)) { |
| found = true; |
| break; |
| } |
| } |
| if (!found) { |
| oovs.insert(word); |
| maxOovWidth = std::max(maxOovWidth, width); |
| } |
| } |
| } |
| } |
| } |
|
|
| template<typename Parser> |
| void Manager<Parser>::Decode() |
| { |
| |
| const std::size_t popLimit = options()->cube.pop_limit; |
| const std::size_t ruleLimit = options()->syntax.rule_limit; |
| const std::size_t stackLimit = options()->search.stack_size; |
|
|
| |
| InitializeCharts(); |
|
|
| |
| InitializeParsers(m_pchart, ruleLimit); |
|
|
| |
| typename Parser::CallbackType callback(m_schart, ruleLimit); |
|
|
| |
| std::size_t size = m_source.GetSize(); |
| for (int start = size-1; start >= 0; --start) { |
| for (std::size_t width = 1; width <= size-start; ++width) { |
| std::size_t end = start + width - 1; |
|
|
| |
| SChart::Cell &scell = m_schart.GetCell(start, end); |
|
|
| Range range(start, end); |
|
|
| |
| |
| |
| callback.InitForRange(range); |
| for (typename std::vector<boost::shared_ptr<Parser> >::iterator |
| p = m_parsers.begin(); p != m_parsers.end(); ++p) { |
| (*p)->EnumerateHyperedges(range, callback); |
| } |
|
|
| |
| const BoundedPriorityContainer<SHyperedgeBundle> &bundles = |
| callback.GetContainer(); |
|
|
| |
| |
| CubeQueue cubeQueue(bundles.Begin(), bundles.End()); |
| std::size_t count = 0; |
| typedef boost::unordered_map<Word, std::vector<SHyperedge*>, |
| SymbolHasher, SymbolEqualityPred > BufferMap; |
| BufferMap buffers; |
| while (count < popLimit && !cubeQueue.IsEmpty()) { |
| SHyperedge *hyperedge = cubeQueue.Pop(); |
| |
| |
| |
| |
| |
| |
| |
| const Word &lhs = hyperedge->label.translation->GetTargetLHS(); |
| hyperedge->head->pvertex = &m_pchart.AddVertex(PVertex(range, lhs)); |
| |
| buffers[lhs].push_back(hyperedge); |
| ++count; |
| } |
|
|
| |
| for (BufferMap::const_iterator p = buffers.begin(); p != buffers.end(); |
| ++p) { |
| const Word &category = p->first; |
| const std::vector<SHyperedge*> &buffer = p->second; |
| std::pair<SChart::Cell::NMap::Iterator, bool> ret = |
| scell.nonTerminalStacks.Insert(category, SVertexStack()); |
| assert(ret.second); |
| SVertexStack &stack = ret.first->second; |
| RecombineAndSort(buffer, stack); |
| } |
|
|
| |
| if (stackLimit > 0) { |
| for (SChart::Cell::NMap::Iterator p = scell.nonTerminalStacks.Begin(); |
| p != scell.nonTerminalStacks.End(); ++p) { |
| SVertexStack &stack = p->second; |
| if (stack.size() > stackLimit) { |
| stack.resize(stackLimit); |
| } |
| } |
| } |
|
|
| |
| |
| |
| |
| } |
| } |
| } |
|
|
| template<typename Parser> |
| const SHyperedge *Manager<Parser>::GetBestSHyperedge() const |
| { |
| const SChart::Cell &cell = m_schart.GetCell(0, m_source.GetSize()-1); |
| const SChart::Cell::NMap &stacks = cell.nonTerminalStacks; |
| if (stacks.Size() == 0) { |
| return 0; |
| } |
| assert(stacks.Size() == 1); |
| const std::vector<boost::shared_ptr<SVertex> > &stack = stacks.Begin()->second; |
| |
| return stack[0]->best; |
| } |
|
|
| template<typename Parser> |
| void Manager<Parser>::ExtractKBest( |
| std::size_t k, |
| std::vector<boost::shared_ptr<KBestExtractor::Derivation> > &kBestList, |
| bool onlyDistinct) const |
| { |
| kBestList.clear(); |
| if (k == 0 || m_source.GetSize() == 0) { |
| return; |
| } |
|
|
| |
| const SChart::Cell &cell = m_schart.GetCell(0, m_source.GetSize()-1); |
| const SChart::Cell::NMap &stacks = cell.nonTerminalStacks; |
| if (stacks.Size() == 0) { |
| return; |
| } |
| assert(stacks.Size() == 1); |
| const std::vector<boost::shared_ptr<SVertex> > &stack = stacks.Begin()->second; |
| |
|
|
| KBestExtractor extractor; |
|
|
| if (!onlyDistinct) { |
| |
| extractor.Extract(stack, k, kBestList); |
| return; |
| } |
|
|
| |
| |
| |
| |
| |
| const StaticData &staticData = StaticData::Instance(); |
| const std::size_t nBestFactor = staticData.options()->nbest.factor; |
| std::size_t numDerivations = (nBestFactor == 0) ? k*1000 : k*nBestFactor; |
|
|
| |
| KBestExtractor::KBestVec bigList; |
| bigList.reserve(numDerivations); |
| extractor.Extract(stack, numDerivations, bigList); |
|
|
| |
| std::set<Phrase> distinct; |
| for (KBestExtractor::KBestVec::const_iterator p = bigList.begin(); |
| kBestList.size() < k && p != bigList.end(); ++p) { |
| boost::shared_ptr<KBestExtractor::Derivation> derivation = *p; |
| Phrase translation = KBestExtractor::GetOutputPhrase(*derivation); |
| if (distinct.insert(translation).second) { |
| kBestList.push_back(derivation); |
| } |
| } |
| } |
|
|
| template<typename Parser> |
| void Manager<Parser>::PrunePChart(const SChart::Cell &scell, |
| PChart::Cell &pcell) |
| { |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| } |
|
|
| template<typename Parser> |
| void Manager<Parser>::RecombineAndSort(const std::vector<SHyperedge*> &buffer, |
| SVertexStack &stack) |
| { |
| |
| |
| |
| |
| |
| typedef boost::unordered_map<SVertex *, SVertex *, |
| SVertexRecombinationHasher, |
| SVertexRecombinationEqualityPred> Map; |
| Map map; |
| for (std::vector<SHyperedge*>::const_iterator p = buffer.begin(); |
| p != buffer.end(); ++p) { |
| SHyperedge *h = *p; |
| SVertex *v = h->head; |
| assert(v->best == h); |
| assert(v->recombined.empty()); |
| std::pair<Map::iterator, bool> result = map.insert(Map::value_type(v, v)); |
| if (result.second) { |
| continue; |
| } |
| |
| |
| |
| SVertex *storedVertex = result.first->second; |
| if (h->label.futureScore > storedVertex->best->label.futureScore) { |
| |
| storedVertex->recombined.push_back(storedVertex->best); |
| storedVertex->best = h; |
| } else { |
| storedVertex->recombined.push_back(h); |
| } |
| h->head->best = 0; |
| delete h->head; |
| h->head = storedVertex; |
| } |
|
|
| |
| stack.clear(); |
| stack.reserve(map.size()); |
| for (Map::const_iterator p = map.begin(); p != map.end(); ++p) { |
| stack.push_back(boost::shared_ptr<SVertex>(p->first)); |
| } |
|
|
| |
| std::sort(stack.begin(), stack.end(), SVertexStackContentOrderer()); |
| } |
|
|
| template<typename Parser> |
| void Manager<Parser>::OutputDetailedTranslationReport( |
| OutputCollector *collector) const |
| { |
| const SHyperedge *best = GetBestSHyperedge(); |
| if (best == NULL || collector == NULL) { |
| return; |
| } |
| long translationId = m_source.GetTranslationId(); |
| std::ostringstream out; |
| DerivationWriter::Write(*best, translationId, out); |
| collector->Write(translationId, out.str()); |
| } |
|
|
| } |
| } |
| } |
|
|