|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
from .modeling_bert import BertEncoder, BertPooler
|
|
|
|
|
|
class GlobalMaskMaxPooling1D(nn.Module):
|
|
|
def __init__(self, ):
|
|
|
super(GlobalMaskMaxPooling1D, self).__init__()
|
|
|
|
|
|
def forward(self, x, mask=None):
|
|
|
if mask is not None:
|
|
|
|
|
|
mask = 1.0 - mask
|
|
|
mask = mask * (-2**10 + 1)
|
|
|
mask = torch.unsqueeze(mask, dim=-1)
|
|
|
x += mask
|
|
|
return torch.max(x, dim=1)[0]
|
|
|
|
|
|
|
|
|
class GlobalMaskMinPooling1D(nn.Module):
|
|
|
def __init__(self, ):
|
|
|
super(GlobalMaskMinPooling1D, self).__init__()
|
|
|
|
|
|
def forward(self, x, mask=None):
|
|
|
if mask is not None:
|
|
|
|
|
|
mask = 1.0 - mask
|
|
|
mask = mask * (2**10+1)
|
|
|
mask = torch.unsqueeze(mask, dim=-1)
|
|
|
x += mask
|
|
|
return torch.min(x, dim=1)[0]
|
|
|
|
|
|
|
|
|
class GlobalMaskAvgPooling1D(nn.Module):
|
|
|
def __init__(self):
|
|
|
super(GlobalMaskAvgPooling1D, self).__init__()
|
|
|
|
|
|
def forward(self, x, mask=None):
|
|
|
if mask is not None:
|
|
|
|
|
|
mask = torch.unsqueeze(mask, dim=-1)
|
|
|
x *= mask
|
|
|
return torch.sum(x, dim=1)/torch.sum(mask, dim=1)
|
|
|
else:
|
|
|
return torch.mean(x, dim=1)
|
|
|
|
|
|
|
|
|
class GlobalMaskSumPooling1D(nn.Module):
|
|
|
def __init__(self, axis):
|
|
|
'''
|
|
|
sum pooling
|
|
|
:param axis: axis=0, add all the rows of the matrix,axis=1, add all the cols of the matrix
|
|
|
'''
|
|
|
super(GlobalMaskSumPooling1D, self).__init__()
|
|
|
self.axis = axis
|
|
|
|
|
|
def forward(self, x, mask=None):
|
|
|
if mask is not None:
|
|
|
|
|
|
mask = torch.unsqueeze(mask, dim=-1)
|
|
|
x *= mask
|
|
|
return torch.sum(x, dim=self.axis)
|
|
|
|
|
|
|
|
|
class GlobalMaskWeightedAttentionPooling1D(nn.Module):
|
|
|
def __init__(self, embed_size, use_bias=False):
|
|
|
super(GlobalMaskWeightedAttentionPooling1D, self).__init__()
|
|
|
self.embed_size = embed_size
|
|
|
self.use_bias = use_bias
|
|
|
|
|
|
self.W = nn.Parameter(torch.Tensor(self.embed_size))
|
|
|
nn.init.trunc_normal_(self.W, std=0.01)
|
|
|
if self.use_bias:
|
|
|
self.b = nn.Parameter(torch.Tensor(1))
|
|
|
nn.init.trunc_normal_(self.b, std=0.01)
|
|
|
|
|
|
def forward(self, x, mask=None):
|
|
|
|
|
|
logits = torch.matmul(x, self.W)
|
|
|
if self.use_bias:
|
|
|
logits += self.b
|
|
|
|
|
|
if mask is not None:
|
|
|
attention_probs = nn.Softmax(dim=-1)(logits + (1.0 - mask) * -10000)
|
|
|
else:
|
|
|
attention_probs = nn.Softmax(dim=-1)(logits)
|
|
|
x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class GlobalMaskContextAttentionPooling1D(nn.Module):
|
|
|
def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False):
|
|
|
super(GlobalMaskContextAttentionPooling1D, self).__init__()
|
|
|
self.embed_size = embed_size
|
|
|
self.use_additive_bias = use_additive_bias
|
|
|
self.use_attention_bias = use_attention_bias
|
|
|
self.units = units if units else embed_size
|
|
|
|
|
|
self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units))
|
|
|
self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units))
|
|
|
if self.use_additive_bias:
|
|
|
self.b1 = nn.Parameter(torch.Tensor(self.units))
|
|
|
nn.init.trunc_normal_(self.b1, std=0.01)
|
|
|
if self.use_attention_bias:
|
|
|
self.b2 = nn.Parameter(torch.Tensor(1))
|
|
|
nn.init.trunc_normal_(self.b2, std=0.01)
|
|
|
|
|
|
self.c = nn.Parameter(torch.Tensor(self.units))
|
|
|
|
|
|
nn.init.trunc_normal_(self.U, std=0.01)
|
|
|
nn.init.trunc_normal_(self.V, std=0.01)
|
|
|
nn.init.trunc_normal_(self.c, std=0.01)
|
|
|
|
|
|
def forward(self, x, mask=None):
|
|
|
|
|
|
q = torch.matmul(x, self.U)
|
|
|
k = torch.matmul(x, self.V)
|
|
|
if self.use_additive_bias:
|
|
|
h = torch.tanh(q + k + self.b1)
|
|
|
else:
|
|
|
h = torch.tanh(q + k)
|
|
|
|
|
|
if self.use_attention_bias:
|
|
|
e = torch.matmul(h, self.c) + self.b2
|
|
|
else:
|
|
|
e = torch.matmul(h, self.c)
|
|
|
if mask is not None:
|
|
|
attention_probs = nn.Softmax(dim=-1)(e + (1.0 - mask) * -10000)
|
|
|
else:
|
|
|
attention_probs = nn.Softmax(dim=-1)(e)
|
|
|
x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class GlobalMaskValueAttentionPooling1D(nn.Module):
|
|
|
def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False):
|
|
|
super(GlobalMaskValueAttentionPooling1D, self).__init__()
|
|
|
self.embed_size = embed_size
|
|
|
self.use_additive_bias = use_additive_bias
|
|
|
self.use_attention_bias = use_attention_bias
|
|
|
self.units = units if units else embed_size
|
|
|
|
|
|
self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units))
|
|
|
self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units))
|
|
|
if self.use_additive_bias:
|
|
|
self.b1 = nn.Parameter(torch.Tensor(self.units))
|
|
|
nn.init.trunc_normal_(self.b1, std=0.01)
|
|
|
if self.use_attention_bias:
|
|
|
self.b2 = nn.Parameter(torch.Tensor(self.embed_size))
|
|
|
nn.init.trunc_normal_(self.b2, std=0.01)
|
|
|
|
|
|
self.W = nn.Parameter(torch.Tensor(self.units, self.embed_size))
|
|
|
|
|
|
nn.init.trunc_normal_(self.U, std=0.01)
|
|
|
nn.init.trunc_normal_(self.V, std=0.01)
|
|
|
nn.init.trunc_normal_(self.W, std=0.01)
|
|
|
|
|
|
def forward(self, x, mask=None):
|
|
|
|
|
|
q = torch.matmul(x, self.U)
|
|
|
k = torch.matmul(x, self.V)
|
|
|
if self.use_additive_bias:
|
|
|
h = torch.tanh(q + k + self.b1)
|
|
|
else:
|
|
|
h = torch.tanh(q + k)
|
|
|
|
|
|
|
|
|
if self.use_attention_bias:
|
|
|
e = torch.matmul(h, self.W) + self.b2
|
|
|
else:
|
|
|
e = torch.matmul(h, self.W)
|
|
|
if mask is not None:
|
|
|
attention_probs = nn.Softmax(dim=1)(e + torch.unsqueeze((1.0 - mask) * -10000, dim=-1))
|
|
|
else:
|
|
|
attention_probs = nn.Softmax(dim=1)(e)
|
|
|
x = torch.sum(attention_probs * x, dim=1)
|
|
|
return x
|
|
|
|
|
|
def __repr__(self):
|
|
|
return self.__class__.__name__ + ' (' + str(self.embed_size) + ' -> ' + str(self.embed_size) + ')'
|
|
|
|
|
|
|
|
|
class GlobalMaskTransformerPooling1D(nn.Module):
|
|
|
def __init__(self, config):
|
|
|
super(GlobalMaskTransformerPooling1D, self).__init__()
|
|
|
self.embeddings = nn.Parameter(torch.Tensor(1, 1, config.hidden_size))
|
|
|
nn.init.trunc_normal_(self.embeddings, std=0.02)
|
|
|
config.num_hidden_layers = 2
|
|
|
self.encoder = BertEncoder(config)
|
|
|
self.pooler = BertPooler(config)
|
|
|
|
|
|
def forward(self, x, mask=None):
|
|
|
B, Seq_len, Enbed = x.size()
|
|
|
cls_emb_batch = self.embeddings.expand(B, 1, Enbed)
|
|
|
merged_output = torch.cat((cls_emb_batch, x), dim=1)
|
|
|
if mask is not None:
|
|
|
device = x.device
|
|
|
cls_mask = torch.ones(B, 1).to(device)
|
|
|
mask = torch.cat([cls_mask, mask], dim=1)
|
|
|
mask = mask[:, None, None, :]
|
|
|
|
|
|
sequence_output = self.encoder(merged_output,
|
|
|
attention_mask=mask,
|
|
|
head_mask=None,
|
|
|
encoder_hidden_states=None,
|
|
|
encoder_attention_mask=None,
|
|
|
output_attentions=False,
|
|
|
output_hidden_states=False,
|
|
|
return_dict=False)[0]
|
|
|
pooled_output = self.pooler(sequence_output)
|
|
|
return pooled_output
|
|
|
|
|
|
|
|
|
class GlobalMaxPool1d(nn.Module):
|
|
|
def __init__(self):
|
|
|
super(GlobalMaxPool1d,self).__init__()
|
|
|
self.fc = nn.AdaptiveMaxPool1d(1)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = x.permute(0, 2, 1)
|
|
|
x = self.fc(x)
|
|
|
x = torch.squeeze(x, dim=-1)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class GlobalAvgPool1d(nn.Module):
|
|
|
def __init__(self, ):
|
|
|
super(GlobalAvgPool1d, self).__init__()
|
|
|
self.fc = nn.AdaptiveAvgPool1d(1)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = x.permute(0, 2, 1)
|
|
|
x = self.fc(x)
|
|
|
x = torch.squeeze(x, dim=-1)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class AttentionPool1d(nn.Module):
|
|
|
def __init__(self, embed_size, device="cuda"):
|
|
|
super(AttentionPool1d, self).__init__()
|
|
|
self.embed_size = embed_size
|
|
|
self.W = nn.Parameter(torch.Tensor(self.embed_size, self.embed_size))
|
|
|
self.b = nn.Parameter(torch.Tensor(self.embed_size))
|
|
|
self.c = nn.Parameter(torch.Tensor(self.embed_size))
|
|
|
nn.init.trunc_normal_(self.W, std=0.02)
|
|
|
nn.init.trunc_normal_(self.b, std=0.02)
|
|
|
nn.init.trunc_normal_(self.c, std=0.02)
|
|
|
|
|
|
def forward(self, x):
|
|
|
'''
|
|
|
# x:(B, Seq_len, Enbed)
|
|
|
# mul: (B, Seq_len)
|
|
|
mul = torch.matmul(x, self.w)
|
|
|
# B, Seq_len
|
|
|
attention_probs = nn.Softmax(dim=-1)(mul)
|
|
|
# B, Seq_len
|
|
|
x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
|
|
|
'''
|
|
|
mul = torch.tanh(torch.matmul(x, self.W) + self.b)
|
|
|
mul = torch.matmul(mul, self.c)
|
|
|
attention_probs = nn.Softmax(dim=-1)(mul)
|
|
|
x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class TransformerPool1d(nn.Module):
|
|
|
def __init__(self, config, embeddings, embed_size, num_transformer_layers=2, CLS_ID=102, device="cuda"):
|
|
|
super(TransformerPool1d, self).__init__()
|
|
|
if embeddings:
|
|
|
self.embeddings = embeddings
|
|
|
else:
|
|
|
self.embeddings = nn.Parameter(torch.Tensor(1, 1, embed_size))
|
|
|
nn.init.trunc_normal_(self.embeddings, std=0.02)
|
|
|
|
|
|
self.CLS_ID = CLS_ID
|
|
|
self.device = device
|
|
|
config.num_hidden_layers = num_transformer_layers
|
|
|
self.encoder = BertEncoder(config)
|
|
|
self.pooler = BertPooler(config)
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
B, Seq_len, Enbed = x.size()
|
|
|
|
|
|
cls_emb_batch = self.embeddings.expand(B, 1, Enbed)
|
|
|
merged_output = torch.cat((cls_emb_batch, x), dim=1)
|
|
|
sequence_output = self.encoder(merged_output,
|
|
|
attention_mask=None,
|
|
|
head_mask=None,
|
|
|
encoder_hidden_states=None,
|
|
|
encoder_attention_mask=None,
|
|
|
output_attentions=False,
|
|
|
output_hidden_states=False,
|
|
|
return_dict=False)[0]
|
|
|
pooled_output = self.pooler(sequence_output)
|
|
|
return pooled_output
|
|
|
|
|
|
|
|
|
|