Albin Thörn Cleland
Clean initial commit with LFS
19b8775
"""
A utilitiy class for keeping track of intermediate parse states
"""
from collections import namedtuple
class TreeStack(namedtuple('TreeStack', ['value', 'parent', 'length'])):
"""
A stack which can branch in several directions, as long as you
keep track of the branching heads
An example usage is when K constituents are removed at once
to create a new constituent, and then the LSTM which tracks the
values of the constituents is updated starting from the Kth
output of the LSTM with the new value.
We don't simply keep track of a single stack object using a deque
because versions of the parser which use a beam will want to be
able to branch in different directions from the same base stack
Another possible usage is if an oracle is used for training
in a manner where some fraction of steps are non-gold steps,
but we also want to take a gold step from the same state.
Eg, parser gets to state X, wants to make incorrect transition T
instead of gold transition G, and so we continue training both
X+G and X+T. If we only represent the state X with standard
python stacks, it would not be possible to track both of these
states at the same time without copying the entire thing.
Value can be as transition, a word, or a partially built constituent
Implemented as a namedtuple to make it a bit more efficient
"""
def pop(self):
return self.parent
def push(self, value):
# returns a new stack node which points to this
return TreeStack(value, self, self.length+1)
def __iter__(self):
stack = self
while stack.parent is not None:
yield stack.value
stack = stack.parent
yield stack.value
def __reversed__(self):
items = list(iter(self))
for item in reversed(items):
yield item
def __str__(self):
return "TreeStack(%s)" % ", ".join([str(x) for x in self])
def __len__(self):
return self.length