| #pragma once |
|
|
| #include "moses/DecodeGraph.h" |
| #include "moses/StaticData.h" |
| #include "moses/Syntax/BoundedPriorityContainer.h" |
| #include "moses/Syntax/CubeQueue.h" |
| #include "moses/Syntax/F2S/DerivationWriter.h" |
| #include "moses/Syntax/F2S/RuleMatcherCallback.h" |
| #include "moses/Syntax/PHyperedge.h" |
| #include "moses/Syntax/RuleTable.h" |
| #include "moses/Syntax/RuleTableFF.h" |
| #include "moses/Syntax/SHyperedgeBundle.h" |
| #include "moses/Syntax/SVertex.h" |
| #include "moses/Syntax/SVertexRecombinationEqualityPred.h" |
| #include "moses/Syntax/SVertexRecombinationHasher.h" |
| #include "moses/Syntax/SymbolEqualityPred.h" |
| #include "moses/Syntax/SymbolHasher.h" |
|
|
| #include "GlueRuleSynthesizer.h" |
| #include "InputTreeBuilder.h" |
| #include "RuleTrie.h" |
|
|
| namespace Moses |
| { |
| namespace Syntax |
| { |
| namespace T2S |
| { |
|
|
| template<typename RuleMatcher> |
| Manager<RuleMatcher>::Manager(ttasksptr const& ttask) |
| : Syntax::Manager(ttask) |
| { |
| if (const TreeInput *p = dynamic_cast<const TreeInput*>(&m_source)) { |
| |
| InputTreeBuilder builder(options()->output.factor_order); |
| builder.Build(*p, "Q", m_inputTree); |
| } else { |
| UTIL_THROW2("ERROR: T2S::Manager requires input to be a tree"); |
| } |
| } |
|
|
| template<typename RuleMatcher> |
| void Manager<RuleMatcher>::InitializeRuleMatchers() |
| { |
| const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances(); |
| for (std::size_t i = 0; i < ffs.size(); ++i) { |
| RuleTableFF *ff = ffs[i]; |
| |
| |
| |
| |
| const RuleTable *table = ff->GetTable(); |
| assert(table); |
| RuleTable *nonConstTable = const_cast<RuleTable*>(table); |
| RuleTrie *trie = dynamic_cast<RuleTrie*>(nonConstTable); |
| assert(trie); |
| boost::shared_ptr<RuleMatcher> p(new RuleMatcher(m_inputTree, *trie)); |
| m_ruleMatchers.push_back(p); |
| } |
|
|
| |
| |
| |
| m_glueRuleTrie.reset(new RuleTrie(ffs[0])); |
| boost::shared_ptr<RuleMatcher> p(new RuleMatcher(m_inputTree, *m_glueRuleTrie)); |
| m_ruleMatchers.push_back(p); |
| m_glueRuleMatcher = p.get(); |
| } |
|
|
| template<typename RuleMatcher> |
| void Manager<RuleMatcher>::InitializeStacks() |
| { |
| |
| assert(!m_inputTree.nodes.empty()); |
|
|
| for (std::vector<InputTree::Node>::const_iterator p = |
| m_inputTree.nodes.begin(); p != m_inputTree.nodes.end(); ++p) { |
| const InputTree::Node &node = *p; |
|
|
| |
| SVertexStack &stack = m_stackMap[&(node.pvertex)]; |
|
|
| |
| if (node.children.empty()) { |
| boost::shared_ptr<SVertex> v(new SVertex()); |
| v->best = 0; |
| v->pvertex = &(node.pvertex); |
| stack.push_back(v); |
| } |
| } |
| } |
|
|
| template<typename RuleMatcher> |
| void Manager<RuleMatcher>::Decode() |
| { |
| |
|
|
| |
| const std::size_t popLimit = this->options()->cube.pop_limit; |
| const std::size_t ruleLimit = this->options()->syntax.rule_limit; |
| const std::size_t stackLimit = this->options()->search.stack_size; |
|
|
| |
| InitializeStacks(); |
|
|
| |
| InitializeRuleMatchers(); |
|
|
| |
| F2S::RuleMatcherCallback callback(m_stackMap, ruleLimit); |
|
|
| |
| Word dflt_nonterm = options()->syntax.output_default_non_terminal; |
| GlueRuleSynthesizer glueRuleSynthesizer(*m_glueRuleTrie, dflt_nonterm); |
|
|
| |
| for (std::vector<InputTree::Node>::const_iterator p = |
| m_inputTree.nodes.begin(); p != m_inputTree.nodes.end(); ++p) { |
|
|
| const InputTree::Node &node = *p; |
|
|
| |
| if (node.children.empty()) { |
| continue; |
| } |
|
|
| |
| |
| |
| |
| callback.ClearContainer(); |
| for (typename std::vector<boost::shared_ptr<RuleMatcher> >::iterator |
| q = m_ruleMatchers.begin(); q != m_ruleMatchers.end(); ++q) { |
| (*q)->EnumerateHyperedges(node, callback); |
| } |
|
|
| |
| const BoundedPriorityContainer<SHyperedgeBundle> &bundles = |
| callback.GetContainer(); |
|
|
| |
| |
| if (bundles.Size() == 0) { |
| glueRuleSynthesizer.SynthesizeRule(node); |
| m_glueRuleMatcher->EnumerateHyperedges(node, callback); |
| assert(bundles.Size() == 1); |
| } |
|
|
| |
| |
| CubeQueue cubeQueue(bundles.Begin(), bundles.End()); |
| std::size_t count = 0; |
| std::vector<SHyperedge*> buffer; |
| while (count < popLimit && !cubeQueue.IsEmpty()) { |
| SHyperedge *hyperedge = cubeQueue.Pop(); |
| |
| |
| hyperedge->head->pvertex = &(node.pvertex); |
| |
| buffer.push_back(hyperedge); |
| ++count; |
| } |
|
|
| |
| SVertexStack &stack = m_stackMap[&(node.pvertex)]; |
| RecombineAndSort(buffer, stack); |
|
|
| |
| if (stackLimit > 0 && stack.size() > stackLimit) { |
| stack.resize(stackLimit); |
| } |
| } |
| } |
|
|
| template<typename RuleMatcher> |
| const SHyperedge *Manager<RuleMatcher>::GetBestSHyperedge() const |
| { |
| const InputTree::Node &rootNode = m_inputTree.nodes.back(); |
| F2S::PVertexToStackMap::const_iterator p = m_stackMap.find(&rootNode.pvertex); |
| assert(p != m_stackMap.end()); |
| const SVertexStack &stack = p->second; |
| assert(!stack.empty()); |
| return stack[0]->best; |
| } |
|
|
| template<typename RuleMatcher> |
| void Manager<RuleMatcher>::ExtractKBest( |
| std::size_t k, |
| std::vector<boost::shared_ptr<KBestExtractor::Derivation> > &kBestList, |
| bool onlyDistinct) const |
| { |
| kBestList.clear(); |
| if (k == 0 || m_source.GetSize() == 0) { |
| return; |
| } |
|
|
| |
| const InputTree::Node &rootNode = m_inputTree.nodes.back(); |
| F2S::PVertexToStackMap::const_iterator p = m_stackMap.find(&rootNode.pvertex); |
| assert(p != m_stackMap.end()); |
| const SVertexStack &stack = p->second; |
| assert(!stack.empty()); |
|
|
| KBestExtractor extractor; |
|
|
| if (!onlyDistinct) { |
| |
| extractor.Extract(stack, k, kBestList); |
| return; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| const std::size_t nBestFactor = this->options()->nbest.factor; |
| std::size_t numDerivations = (nBestFactor == 0) ? k*1000 : k*nBestFactor; |
|
|
| |
| KBestExtractor::KBestVec bigList; |
| bigList.reserve(numDerivations); |
| extractor.Extract(stack, numDerivations, bigList); |
|
|
| |
| std::set<Phrase> distinct; |
| for (KBestExtractor::KBestVec::const_iterator p = bigList.begin(); |
| kBestList.size() < k && p != bigList.end(); ++p) { |
| boost::shared_ptr<KBestExtractor::Derivation> derivation = *p; |
| Phrase translation = KBestExtractor::GetOutputPhrase(*derivation); |
| if (distinct.insert(translation).second) { |
| kBestList.push_back(derivation); |
| } |
| } |
| } |
|
|
| |
| |
| template<typename RuleMatcher> |
| void Manager<RuleMatcher>::RecombineAndSort( |
| const std::vector<SHyperedge*> &buffer, SVertexStack &stack) |
| { |
| |
| |
| |
| |
| |
| typedef boost::unordered_map<SVertex *, SVertex *, |
| SVertexRecombinationHasher, |
| SVertexRecombinationEqualityPred> Map; |
| Map map; |
| for (std::vector<SHyperedge*>::const_iterator p = buffer.begin(); |
| p != buffer.end(); ++p) { |
| SHyperedge *h = *p; |
| SVertex *v = h->head; |
| assert(v->best == h); |
| assert(v->recombined.empty()); |
| std::pair<Map::iterator, bool> result = map.insert(Map::value_type(v, v)); |
| if (result.second) { |
| continue; |
| } |
| |
| |
| |
| SVertex *storedVertex = result.first->second; |
| if (h->label.futureScore > storedVertex->best->label.futureScore) { |
| |
| storedVertex->recombined.push_back(storedVertex->best); |
| storedVertex->best = h; |
| } else { |
| storedVertex->recombined.push_back(h); |
| } |
| h->head->best = 0; |
| delete h->head; |
| h->head = storedVertex; |
| } |
|
|
| |
| stack.clear(); |
| stack.reserve(map.size()); |
| for (Map::const_iterator p = map.begin(); p != map.end(); ++p) { |
| stack.push_back(boost::shared_ptr<SVertex>(p->first)); |
| } |
|
|
| |
| std::sort(stack.begin(), stack.end(), SVertexStackContentOrderer()); |
| } |
|
|
| template<typename RuleMatcher> |
| void Manager<RuleMatcher>::OutputDetailedTranslationReport( |
| OutputCollector *collector) const |
| { |
| const SHyperedge *best = GetBestSHyperedge(); |
| if (best == NULL || collector == NULL) { |
| return; |
| } |
| long translationId = m_source.GetTranslationId(); |
| std::ostringstream out; |
| F2S::DerivationWriter::Write(*best, translationId, out); |
| collector->Write(translationId, out.str()); |
| } |
|
|
| } |
| } |
| } |
|
|