toilachuoituyet's picture
Upload folder using huggingface_hub
038426a verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from adaptive_span import AdaptiveSpan
def _skew(X, pad_value):
"""shift every row 1 step to right"""
# X = B x M x L
B, M, L = X.size()
X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1)
X = X.view(B, -1) # B x ML+MM+M
X = X[:, :-M] # B x ML+MM
X = X.view(B, M, M + L) # B x M x L+M
return X
def _unskew(X):
"""reverse _skew operation"""
# X = B x M x L+M
B, M, L = X.size()
L -= M
X = X.view(B, -1) # B x ML+MM
X = F.pad(X, (0, M)) # B x ML+MM+M
X = X.view(B, M, M + L + 1) # B x M x L+M+1
X = X[:, :, :L] # B x M x L
return X
class SeqAttention(nn.Module):
"""Sequential self-attention layer.
Each token will attend to its previous fixed number of steps.
Note that attention doesn't include the current step itself.
"""
def __init__(self, hidden_size, nb_heads, attn_span,
dropout,adapt_span_params, **kargs):
nn.Module.__init__(self)
# pdb.set_trace()
self.dropout = nn.Dropout(dropout)
self.hidden_size = hidden_size # size of a single head
self.attn_span = attn_span
self.adapt_span_enabled = adapt_span_params['adapt_span_enabled']
if self.adapt_span_enabled:
self.adaptive_span = AdaptiveSpan(attn_span=attn_span, nb_heads=nb_heads,
**adapt_span_params, **kargs)
self.persistent_memory = None
def forward(self, query, key, value, key_pe,output_attentions=False):
# query size = B x M x H
# key, value sizes = B x (M+L) x H
# compute attention from context
# B x M (dest) x (M+L) (src)
attn_cont = torch.matmul(query, key.transpose(-1, -2))
attn_cont = _unskew(attn_cont) # B x M x L
# compute the effect of position embedding
attn_pos = torch.matmul(query, key_pe) # B x M x L_pos
attn = attn_cont + attn_pos
if self.persistent_memory is not None:
attn, pers_mem_out = self.persistent_memory(query, attn)
else:
attn = attn / math.sqrt(self.hidden_size) # B x M X L_pos
attn = F.softmax(attn, dim=-1)
if self.adapt_span_enabled:
# trim attention lengths according to the learned span
attn = self.adaptive_span(attn)
attn = self.dropout(attn) # B x M X L_pos
attn_cont = _skew(attn, 0) # B x M X (L+M)
out = torch.matmul(attn_cont, value) # B x M x H
# pdb.set_trace()
if self.persistent_memory is not None:
out = out + pers_mem_out
if output_attentions:
L=attn_cont.size()[1]
return out, attn_cont[:,:,-L:]
else:
return out
def get_cache_size(self):
if self.adapt_span_enabled:
return self.adaptive_span.get_cache_size()
else:
return self.attn_span
class MultiHeadSeqAttention(nn.Module):
def __init__(self, hidden_size, nb_heads, **kargs):
nn.Module.__init__(self)
# pdb.set_trace()
assert hidden_size % nb_heads == 0
self.nb_heads = nb_heads
self.head_dim = hidden_size // nb_heads
self.attn = SeqAttention(
hidden_size=self.head_dim, nb_heads=nb_heads, **kargs)
self.proj_query = nn.Linear(hidden_size, hidden_size, bias=False)
self.proj_out = nn.Linear(hidden_size, hidden_size, bias=False)
self.proj_val = nn.Linear(hidden_size, hidden_size, bias=False)
self.proj_key = nn.Linear(hidden_size, hidden_size, bias=False)
def head_reshape(self, x):
K = self.nb_heads
D = self.head_dim
x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D
x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D
x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D
return x
def forward(self, query, key, value, key_pe,output_attentions=False):
B = query.size(0)
K = self.nb_heads
D = self.head_dim
M = query.size(1)
query = self.proj_query(query)
query = self.head_reshape(query)
value = self.proj_val(value)
value = self.head_reshape(value)
key = self.proj_key(key)
key = self.head_reshape(key)
if output_attentions:
out,attentions = self.attn(query, key, value, key_pe,output_attentions) # B_K x M x D
else:
out = self.attn(query, key, value, key_pe,output_attentions) # B_K x M x D
out = out.view(B, K, M, D) # B x K x M x D
out = out.transpose(1, 2).contiguous() # B x M x K x D
out = out.view(B, M, -1) # B x M x K_D
out = self.proj_out(out)
# pdb.set_trace()
if output_attentions:
return out, attentions
else:
return out
class FeedForwardLayer(nn.Module):
def __init__(self, hidden_size, inner_hidden_size, dropout, **kargs):
nn.Module.__init__(self)
self.fc1 = nn.Linear(hidden_size, inner_hidden_size)
self.fc2 = nn.Linear(inner_hidden_size, hidden_size)
self.dropout = nn.Dropout(dropout)
def forward(self, h):
h1 = F.relu(self.fc1(h))
h1 = self.dropout(h1)
h2 = self.fc2(h1)
return h2
class FSA_layer(nn.Module):
def __init__(self, hidden_size,**kargs):
nn.Module.__init__(self)
# pdb.set_trace()
self.attn_span=kargs['attn_span']
self.hidden_size=hidden_size
self.attn = MultiHeadSeqAttention(hidden_size=hidden_size, **kargs)
self.norm1 = nn.LayerNorm(hidden_size)
self.ff = FeedForwardLayer(hidden_size=hidden_size, **kargs)
self.norm2 = nn.LayerNorm(hidden_size)
self.key_pe = nn.Parameter(
torch.randn(1, hidden_size // kargs['nb_heads'], kargs['attn_span']))
# self.h_cache=torch.zeros(16,kargs['attn_span'],hidden_size).cuda()
def forward(self, h,output_attentions=False):
# h = B x M x H
# h_cache = B x L x H
B=h.shape[0]
self.h_cache=torch.zeros(B,self.attn_span,self.hidden_size).to(h.device)
h_all = torch.cat([self.h_cache, h], dim=1) # B x (M+L) x H
attn_out = self.attn(h, h_all, h_all, self.key_pe,output_attentions)
if output_attentions:
attn_out,attentions = self.attn(h, h_all, h_all, self.key_pe,output_attentions)
else:
attn_out = self.attn(h, h_all, h_all, self.key_pe,output_attentions)
h = self.norm1(h + attn_out) # B x M x H
ff_out = self.ff(h)
out = self.norm2(h + ff_out) # B x M x H
if output_attentions:
return out,attentions
else:
return out