|
|
""" |
|
|
A utilitiy class for keeping track of intermediate parse states |
|
|
""" |
|
|
|
|
|
from collections import namedtuple |
|
|
|
|
|
class TreeStack(namedtuple('TreeStack', ['value', 'parent', 'length'])): |
|
|
""" |
|
|
A stack which can branch in several directions, as long as you |
|
|
keep track of the branching heads |
|
|
|
|
|
An example usage is when K constituents are removed at once |
|
|
to create a new constituent, and then the LSTM which tracks the |
|
|
values of the constituents is updated starting from the Kth |
|
|
output of the LSTM with the new value. |
|
|
|
|
|
We don't simply keep track of a single stack object using a deque |
|
|
because versions of the parser which use a beam will want to be |
|
|
able to branch in different directions from the same base stack |
|
|
|
|
|
Another possible usage is if an oracle is used for training |
|
|
in a manner where some fraction of steps are non-gold steps, |
|
|
but we also want to take a gold step from the same state. |
|
|
Eg, parser gets to state X, wants to make incorrect transition T |
|
|
instead of gold transition G, and so we continue training both |
|
|
X+G and X+T. If we only represent the state X with standard |
|
|
python stacks, it would not be possible to track both of these |
|
|
states at the same time without copying the entire thing. |
|
|
|
|
|
Value can be as transition, a word, or a partially built constituent |
|
|
|
|
|
Implemented as a namedtuple to make it a bit more efficient |
|
|
""" |
|
|
def pop(self): |
|
|
return self.parent |
|
|
|
|
|
def push(self, value): |
|
|
|
|
|
return TreeStack(value, self, self.length+1) |
|
|
|
|
|
def __iter__(self): |
|
|
stack = self |
|
|
while stack.parent is not None: |
|
|
yield stack.value |
|
|
stack = stack.parent |
|
|
yield stack.value |
|
|
|
|
|
def __reversed__(self): |
|
|
items = list(iter(self)) |
|
|
for item in reversed(items): |
|
|
yield item |
|
|
|
|
|
def __str__(self): |
|
|
return "TreeStack(%s)" % ", ".join([str(x) for x in self]) |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|