File size: 7,081 Bytes
038426a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from adaptive_span import AdaptiveSpan
def _skew(X, pad_value):
"""shift every row 1 step to right"""
# X = B x M x L
B, M, L = X.size()
X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1)
X = X.view(B, -1) # B x ML+MM+M
X = X[:, :-M] # B x ML+MM
X = X.view(B, M, M + L) # B x M x L+M
return X
def _unskew(X):
"""reverse _skew operation"""
# X = B x M x L+M
B, M, L = X.size()
L -= M
X = X.view(B, -1) # B x ML+MM
X = F.pad(X, (0, M)) # B x ML+MM+M
X = X.view(B, M, M + L + 1) # B x M x L+M+1
X = X[:, :, :L] # B x M x L
return X
class SeqAttention(nn.Module):
"""Sequential self-attention layer.
Each token will attend to its previous fixed number of steps.
Note that attention doesn't include the current step itself.
"""
def __init__(self, hidden_size, nb_heads, attn_span,
dropout,adapt_span_params, **kargs):
nn.Module.__init__(self)
# pdb.set_trace()
self.dropout = nn.Dropout(dropout)
self.hidden_size = hidden_size # size of a single head
self.attn_span = attn_span
self.adapt_span_enabled = adapt_span_params['adapt_span_enabled']
if self.adapt_span_enabled:
self.adaptive_span = AdaptiveSpan(attn_span=attn_span, nb_heads=nb_heads,
**adapt_span_params, **kargs)
self.persistent_memory = None
def forward(self, query, key, value, key_pe,output_attentions=False):
# query size = B x M x H
# key, value sizes = B x (M+L) x H
# compute attention from context
# B x M (dest) x (M+L) (src)
attn_cont = torch.matmul(query, key.transpose(-1, -2))
attn_cont = _unskew(attn_cont) # B x M x L
# compute the effect of position embedding
attn_pos = torch.matmul(query, key_pe) # B x M x L_pos
attn = attn_cont + attn_pos
if self.persistent_memory is not None:
attn, pers_mem_out = self.persistent_memory(query, attn)
else:
attn = attn / math.sqrt(self.hidden_size) # B x M X L_pos
attn = F.softmax(attn, dim=-1)
if self.adapt_span_enabled:
# trim attention lengths according to the learned span
attn = self.adaptive_span(attn)
attn = self.dropout(attn) # B x M X L_pos
attn_cont = _skew(attn, 0) # B x M X (L+M)
out = torch.matmul(attn_cont, value) # B x M x H
# pdb.set_trace()
if self.persistent_memory is not None:
out = out + pers_mem_out
if output_attentions:
L=attn_cont.size()[1]
return out, attn_cont[:,:,-L:]
else:
return out
def get_cache_size(self):
if self.adapt_span_enabled:
return self.adaptive_span.get_cache_size()
else:
return self.attn_span
class MultiHeadSeqAttention(nn.Module):
def __init__(self, hidden_size, nb_heads, **kargs):
nn.Module.__init__(self)
# pdb.set_trace()
assert hidden_size % nb_heads == 0
self.nb_heads = nb_heads
self.head_dim = hidden_size // nb_heads
self.attn = SeqAttention(
hidden_size=self.head_dim, nb_heads=nb_heads, **kargs)
self.proj_query = nn.Linear(hidden_size, hidden_size, bias=False)
self.proj_out = nn.Linear(hidden_size, hidden_size, bias=False)
self.proj_val = nn.Linear(hidden_size, hidden_size, bias=False)
self.proj_key = nn.Linear(hidden_size, hidden_size, bias=False)
def head_reshape(self, x):
K = self.nb_heads
D = self.head_dim
x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D
x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D
x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D
return x
def forward(self, query, key, value, key_pe,output_attentions=False):
B = query.size(0)
K = self.nb_heads
D = self.head_dim
M = query.size(1)
query = self.proj_query(query)
query = self.head_reshape(query)
value = self.proj_val(value)
value = self.head_reshape(value)
key = self.proj_key(key)
key = self.head_reshape(key)
if output_attentions:
out,attentions = self.attn(query, key, value, key_pe,output_attentions) # B_K x M x D
else:
out = self.attn(query, key, value, key_pe,output_attentions) # B_K x M x D
out = out.view(B, K, M, D) # B x K x M x D
out = out.transpose(1, 2).contiguous() # B x M x K x D
out = out.view(B, M, -1) # B x M x K_D
out = self.proj_out(out)
# pdb.set_trace()
if output_attentions:
return out, attentions
else:
return out
class FeedForwardLayer(nn.Module):
def __init__(self, hidden_size, inner_hidden_size, dropout, **kargs):
nn.Module.__init__(self)
self.fc1 = nn.Linear(hidden_size, inner_hidden_size)
self.fc2 = nn.Linear(inner_hidden_size, hidden_size)
self.dropout = nn.Dropout(dropout)
def forward(self, h):
h1 = F.relu(self.fc1(h))
h1 = self.dropout(h1)
h2 = self.fc2(h1)
return h2
class FSA_layer(nn.Module):
def __init__(self, hidden_size,**kargs):
nn.Module.__init__(self)
# pdb.set_trace()
self.attn_span=kargs['attn_span']
self.hidden_size=hidden_size
self.attn = MultiHeadSeqAttention(hidden_size=hidden_size, **kargs)
self.norm1 = nn.LayerNorm(hidden_size)
self.ff = FeedForwardLayer(hidden_size=hidden_size, **kargs)
self.norm2 = nn.LayerNorm(hidden_size)
self.key_pe = nn.Parameter(
torch.randn(1, hidden_size // kargs['nb_heads'], kargs['attn_span']))
# self.h_cache=torch.zeros(16,kargs['attn_span'],hidden_size).cuda()
def forward(self, h,output_attentions=False):
# h = B x M x H
# h_cache = B x L x H
B=h.shape[0]
self.h_cache=torch.zeros(B,self.attn_span,self.hidden_size).to(h.device)
h_all = torch.cat([self.h_cache, h], dim=1) # B x (M+L) x H
attn_out = self.attn(h, h_all, h_all, self.key_pe,output_attentions)
if output_attentions:
attn_out,attentions = self.attn(h, h_all, h_all, self.key_pe,output_attentions)
else:
attn_out = self.attn(h, h_all, h_all, self.key_pe,output_attentions)
h = self.norm1(h + attn_out) # B x M x H
ff_out = self.ff(h)
out = self.norm2(h + ff_out) # B x M x H
if output_attentions:
return out,attentions
else:
return out
|