|
|
import math
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from adaptive_span import AdaptiveSpan
|
|
|
|
|
|
|
|
|
def _skew(X, pad_value):
|
|
|
"""shift every row 1 step to right"""
|
|
|
|
|
|
B, M, L = X.size()
|
|
|
X = F.pad(X, (0, M + 1), value=pad_value)
|
|
|
X = X.view(B, -1)
|
|
|
X = X[:, :-M]
|
|
|
X = X.view(B, M, M + L)
|
|
|
return X
|
|
|
|
|
|
def _unskew(X):
|
|
|
"""reverse _skew operation"""
|
|
|
|
|
|
B, M, L = X.size()
|
|
|
L -= M
|
|
|
X = X.view(B, -1)
|
|
|
X = F.pad(X, (0, M))
|
|
|
X = X.view(B, M, M + L + 1)
|
|
|
X = X[:, :, :L]
|
|
|
return X
|
|
|
|
|
|
class SeqAttention(nn.Module):
|
|
|
"""Sequential self-attention layer.
|
|
|
Each token will attend to its previous fixed number of steps.
|
|
|
Note that attention doesn't include the current step itself.
|
|
|
"""
|
|
|
def __init__(self, hidden_size, nb_heads, attn_span,
|
|
|
dropout,adapt_span_params, **kargs):
|
|
|
nn.Module.__init__(self)
|
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
self.hidden_size = hidden_size
|
|
|
self.attn_span = attn_span
|
|
|
self.adapt_span_enabled = adapt_span_params['adapt_span_enabled']
|
|
|
if self.adapt_span_enabled:
|
|
|
self.adaptive_span = AdaptiveSpan(attn_span=attn_span, nb_heads=nb_heads,
|
|
|
**adapt_span_params, **kargs)
|
|
|
|
|
|
self.persistent_memory = None
|
|
|
|
|
|
def forward(self, query, key, value, key_pe,output_attentions=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_cont = torch.matmul(query, key.transpose(-1, -2))
|
|
|
attn_cont = _unskew(attn_cont)
|
|
|
|
|
|
|
|
|
attn_pos = torch.matmul(query, key_pe)
|
|
|
attn = attn_cont + attn_pos
|
|
|
|
|
|
if self.persistent_memory is not None:
|
|
|
attn, pers_mem_out = self.persistent_memory(query, attn)
|
|
|
else:
|
|
|
attn = attn / math.sqrt(self.hidden_size)
|
|
|
attn = F.softmax(attn, dim=-1)
|
|
|
|
|
|
if self.adapt_span_enabled:
|
|
|
|
|
|
attn = self.adaptive_span(attn)
|
|
|
|
|
|
attn = self.dropout(attn)
|
|
|
|
|
|
attn_cont = _skew(attn, 0)
|
|
|
|
|
|
out = torch.matmul(attn_cont, value)
|
|
|
|
|
|
|
|
|
|
|
|
if self.persistent_memory is not None:
|
|
|
out = out + pers_mem_out
|
|
|
if output_attentions:
|
|
|
L=attn_cont.size()[1]
|
|
|
|
|
|
return out, attn_cont[:,:,-L:]
|
|
|
|
|
|
else:
|
|
|
return out
|
|
|
|
|
|
def get_cache_size(self):
|
|
|
if self.adapt_span_enabled:
|
|
|
return self.adaptive_span.get_cache_size()
|
|
|
else:
|
|
|
return self.attn_span
|
|
|
|
|
|
|
|
|
class MultiHeadSeqAttention(nn.Module):
|
|
|
def __init__(self, hidden_size, nb_heads, **kargs):
|
|
|
nn.Module.__init__(self)
|
|
|
|
|
|
assert hidden_size % nb_heads == 0
|
|
|
self.nb_heads = nb_heads
|
|
|
self.head_dim = hidden_size // nb_heads
|
|
|
self.attn = SeqAttention(
|
|
|
hidden_size=self.head_dim, nb_heads=nb_heads, **kargs)
|
|
|
self.proj_query = nn.Linear(hidden_size, hidden_size, bias=False)
|
|
|
self.proj_out = nn.Linear(hidden_size, hidden_size, bias=False)
|
|
|
self.proj_val = nn.Linear(hidden_size, hidden_size, bias=False)
|
|
|
self.proj_key = nn.Linear(hidden_size, hidden_size, bias=False)
|
|
|
|
|
|
def head_reshape(self, x):
|
|
|
K = self.nb_heads
|
|
|
D = self.head_dim
|
|
|
x = x.view(x.size()[:-1] + (K, D))
|
|
|
x = x.transpose(1, 2).contiguous()
|
|
|
x = x.view(-1, x.size(-2), x.size(-1))
|
|
|
return x
|
|
|
|
|
|
def forward(self, query, key, value, key_pe,output_attentions=False):
|
|
|
B = query.size(0)
|
|
|
K = self.nb_heads
|
|
|
D = self.head_dim
|
|
|
M = query.size(1)
|
|
|
|
|
|
query = self.proj_query(query)
|
|
|
query = self.head_reshape(query)
|
|
|
value = self.proj_val(value)
|
|
|
value = self.head_reshape(value)
|
|
|
key = self.proj_key(key)
|
|
|
key = self.head_reshape(key)
|
|
|
|
|
|
|
|
|
if output_attentions:
|
|
|
out,attentions = self.attn(query, key, value, key_pe,output_attentions)
|
|
|
else:
|
|
|
out = self.attn(query, key, value, key_pe,output_attentions)
|
|
|
|
|
|
out = out.view(B, K, M, D)
|
|
|
out = out.transpose(1, 2).contiguous()
|
|
|
out = out.view(B, M, -1)
|
|
|
out = self.proj_out(out)
|
|
|
|
|
|
if output_attentions:
|
|
|
return out, attentions
|
|
|
else:
|
|
|
return out
|
|
|
|
|
|
|
|
|
class FeedForwardLayer(nn.Module):
|
|
|
def __init__(self, hidden_size, inner_hidden_size, dropout, **kargs):
|
|
|
nn.Module.__init__(self)
|
|
|
self.fc1 = nn.Linear(hidden_size, inner_hidden_size)
|
|
|
self.fc2 = nn.Linear(inner_hidden_size, hidden_size)
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
def forward(self, h):
|
|
|
h1 = F.relu(self.fc1(h))
|
|
|
h1 = self.dropout(h1)
|
|
|
h2 = self.fc2(h1)
|
|
|
return h2
|
|
|
|
|
|
|
|
|
class FSA_layer(nn.Module):
|
|
|
def __init__(self, hidden_size,**kargs):
|
|
|
nn.Module.__init__(self)
|
|
|
|
|
|
self.attn_span=kargs['attn_span']
|
|
|
self.hidden_size=hidden_size
|
|
|
self.attn = MultiHeadSeqAttention(hidden_size=hidden_size, **kargs)
|
|
|
self.norm1 = nn.LayerNorm(hidden_size)
|
|
|
self.ff = FeedForwardLayer(hidden_size=hidden_size, **kargs)
|
|
|
self.norm2 = nn.LayerNorm(hidden_size)
|
|
|
self.key_pe = nn.Parameter(
|
|
|
torch.randn(1, hidden_size // kargs['nb_heads'], kargs['attn_span']))
|
|
|
|
|
|
def forward(self, h,output_attentions=False):
|
|
|
|
|
|
|
|
|
B=h.shape[0]
|
|
|
self.h_cache=torch.zeros(B,self.attn_span,self.hidden_size).to(h.device)
|
|
|
h_all = torch.cat([self.h_cache, h], dim=1)
|
|
|
attn_out = self.attn(h, h_all, h_all, self.key_pe,output_attentions)
|
|
|
|
|
|
if output_attentions:
|
|
|
attn_out,attentions = self.attn(h, h_all, h_all, self.key_pe,output_attentions)
|
|
|
else:
|
|
|
attn_out = self.attn(h, h_all, h_all, self.key_pe,output_attentions)
|
|
|
|
|
|
h = self.norm1(h + attn_out)
|
|
|
ff_out = self.ff(h)
|
|
|
out = self.norm2(h + ff_out)
|
|
|
if output_attentions:
|
|
|
return out,attentions
|
|
|
else:
|
|
|
return out
|
|
|
|