| | |
| | #pragma once |
| |
|
| | #include "moses/DecodeGraph.h" |
| | #include "moses/ForestInput.h" |
| | #include "moses/StaticData.h" |
| | #include "moses/Syntax/BoundedPriorityContainer.h" |
| | #include "moses/Syntax/CubeQueue.h" |
| | #include "moses/Syntax/PHyperedge.h" |
| | #include "moses/Syntax/RuleTable.h" |
| | #include "moses/Syntax/RuleTableFF.h" |
| | #include "moses/Syntax/SHyperedgeBundle.h" |
| | #include "moses/Syntax/SVertex.h" |
| | #include "moses/Syntax/SVertexRecombinationEqualityPred.h" |
| | #include "moses/Syntax/SVertexRecombinationHasher.h" |
| | #include "moses/Syntax/SymbolEqualityPred.h" |
| | #include "moses/Syntax/SymbolHasher.h" |
| | #include "moses/Syntax/T2S/InputTree.h" |
| | #include "moses/Syntax/T2S/InputTreeBuilder.h" |
| | #include "moses/Syntax/T2S/InputTreeToForest.h" |
| | #include "moses/TreeInput.h" |
| |
|
| | #include "DerivationWriter.h" |
| | #include "GlueRuleSynthesizer.h" |
| | #include "HyperTree.h" |
| | #include "RuleMatcherCallback.h" |
| | #include "TopologicalSorter.h" |
| |
|
| | namespace Moses |
| | { |
| | namespace Syntax |
| | { |
| | namespace F2S |
| | { |
| |
|
| | template<typename RuleMatcher> |
| | Manager<RuleMatcher>::Manager(ttasksptr const& ttask) |
| | : Syntax::Manager(ttask) |
| | { |
| | if (const ForestInput *p = dynamic_cast<const ForestInput*>(&m_source)) { |
| | m_forest = p->GetForest(); |
| | m_rootVertex = p->GetRootVertex(); |
| | m_sentenceLength = p->GetSize(); |
| | } else if (const TreeInput *p = dynamic_cast<const TreeInput*>(&m_source)) { |
| | T2S::InputTreeBuilder builder(options()->output.factor_order); |
| | T2S::InputTree tmpTree; |
| | builder.Build(*p, "Q", tmpTree); |
| | boost::shared_ptr<Forest> forest = boost::make_shared<Forest>(); |
| | m_rootVertex = T2S::InputTreeToForest(tmpTree, *forest); |
| | m_forest = forest; |
| | m_sentenceLength = p->GetSize(); |
| | } else { |
| | UTIL_THROW2("ERROR: F2S::Manager requires input to be a tree or forest"); |
| | } |
| | } |
| |
|
| | template<typename RuleMatcher> |
| | void Manager<RuleMatcher>::Decode() |
| | { |
| | |
| | const std::size_t popLimit = options()->cube.pop_limit; |
| | const std::size_t ruleLimit = options()->syntax.rule_limit; |
| | const std::size_t stackLimit = options()->search.stack_size; |
| |
|
| | |
| | InitializeStacks(); |
| |
|
| | |
| | InitializeRuleMatchers(); |
| |
|
| | |
| | RuleMatcherCallback callback(m_stackMap, ruleLimit); |
| |
|
| | |
| | GlueRuleSynthesizer glueRuleSynthesizer(*options(), *m_glueRuleTrie); |
| |
|
| | |
| | std::vector<const Forest::Vertex *> sortedVertices; |
| | TopologicalSorter sorter; |
| | sorter.Sort(*m_forest, sortedVertices); |
| |
|
| | |
| | for (std::vector<const Forest::Vertex *>::const_iterator |
| | p = sortedVertices.begin(); p != sortedVertices.end(); ++p) { |
| | const Forest::Vertex &vertex = **p; |
| |
|
| | |
| | if (vertex.incoming.empty()) { |
| | if (vertex.pvertex.span.GetStartPos() > 0 && |
| | vertex.pvertex.span.GetEndPos() < m_sentenceLength-1 && |
| | IsUnknownSourceWord(vertex.pvertex.symbol)) { |
| | m_oovs.insert(vertex.pvertex.symbol); |
| | } |
| | continue; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | callback.ClearContainer(); |
| | for (typename std::vector<boost::shared_ptr<RuleMatcher> >::iterator |
| | q = m_mainRuleMatchers.begin(); q != m_mainRuleMatchers.end(); ++q) { |
| | (*q)->EnumerateHyperedges(vertex, callback); |
| | } |
| |
|
| | |
| | const BoundedPriorityContainer<SHyperedgeBundle> &bundles = |
| | callback.GetContainer(); |
| |
|
| | |
| | |
| | if (bundles.Size() == 0) { |
| | for (std::vector<Forest::Hyperedge *>::const_iterator p = |
| | vertex.incoming.begin(); p != vertex.incoming.end(); ++p) { |
| | glueRuleSynthesizer.SynthesizeRule(**p); |
| | } |
| | m_glueRuleMatcher->EnumerateHyperedges(vertex, callback); |
| | |
| | |
| | } |
| |
|
| | |
| | |
| | CubeQueue cubeQueue(bundles.Begin(), bundles.End()); |
| | std::size_t count = 0; |
| | std::vector<SHyperedge*> buffer; |
| | while (count < popLimit && !cubeQueue.IsEmpty()) { |
| | SHyperedge *hyperedge = cubeQueue.Pop(); |
| | |
| | |
| | hyperedge->head->pvertex = &(vertex.pvertex); |
| | |
| | buffer.push_back(hyperedge); |
| | ++count; |
| | } |
| |
|
| | |
| | SVertexStack &stack = m_stackMap[&(vertex.pvertex)]; |
| | RecombineAndSort(buffer, stack); |
| |
|
| | |
| | if (stackLimit > 0 && stack.size() > stackLimit) { |
| | stack.resize(stackLimit); |
| | } |
| | } |
| | } |
| |
|
| | template<typename RuleMatcher> |
| | void Manager<RuleMatcher>::InitializeRuleMatchers() |
| | { |
| | const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances(); |
| | for (std::size_t i = 0; i < ffs.size(); ++i) { |
| | RuleTableFF *ff = ffs[i]; |
| | |
| | |
| | |
| | |
| | const RuleTable *table = ff->GetTable(); |
| | assert(table); |
| | RuleTable *nonConstTable = const_cast<RuleTable*>(table); |
| | HyperTree *trie = dynamic_cast<HyperTree*>(nonConstTable); |
| | assert(trie); |
| | boost::shared_ptr<RuleMatcher> p(new RuleMatcher(*trie)); |
| | m_mainRuleMatchers.push_back(p); |
| | } |
| |
|
| | |
| | |
| | |
| | m_glueRuleTrie.reset(new HyperTree(ffs[0])); |
| | m_glueRuleMatcher = boost::shared_ptr<RuleMatcher>( |
| | new RuleMatcher(*m_glueRuleTrie)); |
| | } |
| |
|
| | template<typename RuleMatcher> |
| | void Manager<RuleMatcher>::InitializeStacks() |
| | { |
| | |
| | assert(!m_forest->vertices.empty()); |
| |
|
| | for (std::vector<Forest::Vertex *>::const_iterator |
| | p = m_forest->vertices.begin(); p != m_forest->vertices.end(); ++p) { |
| | const Forest::Vertex &vertex = **p; |
| |
|
| | |
| | SVertexStack &stack = m_stackMap[&(vertex.pvertex)]; |
| |
|
| | |
| | if (vertex.incoming.empty()) { |
| | boost::shared_ptr<SVertex> v(new SVertex()); |
| | v->best = 0; |
| | v->pvertex = &(vertex.pvertex); |
| | stack.push_back(v); |
| | } |
| | } |
| | } |
| |
|
| | template<typename RuleMatcher> |
| | bool Manager<RuleMatcher>::IsUnknownSourceWord(const Word &w) const |
| | { |
| | const std::size_t factorId = w[0]->GetId(); |
| | const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances(); |
| | for (std::size_t i = 0; i < ffs.size(); ++i) { |
| | RuleTableFF *ff = ffs[i]; |
| | const boost::unordered_set<std::size_t> &sourceTerms = |
| | ff->GetSourceTerminalSet(); |
| | if (sourceTerms.find(factorId) != sourceTerms.end()) { |
| | return false; |
| | } |
| | } |
| | return true; |
| | } |
| |
|
| | template<typename RuleMatcher> |
| | const SHyperedge *Manager<RuleMatcher>::GetBestSHyperedge() const |
| | { |
| | PVertexToStackMap::const_iterator p = m_stackMap.find(&m_rootVertex->pvertex); |
| | assert(p != m_stackMap.end()); |
| | const SVertexStack &stack = p->second; |
| | assert(!stack.empty()); |
| | return stack[0]->best; |
| | } |
| |
|
| | template<typename RuleMatcher> |
| | void Manager<RuleMatcher>::ExtractKBest( |
| | std::size_t k, |
| | std::vector<boost::shared_ptr<KBestExtractor::Derivation> > &kBestList, |
| | bool onlyDistinct) const |
| | { |
| | kBestList.clear(); |
| | if (k == 0 || m_source.GetSize() == 0) { |
| | return; |
| | } |
| |
|
| | |
| | PVertexToStackMap::const_iterator p = m_stackMap.find(&m_rootVertex->pvertex); |
| | assert(p != m_stackMap.end()); |
| | const SVertexStack &stack = p->second; |
| | assert(!stack.empty()); |
| |
|
| | KBestExtractor extractor; |
| |
|
| | if (!onlyDistinct) { |
| | |
| | extractor.Extract(stack, k, kBestList); |
| | return; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | const StaticData &staticData = StaticData::Instance(); |
| | const std::size_t nBestFactor = staticData.options()->nbest.factor; |
| | std::size_t numDerivations = (nBestFactor == 0) ? k*1000 : k*nBestFactor; |
| |
|
| | |
| | KBestExtractor::KBestVec bigList; |
| | bigList.reserve(numDerivations); |
| | extractor.Extract(stack, numDerivations, bigList); |
| |
|
| | |
| | std::set<Phrase> distinct; |
| | for (KBestExtractor::KBestVec::const_iterator p = bigList.begin(); |
| | kBestList.size() < k && p != bigList.end(); ++p) { |
| | boost::shared_ptr<KBestExtractor::Derivation> derivation = *p; |
| | Phrase translation = KBestExtractor::GetOutputPhrase(*derivation); |
| | if (distinct.insert(translation).second) { |
| | kBestList.push_back(derivation); |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | template<typename RuleMatcher> |
| | void Manager<RuleMatcher>::RecombineAndSort( |
| | const std::vector<SHyperedge*> &buffer, SVertexStack &stack) |
| | { |
| | |
| | |
| | |
| | |
| | |
| | typedef boost::unordered_map<SVertex *, SVertex *, |
| | SVertexRecombinationHasher, |
| | SVertexRecombinationEqualityPred> Map; |
| | Map map; |
| | for (std::vector<SHyperedge*>::const_iterator p = buffer.begin(); |
| | p != buffer.end(); ++p) { |
| | SHyperedge *h = *p; |
| | SVertex *v = h->head; |
| | assert(v->best == h); |
| | assert(v->recombined.empty()); |
| | std::pair<Map::iterator, bool> result = map.insert(Map::value_type(v, v)); |
| | if (result.second) { |
| | continue; |
| | } |
| | |
| | |
| | |
| | SVertex *storedVertex = result.first->second; |
| | if (h->label.futureScore > storedVertex->best->label.futureScore) { |
| | |
| | storedVertex->recombined.push_back(storedVertex->best); |
| | storedVertex->best = h; |
| | } else { |
| | storedVertex->recombined.push_back(h); |
| | } |
| | h->head->best = 0; |
| | delete h->head; |
| | h->head = storedVertex; |
| | } |
| |
|
| | |
| | stack.clear(); |
| | stack.reserve(map.size()); |
| | for (Map::const_iterator p = map.begin(); p != map.end(); ++p) { |
| | stack.push_back(boost::shared_ptr<SVertex>(p->first)); |
| | } |
| |
|
| | |
| | std::sort(stack.begin(), stack.end(), SVertexStackContentOrderer()); |
| | } |
| |
|
| | template<typename RuleMatcher> |
| | void Manager<RuleMatcher>::OutputDetailedTranslationReport( |
| | OutputCollector *collector) const |
| | { |
| | const SHyperedge *best = GetBestSHyperedge(); |
| | if (best == NULL || collector == NULL) { |
| | return; |
| | } |
| | long translationId = m_source.GetTranslationId(); |
| | std::ostringstream out; |
| | DerivationWriter::Write(*best, translationId, out); |
| | collector->Write(translationId, out.str()); |
| | } |
| |
|
| | } |
| | } |
| | } |
| |
|