// -*- c++ -*- #pragma once #include "moses/DecodeGraph.h" #include "moses/ForestInput.h" #include "moses/StaticData.h" #include "moses/Syntax/BoundedPriorityContainer.h" #include "moses/Syntax/CubeQueue.h" #include "moses/Syntax/PHyperedge.h" #include "moses/Syntax/RuleTable.h" #include "moses/Syntax/RuleTableFF.h" #include "moses/Syntax/SHyperedgeBundle.h" #include "moses/Syntax/SVertex.h" #include "moses/Syntax/SVertexRecombinationEqualityPred.h" #include "moses/Syntax/SVertexRecombinationHasher.h" #include "moses/Syntax/SymbolEqualityPred.h" #include "moses/Syntax/SymbolHasher.h" #include "moses/Syntax/T2S/InputTree.h" #include "moses/Syntax/T2S/InputTreeBuilder.h" #include "moses/Syntax/T2S/InputTreeToForest.h" #include "moses/TreeInput.h" #include "DerivationWriter.h" #include "GlueRuleSynthesizer.h" #include "HyperTree.h" #include "RuleMatcherCallback.h" #include "TopologicalSorter.h" namespace Moses { namespace Syntax { namespace F2S { template Manager::Manager(ttasksptr const& ttask) : Syntax::Manager(ttask) { if (const ForestInput *p = dynamic_cast(&m_source)) { m_forest = p->GetForest(); m_rootVertex = p->GetRootVertex(); m_sentenceLength = p->GetSize(); } else if (const TreeInput *p = dynamic_cast(&m_source)) { T2S::InputTreeBuilder builder(options()->output.factor_order); T2S::InputTree tmpTree; builder.Build(*p, "Q", tmpTree); boost::shared_ptr forest = boost::make_shared(); m_rootVertex = T2S::InputTreeToForest(tmpTree, *forest); m_forest = forest; m_sentenceLength = p->GetSize(); } else { UTIL_THROW2("ERROR: F2S::Manager requires input to be a tree or forest"); } } template void Manager::Decode() { // Get various pruning-related constants. const std::size_t popLimit = options()->cube.pop_limit; const std::size_t ruleLimit = options()->syntax.rule_limit; const std::size_t stackLimit = options()->search.stack_size; // Initialize the stacks. InitializeStacks(); // Initialize the rule matchers. InitializeRuleMatchers(); // Create a callback to process the PHyperedges produced by the rule matchers. RuleMatcherCallback callback(m_stackMap, ruleLimit); // Create a glue rule synthesizer. GlueRuleSynthesizer glueRuleSynthesizer(*options(), *m_glueRuleTrie); // Sort the input forest's vertices into bottom-up topological order. std::vector sortedVertices; TopologicalSorter sorter; sorter.Sort(*m_forest, sortedVertices); // Visit each vertex of the input forest in topological order. for (std::vector::const_iterator p = sortedVertices.begin(); p != sortedVertices.end(); ++p) { const Forest::Vertex &vertex = **p; // Skip terminal vertices (after checking if they are OOVs). if (vertex.incoming.empty()) { if (vertex.pvertex.span.GetStartPos() > 0 && vertex.pvertex.span.GetEndPos() < m_sentenceLength-1 && IsUnknownSourceWord(vertex.pvertex.symbol)) { m_oovs.insert(vertex.pvertex.symbol); } continue; } // Call the rule matchers to generate PHyperedges for this vertex and // convert each one to a SHyperedgeBundle (via the callback). The // callback prunes the SHyperedgeBundles and keeps the best ones (up // to ruleLimit). callback.ClearContainer(); for (typename std::vector >::iterator q = m_mainRuleMatchers.begin(); q != m_mainRuleMatchers.end(); ++q) { (*q)->EnumerateHyperedges(vertex, callback); } // Retrieve the (pruned) set of SHyperedgeBundles from the callback. const BoundedPriorityContainer &bundles = callback.GetContainer(); // Check if any rules were matched. If not then for each incoming // hyperedge, synthesize a glue rule that is guaranteed to match. if (bundles.Size() == 0) { for (std::vector::const_iterator p = vertex.incoming.begin(); p != vertex.incoming.end(); ++p) { glueRuleSynthesizer.SynthesizeRule(**p); } m_glueRuleMatcher->EnumerateHyperedges(vertex, callback); // FIXME This assertion occasionally fails -- why? // assert(bundles.Size() == vertex.incoming.size()); } // Use cube pruning to extract SHyperedges from SHyperedgeBundles and // collect the SHyperedges in a buffer. CubeQueue cubeQueue(bundles.Begin(), bundles.End()); std::size_t count = 0; std::vector buffer; while (count < popLimit && !cubeQueue.IsEmpty()) { SHyperedge *hyperedge = cubeQueue.Pop(); // FIXME See corresponding code in S2T::Manager // BEGIN{HACK} hyperedge->head->pvertex = &(vertex.pvertex); // END{HACK} buffer.push_back(hyperedge); ++count; } // Recombine SVertices and sort into a stack. SVertexStack &stack = m_stackMap[&(vertex.pvertex)]; RecombineAndSort(buffer, stack); // Prune stack. if (stackLimit > 0 && stack.size() > stackLimit) { stack.resize(stackLimit); } } } template void Manager::InitializeRuleMatchers() { const std::vector &ffs = RuleTableFF::Instances(); for (std::size_t i = 0; i < ffs.size(); ++i) { RuleTableFF *ff = ffs[i]; // This may change in the future, but currently we assume that every // RuleTableFF is associated with a static, file-based rule table of // some sort and that the table should have been loaded into a RuleTable // by this point. const RuleTable *table = ff->GetTable(); assert(table); RuleTable *nonConstTable = const_cast(table); HyperTree *trie = dynamic_cast(nonConstTable); assert(trie); boost::shared_ptr p(new RuleMatcher(*trie)); m_mainRuleMatchers.push_back(p); } // Create an additional rule trie + matcher for glue rules (which are // synthesized on demand). // FIXME Add a hidden RuleTableFF for the glue rule trie(?) m_glueRuleTrie.reset(new HyperTree(ffs[0])); m_glueRuleMatcher = boost::shared_ptr( new RuleMatcher(*m_glueRuleTrie)); } template void Manager::InitializeStacks() { // Check that m_forest has been initialized. assert(!m_forest->vertices.empty()); for (std::vector::const_iterator p = m_forest->vertices.begin(); p != m_forest->vertices.end(); ++p) { const Forest::Vertex &vertex = **p; // Create an empty stack. SVertexStack &stack = m_stackMap[&(vertex.pvertex)]; // For terminals only, add a single SVertex. if (vertex.incoming.empty()) { boost::shared_ptr v(new SVertex()); v->best = 0; v->pvertex = &(vertex.pvertex); stack.push_back(v); } } } template bool Manager::IsUnknownSourceWord(const Word &w) const { const std::size_t factorId = w[0]->GetId(); const std::vector &ffs = RuleTableFF::Instances(); for (std::size_t i = 0; i < ffs.size(); ++i) { RuleTableFF *ff = ffs[i]; const boost::unordered_set &sourceTerms = ff->GetSourceTerminalSet(); if (sourceTerms.find(factorId) != sourceTerms.end()) { return false; } } return true; } template const SHyperedge *Manager::GetBestSHyperedge() const { PVertexToStackMap::const_iterator p = m_stackMap.find(&m_rootVertex->pvertex); assert(p != m_stackMap.end()); const SVertexStack &stack = p->second; assert(!stack.empty()); return stack[0]->best; } template void Manager::ExtractKBest( std::size_t k, std::vector > &kBestList, bool onlyDistinct) const { kBestList.clear(); if (k == 0 || m_source.GetSize() == 0) { return; } // Get the top-level SVertex stack. PVertexToStackMap::const_iterator p = m_stackMap.find(&m_rootVertex->pvertex); assert(p != m_stackMap.end()); const SVertexStack &stack = p->second; assert(!stack.empty()); KBestExtractor extractor; if (!onlyDistinct) { // Return the k-best list as is, including duplicate translations. extractor.Extract(stack, k, kBestList); return; } // Determine how many derivations to extract. If the k-best list is // restricted to distinct translations then this limit should be bigger // than k. The k-best factor determines how much bigger the limit should be, // with 0 being 'unlimited.' This actually sets a large-ish limit in case // too many translations are identical. const StaticData &staticData = StaticData::Instance(); const std::size_t nBestFactor = staticData.options()->nbest.factor; std::size_t numDerivations = (nBestFactor == 0) ? k*1000 : k*nBestFactor; // Extract the derivations. KBestExtractor::KBestVec bigList; bigList.reserve(numDerivations); extractor.Extract(stack, numDerivations, bigList); // Copy derivations into kBestList, skipping ones with repeated translations. std::set distinct; for (KBestExtractor::KBestVec::const_iterator p = bigList.begin(); kBestList.size() < k && p != bigList.end(); ++p) { boost::shared_ptr derivation = *p; Phrase translation = KBestExtractor::GetOutputPhrase(*derivation); if (distinct.insert(translation).second) { kBestList.push_back(derivation); } } } // TODO Move this function into parent directory (Recombiner class?) and // TODO share with S2T template void Manager::RecombineAndSort( const std::vector &buffer, SVertexStack &stack) { // Step 1: Create a map containing a single instance of each distinct vertex // (where distinctness is defined by the state value). The hyperedges' // head pointers are updated to point to the vertex instances in the map and // any 'duplicate' vertices are deleted. // TODO Set? typedef boost::unordered_map Map; Map map; for (std::vector::const_iterator p = buffer.begin(); p != buffer.end(); ++p) { SHyperedge *h = *p; SVertex *v = h->head; assert(v->best == h); assert(v->recombined.empty()); std::pair result = map.insert(Map::value_type(v, v)); if (result.second) { continue; // v's recombination value hasn't been seen before. } // v is a duplicate (according to the recombination rules). // Compare the score of h against the score of the best incoming hyperedge // for the stored vertex. SVertex *storedVertex = result.first->second; if (h->label.futureScore > storedVertex->best->label.futureScore) { // h's score is better. storedVertex->recombined.push_back(storedVertex->best); storedVertex->best = h; } else { storedVertex->recombined.push_back(h); } h->head->best = 0; delete h->head; h->head = storedVertex; } // Step 2: Copy the vertices from the map to the stack. stack.clear(); stack.reserve(map.size()); for (Map::const_iterator p = map.begin(); p != map.end(); ++p) { stack.push_back(boost::shared_ptr(p->first)); } // Step 3: Sort the vertices in the stack. std::sort(stack.begin(), stack.end(), SVertexStackContentOrderer()); } template void Manager::OutputDetailedTranslationReport( OutputCollector *collector) const { const SHyperedge *best = GetBestSHyperedge(); if (best == NULL || collector == NULL) { return; } long translationId = m_source.GetTranslationId(); std::ostringstream out; DerivationWriter::Write(*best, translationId, out); collector->Write(translationId, out.str()); } } // F2S } // Syntax } // Moses