|
|
from ...torch_core import * |
|
|
from ...layers import * |
|
|
from .awd_lstm import RNNDropout, LinearDecoder, SequentialRNN |
|
|
|
|
|
__all__ = ['Activation', 'PositionalEncoding', 'GeLU', 'Swish', 'feed_forward', 'MultiHeadAttention', 'MultiHeadRelativeAttention', |
|
|
'DecoderLayer', 'Transformer', 'TransformerXL', 'tfmer_lm_config', 'tfmer_clas_config', 'tfmer_lm_split', 'tfmer_clas_split', |
|
|
'tfmerXL_lm_config', 'tfmerXL_clas_config', 'tfmerXL_lm_split', 'tfmerXL_clas_split'] |
|
|
|
|
|
Activation = Enum('Activation', 'ReLU Swish GeLU') |
|
|
|
|
|
class PositionalEncoding(Module): |
|
|
"Encode the position with a sinusoid." |
|
|
def __init__(self, d:int): self.register_buffer('freq', 1 / (10000 ** (torch.arange(0., d, 2.)/d))) |
|
|
|
|
|
def forward(self, pos:Tensor): |
|
|
inp = torch.ger(pos, self.freq) |
|
|
enc = torch.cat([inp.sin(), inp.cos()], dim=-1) |
|
|
return enc |
|
|
|
|
|
class GeLU(Module): |
|
|
def forward(self, x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
|
|
|
|
|
class Swish(Module): |
|
|
def forward(self, x): return x * torch.sigmoid(x) |
|
|
|
|
|
_activ_func = {Activation.ReLU:nn.ReLU(inplace=True), Activation.GeLU:GeLU(), Activation.Swish: Swish()} |
|
|
|
|
|
def feed_forward(d_model:int, d_ff:int, ff_p:float=0., act:Activation=Activation.ReLU, double_drop:bool=True): |
|
|
layers = [nn.Linear(d_model, d_ff), _activ_func[act]] |
|
|
if double_drop: layers.append(nn.Dropout(ff_p)) |
|
|
return SequentialEx(*layers, nn.Linear(d_ff, d_model), nn.Dropout(ff_p), MergeLayer(), nn.LayerNorm(d_model)) |
|
|
|
|
|
class MultiHeadAttention(Module): |
|
|
"MutiHeadAttention." |
|
|
def __init__(self, n_heads:int, d_model:int, d_head:int=None, resid_p:float=0., attn_p:float=0., bias:bool=True, |
|
|
scale:bool=True): |
|
|
d_head = ifnone(d_head, d_model//n_heads) |
|
|
self.n_heads,self.d_head,self.scale = n_heads,d_head,scale |
|
|
self.attention = nn.Linear(d_model, 3 * n_heads * d_head, bias=bias) |
|
|
self.out = nn.Linear(n_heads * d_head, d_model, bias=bias) |
|
|
self.drop_att,self.drop_res = nn.Dropout(attn_p),nn.Dropout(resid_p) |
|
|
self.ln = nn.LayerNorm(d_model) |
|
|
|
|
|
def forward(self, x:Tensor, mask:Tensor=None, **kwargs): |
|
|
return self.ln(x + self.drop_res(self.out(self._apply_attention(x, mask=mask, **kwargs)))) |
|
|
|
|
|
def _apply_attention(self, x:Tensor, mask:Tensor=None): |
|
|
bs,x_len = x.size(0),x.size(1) |
|
|
wq,wk,wv = torch.chunk(self.attention(x), 3, dim=-1) |
|
|
wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv)) |
|
|
wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3) |
|
|
attn_score = torch.matmul(wq, wk) |
|
|
if self.scale: attn_score.div_(self.d_head ** 0.5) |
|
|
if mask is not None: |
|
|
attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score) |
|
|
attn_prob = self.drop_att(F.softmax(attn_score, dim=-1)) |
|
|
attn_vec = torch.matmul(attn_prob, wv) |
|
|
return attn_vec.permute(0, 2, 1, 3).contiguous().contiguous().view(bs, x_len, -1) |
|
|
|
|
|
def _attention_einsum(self, x, mask=None): |
|
|
|
|
|
bs,x_len = x.size(0),x.size(1) |
|
|
wq,wk,wv = torch.chunk(self.attention(x), 3, dim=-1) |
|
|
wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv)) |
|
|
attn_score = torch.einsum('bind,bjnd->bijn', (wq, wk)) |
|
|
if self.scale: attn_score.mul_(1/(self.d_head ** 0.5)) |
|
|
if mask is not None: |
|
|
attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score) |
|
|
attn_prob = self.drop_att(F.softmax(attn_score, dim=2)) |
|
|
attn_vec = torch.einsum('bijn,bjnd->bind', (attn_prob, wv)) |
|
|
return attn_vec.contiguous().view(bs, x_len, -1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _line_shift(x:Tensor, mask:bool=False): |
|
|
"Shift the line i of `x` by p-i elements to the left, is `mask` puts 0s on the diagonal." |
|
|
bs,nh,n,p = x.size() |
|
|
x_pad = torch.cat([x.new_zeros(bs,nh,n,1), x], dim=3) |
|
|
x_shift = x_pad.view(bs,nh,p + 1,n)[:,:,1:].view_as(x) |
|
|
if mask: x_shift.mul_(torch.tril(x.new_ones(n,p), p-n)[None,None,]) |
|
|
return x_shift |
|
|
|
|
|
class MultiHeadRelativeAttention(MultiHeadAttention): |
|
|
"MutiHeadAttention with relative positional encoding." |
|
|
|
|
|
def __init__(self, n_heads:int, d_model:int, d_head:int, resid_p:float=0., attn_p:float=0., bias:bool=True, |
|
|
scale:bool=True): |
|
|
super().__init__(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale) |
|
|
self.r_attn = nn.Linear(d_model, n_heads * d_head, bias=bias) |
|
|
|
|
|
def _apply_attention(self, x:Tensor, r:Tensor=None, u:Tensor=None, v:Tensor=None, mask:Tensor=None, mem:Tensor=None): |
|
|
|
|
|
|
|
|
bs,x_len,seq_len = x.size(0),x.size(1),r.size(0) |
|
|
context = x if mem is None else torch.cat([mem, x], dim=1) |
|
|
wq,wk,wv = torch.chunk(self.attention(context), 3, dim=-1) |
|
|
wq = wq[:,-x_len:] |
|
|
wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv)) |
|
|
wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3) |
|
|
wkr = self.r_attn(r) |
|
|
wkr = wkr.view(seq_len, self.n_heads, self.d_head) |
|
|
wkr = wkr.permute(1,2,0) |
|
|
|
|
|
AC = torch.matmul(wq+u,wk) |
|
|
BD = _line_shift(torch.matmul(wq+v, wkr)) |
|
|
if self.scale: attn_score = (AC + BD).mul_(1/(self.d_head ** 0.5)) |
|
|
if mask is not None: |
|
|
attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score) |
|
|
attn_prob = self.drop_att(F.softmax(attn_score, dim=-1)) |
|
|
attn_vec = torch.matmul(attn_prob, wv) |
|
|
return attn_vec.permute(0, 2, 1, 3).contiguous().view(bs, x_len, -1) |
|
|
|
|
|
def _attention_einsum(self, x:Tensor, r:Tensor=None, u:Tensor=None, v:Tensor=None, mask:Tensor=None, mem:Tensor=None): |
|
|
|
|
|
bs,x_len,seq_len = x.size(0),x.size(1),r.size(0) |
|
|
context = x if mem is None else torch.cat([mem, x], dim=1) |
|
|
wq,wk,wv = torch.chunk(self.attention(context), 3, dim=-1) |
|
|
wq = wq[:,-x_len:] |
|
|
wkr = self.r_attn(r) |
|
|
wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv)) |
|
|
wkr = wkr.view(seq_len, self.n_heads, self.d_head) |
|
|
|
|
|
AC = torch.einsum('bind,bjnd->bijn', (wq+u, wk)) |
|
|
BD = _line_shift1(torch.einsum('bind,jnd->bijn', (wq+v, wkr))) |
|
|
attn_score = (AC + BD).mul_(1/(self.d_head ** 0.5)) |
|
|
if mask is not None: |
|
|
attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score) |
|
|
attn_prob = self.drop_att(F.softmax(attn_score, dim=2)) |
|
|
attn_vec = torch.einsum('bijn,bjnd->bind', (attn_prob, wv)) |
|
|
return attn_vec.contiguous().view(bs, x_len, -1) |
|
|
|
|
|
class DecoderLayer(Module): |
|
|
"Basic block of a Transformer model." |
|
|
|
|
|
def __init__(self, n_heads:int, d_model:int, d_head:int, d_inner:int, resid_p:float=0., attn_p:float=0., ff_p:float=0., |
|
|
bias:bool=True, scale:bool=True, act:Activation=Activation.ReLU, double_drop:bool=True, |
|
|
attn_cls:Callable=MultiHeadAttention): |
|
|
self.mhra = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale) |
|
|
self.ff = feed_forward(d_model, d_inner, ff_p=ff_p, act=act, double_drop=double_drop) |
|
|
|
|
|
def forward(self, x:Tensor, mask:Tensor=None, **kwargs): return self.ff(self.mhra(x, mask=mask, **kwargs)) |
|
|
|
|
|
class Transformer(Module): |
|
|
"Transformer model: https://arxiv.org/abs/1706.03762." |
|
|
def __init__(self, vocab_sz:int, ctx_len:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int, |
|
|
resid_p:float=0., attn_p:float=0., ff_p:float=0., embed_p:float=0., bias:bool=True, scale:bool=True, |
|
|
act:Activation=Activation.ReLU, double_drop:bool=True, attn_cls:Callable=MultiHeadAttention, |
|
|
learned_pos_enc:bool=True, mask:bool=True): |
|
|
self.mask = mask |
|
|
self.encoder = nn.Embedding(vocab_sz, d_model) |
|
|
self.pos_enc = nn.Embedding(ctx_len, d_model) if learned_pos_enc else PositionalEncoding(d_model) |
|
|
self.drop_emb = nn.Dropout(embed_p) |
|
|
self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p, |
|
|
ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop, |
|
|
attn_cls=attn_cls) for k in range(n_layers)]) |
|
|
|
|
|
def reset(self): pass |
|
|
|
|
|
def forward(self, x): |
|
|
bs, x_len = x.size() |
|
|
pos = torch.arange(0, x_len, device=x.device, dtype=x.dtype) |
|
|
inp = self.drop_emb(self.encoder(x) + self.pos_enc(pos)[None]) |
|
|
mask = torch.triu(x.new_ones(x_len, x_len), diagonal=1).byte()[None,None] if self.mask else None |
|
|
|
|
|
for layer in self.layers: inp = layer(inp, mask=mask) |
|
|
return ([inp],[inp]) |
|
|
|
|
|
class TransformerXL(Module): |
|
|
"TransformerXL model: https://arxiv.org/abs/1901.02860." |
|
|
def __init__(self, vocab_sz:int, ctx_len:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int, |
|
|
resid_p:float=0., attn_p:float=0., ff_p:float=0., embed_p:float=0., bias:bool=False, scale:bool=True, |
|
|
act:Activation=Activation.ReLU, double_drop:bool=True, attn_cls:Callable=MultiHeadRelativeAttention, |
|
|
learned_pos_enc:bool=False, mask:bool=True, mem_len:int=0): |
|
|
self.encoder = nn.Embedding(vocab_sz, d_model) |
|
|
self.pos_enc = nn.Embedding(ctx_len, d_model) if learned_pos_enc else PositionalEncoding(d_model) |
|
|
self.drop_emb = nn.Dropout(embed_p) |
|
|
self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) |
|
|
self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) |
|
|
self.mem_len,self.n_layers,self.d_model,self.mask = mem_len,n_layers,d_model,mask |
|
|
self.init = False |
|
|
self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p, |
|
|
ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop, |
|
|
attn_cls=attn_cls) for k in range(n_layers)]) |
|
|
|
|
|
def reset(self): |
|
|
"Reset the internal memory." |
|
|
self.hidden = [next(self.parameters()).data.new(0) for i in range(self.n_layers+1)] |
|
|
|
|
|
def _update_mems(self, hids): |
|
|
if not getattr(self, 'hidden', False): return None |
|
|
assert len(hids) == len(self.hidden), 'len(hids) != len(self.hidden)' |
|
|
with torch.no_grad(): |
|
|
for i in range(len(hids)): |
|
|
cat = torch.cat([self.hidden[i], hids[i]], dim=1) |
|
|
self.hidden[i] = cat[:,-self.mem_len:].detach() |
|
|
|
|
|
def select_hidden(self, idxs): self.hidden = [h[idxs] for h in self.hidden] |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
if self.mem_len > 0 and not self.init: |
|
|
self.reset() |
|
|
self.init = True |
|
|
bs,x_len = x.size() |
|
|
inp = self.drop_emb(self.encoder(x)) |
|
|
m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0 |
|
|
seq_len = m_len + x_len |
|
|
mask = torch.triu(x.new_ones(x_len, seq_len), diagonal=1+m_len).byte()[None,None] if self.mask else None |
|
|
|
|
|
hids = [] |
|
|
pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype) |
|
|
pos_enc = self.pos_enc(pos) |
|
|
hids.append(inp) |
|
|
for i, layer in enumerate(self.layers): |
|
|
mem = self.hidden[i] if self.mem_len > 0 else None |
|
|
inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem) |
|
|
hids.append(inp) |
|
|
core_out = inp[:,-x_len:] |
|
|
if self.mem_len > 0 : self._update_mems(hids) |
|
|
return (self.hidden if self.mem_len > 0 else [core_out]),[core_out] |
|
|
|
|
|
def init_transformer(m): |
|
|
classname = m.__class__.__name__ |
|
|
if classname.find('Linear') != -1: |
|
|
if hasattr(m, 'weight') and m.weight is not None: nn.init.normal_(m.weight, 0., 0.02) |
|
|
if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias, 0.) |
|
|
elif classname.find('LayerNorm') != -1: |
|
|
if hasattr(m, 'weight') and m.weight is not None: nn.init.normal_(m.weight, 1., 0.02) |
|
|
if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias, 0.) |
|
|
elif classname.find('TransformerXL') != -1: |
|
|
if hasattr(m, 'u'): nn.init.normal_(m.u, 0., 0.02) |
|
|
if hasattr(m, 'v'): nn.init.normal_(m.v, 0., 0.02) |
|
|
|
|
|
tfmer_lm_config = dict(ctx_len=512, n_layers=12, n_heads=12, d_model=768, d_head=64, d_inner=3072, resid_p=0.1, attn_p=0.1, |
|
|
ff_p=0.1, embed_p=0.1, output_p=0., bias=True, scale=True, act=Activation.GeLU, double_drop=False, |
|
|
tie_weights=True, out_bias=False, init=init_transformer, mask=True) |
|
|
|
|
|
tfmer_clas_config = dict(ctx_len=512, n_layers=12, n_heads=12, d_model=768, d_head=64, d_inner=3072, resid_p=0.1, attn_p=0.1, |
|
|
ff_p=0.1, embed_p=0.1, output_p=0., bias=True, scale=True, act=Activation.GeLU, double_drop=False, |
|
|
init=init_transformer, mask=False) |
|
|
|
|
|
def tfmer_lm_split(model:nn.Module) -> List[nn.Module]: |
|
|
"Split a RNN `model` in groups for differential learning rates." |
|
|
encoder = model[0] |
|
|
n = len(encoder.layers)//3 |
|
|
groups = [list(encoder.layers[:n]), list(encoder.layers[n:2*n]), list(encoder.layers[2*n:])] |
|
|
return groups + [[encoder.encoder, model[1]]] |
|
|
|
|
|
def tfmer_clas_split(model:nn.Module) -> List[nn.Module]: |
|
|
"Split a RNN `model` in groups for differential learning rates." |
|
|
encoder = model[0].module |
|
|
n = len(encoder.layers)//3 |
|
|
groups = [[encoder.encoder], list(encoder.layers[:n]), list(encoder.layers[n:2*n]), list(encoder.layers[2*n:])] |
|
|
return groups + [[model[1]]] |
|
|
|
|
|
tfmerXL_lm_config = dict(ctx_len=150, n_layers=12, n_heads=10, d_model=410, d_head=41, d_inner=2100, resid_p=0.1, attn_p=0.1, |
|
|
ff_p=0.1, embed_p=0.1, output_p=0.1, bias=False, scale=True, act=Activation.ReLU, double_drop=True, |
|
|
tie_weights=True, out_bias=True, init=init_transformer, mem_len=150, mask=True) |
|
|
|
|
|
tfmerXL_clas_config = dict(ctx_len=150, n_layers=12, n_heads=10, d_model=410, d_head=41, d_inner=2100, resid_p=0.1, attn_p=0.1, |
|
|
ff_p=0.1, embed_p=0.1, output_p=0.1, bias=False, scale=True, act=Activation.ReLU, double_drop=True, |
|
|
init=init_transformer, mem_len=150, mask=False) |
|
|
|
|
|
def tfmerXL_lm_split(model:nn.Module) -> List[nn.Module]: |
|
|
"Split a RNN `model` in groups for differential learning rates." |
|
|
encoder = model[0] |
|
|
n = len(encoder.layers)//3 |
|
|
groups = [list(encoder.layers[:n]) + [ParameterModule(encoder.u), ParameterModule(encoder.v)]] |
|
|
return groups + [list(encoder.layers[n:2*n]), list(encoder.layers[2*n:]), [encoder.encoder, model[1]]] |
|
|
|
|
|
def tfmerXL_clas_split(model:nn.Module) -> List[nn.Module]: |
|
|
"Split a RNN `model` in groups for differential learning rates." |
|
|
encoder = model[0].module |
|
|
n = len(encoder.layers)//3 |
|
|
groups = [[encoder.encoder], list(encoder.layers[:n]) + [ParameterModule(encoder.u), ParameterModule(encoder.v)]] |
|
|
return groups + [list(encoder.layers[n:2*n]), list(encoder.layers[2*n:]), [model[1]]] |
|
|
|