""" Keeps an LSTM in TreeStack form. The TreeStack nodes keep the hx and cx for the LSTM, along with a "value" which represents whatever the user needs to store. The TreeStacks can be ppped to get back to the previous LSTM state. The module itself implements three methods: initial_state, push_states, output """ from collections import namedtuple import torch import torch.nn as nn from stanza.models.constituency.tree_stack import TreeStack Node = namedtuple("Node", ['value', 'lstm_hx', 'lstm_cx']) class LSTMTreeStack(nn.Module): def __init__(self, input_size, hidden_size, num_lstm_layers, dropout, uses_boundary_vector, input_dropout): """ Prepare LSTM and parameters input_size: dimension of the inputs to the LSTM hidden_size: LSTM internal & output dimension num_lstm_layers: how many layers of LSTM to use dropout: value of the LSTM dropout uses_boundary_vector: if set, learn a start_embedding parameter. otherwise, use zeros input_dropout: an nn.Module to dropout inputs. TODO: allow a float parameter as well """ super().__init__() self.uses_boundary_vector = uses_boundary_vector # The start embedding needs to be input_size as we put it through the LSTM if uses_boundary_vector: self.register_parameter('start_embedding', torch.nn.Parameter(0.2 * torch.randn(input_size, requires_grad=True))) else: self.register_buffer('input_zeros', torch.zeros(num_lstm_layers, 1, input_size)) self.register_buffer('hidden_zeros', torch.zeros(num_lstm_layers, 1, hidden_size)) self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_lstm_layers, dropout=dropout) self.input_dropout = input_dropout def initial_state(self, initial_value=None): """ Return an initial state, either based on zeros or based on the initial embedding and LSTM Note that LSTM start operation is already batched, in a sense The subsequent batch built this way will be used for batch_size trees Returns a stack with None value, hx & cx either based on the start_embedding or zeros, and no parent. """ if self.uses_boundary_vector: start = self.start_embedding.unsqueeze(0).unsqueeze(0) output, (hx, cx) = self.lstm(start) start = output[0, 0, :] else: start = self.input_zeros hx = self.hidden_zeros cx = self.hidden_zeros return TreeStack(value=Node(initial_value, hx, cx), parent=None, length=1) def push_states(self, stacks, values, inputs): """ Starting from a list of current stacks, put the inputs through the LSTM and build new stack nodes. B = stacks.len() = values.len() inputs must be of shape 1 x B x input_size """ inputs = self.input_dropout(inputs) hx = torch.cat([t.value.lstm_hx for t in stacks], axis=1) cx = torch.cat([t.value.lstm_cx for t in stacks], axis=1) output, (hx, cx) = self.lstm(inputs, (hx, cx)) new_stacks = [stack.push(Node(transition, hx[:, i:i+1, :], cx[:, i:i+1, :])) for i, (stack, transition) in enumerate(zip(stacks, values))] return new_stacks def output(self, stack): """ Return the last layer of the lstm_hx as the output from a stack Refactored so that alternate structures have an easy way of getting the output """ return stack.value.lstm_hx[-1, 0, :]