File size: 5,131 Bytes
fd49381 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
#pragma once
#include "moses/Syntax/S2T/PChart.h"
namespace Moses
{
namespace Syntax
{
namespace S2T
{
template<typename Callback>
RecursiveCYKPlusParser<Callback>::RecursiveCYKPlusParser(
PChart &chart,
const RuleTrie &trie,
std::size_t maxChartSpan)
: Parser<Callback>(chart)
, m_ruleTable(trie)
, m_maxChartSpan(maxChartSpan)
, m_callback(NULL)
{
m_hyperedge.head = 0;
}
template<typename Callback>
void RecursiveCYKPlusParser<Callback>::EnumerateHyperedges(
const Range &range,
Callback &callback)
{
const std::size_t start = range.GetStartPos();
const std::size_t end = range.GetEndPos();
m_callback = &callback;
const RuleTrie::Node &rootNode = m_ruleTable.GetRootNode();
m_maxEnd = std::min(Base::m_chart.GetWidth()-1, start+m_maxChartSpan-1);
m_hyperedge.tail.clear();
// Find all hyperedges where the first incoming vertex is a terminal covering
// [start,end].
GetTerminalExtension(rootNode, start, end);
// Find all hyperedges where the first incoming vertex is a non-terminal
// covering [start,end-1].
if (end > start) {
GetNonTerminalExtensions(rootNode, start, end-1, end-1);
}
}
// Search for all extensions of a partial rule (pointed at by node) that begin
// with a non-terminal over a span between [start,minEnd] and [start,maxEnd].
template<typename Callback>
void RecursiveCYKPlusParser<Callback>::GetNonTerminalExtensions(
const RuleTrie::Node &node,
std::size_t start,
std::size_t minEnd,
std::size_t maxEnd)
{
// Non-terminal labels in node's outgoing edge set.
const RuleTrie::Node::SymbolMap &nonTermMap = node.GetNonTerminalMap();
// Compressed matrix from PChart.
const PChart::CompressedMatrix &matrix =
Base::m_chart.GetCompressedMatrix(start);
// Loop over possible expansions of the rule.
RuleTrie::Node::SymbolMap::const_iterator p;
RuleTrie::Node::SymbolMap::const_iterator p_end = nonTermMap.end();
for (p = nonTermMap.begin(); p != p_end; ++p) {
const Word &nonTerm = p->first;
const std::vector<PChart::CompressedItem> &items =
matrix[nonTerm[0]->GetId()];
for (std::vector<PChart::CompressedItem>::const_iterator q = items.begin();
q != items.end(); ++q) {
if (q->end >= minEnd && q->end <= maxEnd) {
const RuleTrie::Node &child = p->second;
AddAndExtend(child, q->end, *(q->vertex));
}
}
}
}
// Search for all extensions of a partial rule (pointed at by node) that begin
// with a terminal over span [start,end].
template<typename Callback>
void RecursiveCYKPlusParser<Callback>::GetTerminalExtension(
const RuleTrie::Node &node,
std::size_t start,
std::size_t end)
{
const PChart::Cell::TMap &vertexMap =
Base::m_chart.GetCell(start, end).terminalVertices;
if (vertexMap.empty()) {
return;
}
const RuleTrie::Node::SymbolMap &terminals = node.GetTerminalMap();
for (PChart::Cell::TMap::const_iterator p = vertexMap.begin();
p != vertexMap.end(); ++p) {
const Word &terminal = p->first;
const PVertex &vertex = p->second;
// if node has small number of terminal edges, test word equality for each.
if (terminals.size() < 5) {
for (RuleTrie::Node::SymbolMap::const_iterator iter = terminals.begin();
iter != terminals.end(); ++iter) {
const Word &word = iter->first;
if (word == terminal) {
const RuleTrie::Node *child = & iter->second;
AddAndExtend(*child, end, vertex);
break;
}
}
} else { // else, do hash lookup
const RuleTrie::Node *child = node.GetChild(terminal);
if (child != NULL) {
AddAndExtend(*child, end, vertex);
}
}
}
}
// If a (partial) rule matches, pass it to the callback (if non-unary and
// non-empty), and try to find expansions that have this partial rule as prefix.
template<typename Callback>
void RecursiveCYKPlusParser<Callback>::AddAndExtend(
const RuleTrie::Node &node,
std::size_t end,
const PVertex &vertex)
{
// FIXME Sort out const-ness.
m_hyperedge.tail.push_back(const_cast<PVertex *>(&vertex));
// Add target phrase collection (except if rule is empty or unary).
TargetPhraseCollection::shared_ptr tpc = node.GetTargetPhraseCollection();
if (!tpc->IsEmpty() && !IsNonLexicalUnary(m_hyperedge)) {
m_hyperedge.label.translations = tpc;
(*m_callback)(m_hyperedge, end);
}
// Get all further extensions of rule (until reaching end of sentence or
// max-chart-span).
if (end < m_maxEnd) {
if (!node.GetTerminalMap().empty()) {
for (std::size_t newEndPos = end+1; newEndPos <= m_maxEnd; newEndPos++) {
GetTerminalExtension(node, end+1, newEndPos);
}
}
if (!node.GetNonTerminalMap().empty()) {
GetNonTerminalExtensions(node, end+1, end+1, m_maxEnd);
}
}
m_hyperedge.tail.pop_back();
}
template<typename Callback>
bool RecursiveCYKPlusParser<Callback>::IsNonLexicalUnary(
const PHyperedge &hyperedge) const
{
return hyperedge.tail.size() == 1 &&
hyperedge.tail[0]->symbol.IsNonTerminal();
}
} // namespace S2T
} // namespace Syntax
} // namespace Moses
|