File size: 7,555 Bytes
19b8775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
"""
Based on
Transition-based Parsing with Stack-Transformers
Ramon Fernandez Astudillo, Miguel Ballesteros, Tahira Naseem,
Austin Blodget, and Radu Florian
https://aclanthology.org/2020.findings-emnlp.89.pdf
"""
from collections import namedtuple
import torch
import torch.nn as nn
from stanza.models.constituency.positional_encoding import SinusoidalEncoding
from stanza.models.constituency.tree_stack import TreeStack
Node = namedtuple("Node", ['value', 'key_stack', 'value_stack', 'output'])
class TransformerTreeStack(nn.Module):
def __init__(self, input_size, output_size, input_dropout, length_limit=None, use_position=False, num_heads=1):
"""
Builds the internal matrices and start parameter
TODO: currently only one attention head, implement MHA
"""
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.inv_sqrt_output_size = 1 / output_size ** 0.5
self.num_heads = num_heads
self.w_query = nn.Linear(input_size, output_size)
self.w_key = nn.Linear(input_size, output_size)
self.w_value = nn.Linear(input_size, output_size)
self.register_parameter('start_embedding', torch.nn.Parameter(0.2 * torch.randn(input_size, requires_grad=True)))
if isinstance(input_dropout, nn.Module):
self.input_dropout = input_dropout
else:
self.input_dropout = nn.Dropout(input_dropout)
if length_limit is not None and length_limit < 1:
raise ValueError("length_limit < 1 makes no sense")
self.length_limit = length_limit
self.use_position = use_position
if use_position:
self.position_encoding = SinusoidalEncoding(model_dim=self.input_size, max_len=512)
def attention(self, key, query, value, mask=None):
"""
Calculate attention for the given key, query value
Where B is the number of items stacked together, N is the length:
The key should be BxNxD
The query is BxD
The value is BxNxD
If mask is specified, it should be BxN of True/False values,
where True means that location is masked out
Reshapes and reorders are used to handle num_heads
Return will be softmax(query x key^T) * value
of size BxD
"""
B = key.shape[0]
N = key.shape[1]
D = key.shape[2]
H = self.num_heads
# query is now BxDx1
query = query.unsqueeze(2)
# BxHxD/Hx1
query = query.reshape((B, H, -1, 1))
# BxNxHxD/H
key = key.reshape((B, N, H, -1))
# BxHxNxD/H
key = key.transpose(1, 2)
# BxNxHxD/H
value = value.reshape((B, N, H, -1))
# BxHxNxD/H
value = value.transpose(1, 2)
# BxHxNxD/H x BxHxD/Hx1
# result shape: BxHxN
attn = torch.matmul(key, query).squeeze(3) * self.inv_sqrt_output_size
if mask is not None:
# mask goes from BxN -> Bx1xN
mask = mask.unsqueeze(1)
mask = mask.expand(-1, H, -1)
attn.masked_fill_(mask, float('-inf'))
# attn shape will now be BxHx1xN
attn = torch.softmax(attn, dim=2).unsqueeze(2)
# BxHx1xN x BxHxNxD/H -> BxHxD/H
output = torch.matmul(attn, value).squeeze(2)
output = output.reshape(B, -1)
return output
def initial_state(self, initial_value=None):
"""
Return an initial state based on a single layer of attention
Running attention might be overkill, but it is the simplest
way to put the Linears and start_embedding in the computation graph
"""
start = self.start_embedding
if self.use_position:
position = self.position_encoding([0]).squeeze(0)
start = start + position
# N=1
# shape: 1xD
key = self.w_key(start).unsqueeze(0)
# shape: D
query = self.w_query(start)
# shape: 1xD
value = self.w_value(start).unsqueeze(0)
# unsqueeze to make it look like we are part of a batch of size 1
output = self.attention(key.unsqueeze(0), query.unsqueeze(0), value.unsqueeze(0)).squeeze(0)
return TreeStack(value=Node(initial_value, key, value, output), parent=None, length=1)
def push_states(self, stacks, values, inputs):
"""
Push new inputs to the stacks and rerun attention on them
Where B is the number of items stacked together, I is input_size
stacks: B TreeStacks such as produced by initial_state and/or push_states
values: the new items to push on the stacks such as tree nodes or anything
inputs: BxI for the new input items
Runs attention starting from the existing keys & values
"""
device = self.w_key.weight.device
batch_len = len(stacks) # B
positions = [x.value.key_stack.shape[0] for x in stacks]
max_len = max(positions) # N
if self.use_position:
position_encodings = self.position_encoding(positions)
inputs = inputs + position_encodings
inputs = self.input_dropout(inputs)
if len(inputs.shape) == 3:
if inputs.shape[0] == 1:
inputs = inputs.squeeze(0)
else:
raise ValueError("Expected the inputs to be of shape 1xBxI, got {}".format(inputs.shape))
new_keys = self.w_key(inputs)
key_stack = torch.zeros(batch_len, max_len+1, self.output_size, device=device)
key_stack[:, -1, :] = new_keys
for stack_idx, stack in enumerate(stacks):
key_stack[stack_idx, -positions[stack_idx]-1:-1, :] = stack.value.key_stack
new_values = self.w_value(inputs)
value_stack = torch.zeros(batch_len, max_len+1, self.output_size, device=device)
value_stack[:, -1, :] = new_values
for stack_idx, stack in enumerate(stacks):
value_stack[stack_idx, -positions[stack_idx]-1:-1, :] = stack.value.value_stack
query = self.w_query(inputs)
mask = torch.zeros(batch_len, max_len+1, device=device, dtype=torch.bool)
for stack_idx, stack in enumerate(stacks):
if len(stack) < max_len:
masked = max_len - positions[stack_idx]
mask[stack_idx, :masked] = True
batched_output = self.attention(key_stack, query, value_stack, mask)
new_stacks = []
for stack_idx, (stack, node_value, new_key, new_value, output) in enumerate(zip(stacks, values, key_stack, value_stack, batched_output)):
# max_len-len(stack) so that we ignore the padding at the start of shorter stacks
new_key_stack = new_key[max_len-positions[stack_idx]:, :]
new_value_stack = new_value[max_len-positions[stack_idx]:, :]
if self.length_limit is not None and new_key_stack.shape[0] > self.length_limit + 1:
new_key_stack = torch.cat([new_key_stack[:1, :], new_key_stack[2:, :]], axis=0)
new_value_stack = torch.cat([new_value_stack[:1, :], new_value_stack[2:, :]], axis=0)
new_stacks.append(stack.push(value=Node(node_value, new_key_stack, new_value_stack, output)))
return new_stacks
def output(self, stack):
"""
Return the last layer of the lstm_hx as the output from a stack
Refactored so that alternate structures have an easy way of getting the output
"""
return stack.value.output
|