LucaOne / pooling.py
Amelie
Intial commit
73d1860
#!/usr/bin/env python
# encoding: utf-8
import torch
import torch.nn as nn
from .modeling_bert import BertEncoder, BertPooler
class GlobalMaskMaxPooling1D(nn.Module):
def __init__(self, ):
super(GlobalMaskMaxPooling1D, self).__init__()
def forward(self, x, mask=None):
if mask is not None:
# (B, Seq_len) -> (B, Seq_len, 1)
mask = 1.0 - mask
mask = mask * (-2**10 + 1)
mask = torch.unsqueeze(mask, dim=-1)
x += mask
return torch.max(x, dim=1)[0]
class GlobalMaskMinPooling1D(nn.Module):
def __init__(self, ):
super(GlobalMaskMinPooling1D, self).__init__()
def forward(self, x, mask=None):
if mask is not None:
# (B, Seq_len) -> (B, Seq_len, 1)
mask = 1.0 - mask
mask = mask * (2**10+1)
mask = torch.unsqueeze(mask, dim=-1)
x += mask
return torch.min(x, dim=1)[0]
class GlobalMaskAvgPooling1D(nn.Module):
def __init__(self):
super(GlobalMaskAvgPooling1D, self).__init__()
def forward(self, x, mask=None):
if mask is not None:
# (B, Seq_len) -> (B, Seq_len, 1)
mask = torch.unsqueeze(mask, dim=-1)
x *= mask
return torch.sum(x, dim=1)/torch.sum(mask, dim=1)
else:
return torch.mean(x, dim=1)
class GlobalMaskSumPooling1D(nn.Module):
def __init__(self, axis):
'''
sum pooling
:param axis: axis=0, add all the rows of the matrix,axis=1, add all the cols of the matrix
'''
super(GlobalMaskSumPooling1D, self).__init__()
self.axis = axis
def forward(self, x, mask=None):
if mask is not None:
# (B, Seq_len) -> (B, Seq_len, 1)
mask = torch.unsqueeze(mask, dim=-1)
x *= mask
return torch.sum(x, dim=self.axis)
class GlobalMaskWeightedAttentionPooling1D(nn.Module):
def __init__(self, embed_size, use_bias=False):
super(GlobalMaskWeightedAttentionPooling1D, self).__init__()
self.embed_size = embed_size
self.use_bias = use_bias
self.W = nn.Parameter(torch.Tensor(self.embed_size))
nn.init.trunc_normal_(self.W, std=0.01)
if self.use_bias:
self.b = nn.Parameter(torch.Tensor(1))
nn.init.trunc_normal_(self.b, std=0.01)
def forward(self, x, mask=None):
# (B, Len, Embed) x (Embed,) = (B, Len)
logits = torch.matmul(x, self.W)
if self.use_bias:
logits += self.b
if mask is not None:
attention_probs = nn.Softmax(dim=-1)(logits + (1.0 - mask) * -10000)
else:
attention_probs = nn.Softmax(dim=-1)(logits)
x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
return x
class GlobalMaskContextAttentionPooling1D(nn.Module):
def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False):
super(GlobalMaskContextAttentionPooling1D, self).__init__()
self.embed_size = embed_size
self.use_additive_bias = use_additive_bias
self.use_attention_bias = use_attention_bias
self.units = units if units else embed_size
self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units))
self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units))
if self.use_additive_bias:
self.b1 = nn.Parameter(torch.Tensor(self.units))
nn.init.trunc_normal_(self.b1, std=0.01)
if self.use_attention_bias:
self.b2 = nn.Parameter(torch.Tensor(1))
nn.init.trunc_normal_(self.b2, std=0.01)
self.c = nn.Parameter(torch.Tensor(self.units))
nn.init.trunc_normal_(self.U, std=0.01)
nn.init.trunc_normal_(self.V, std=0.01)
nn.init.trunc_normal_(self.c, std=0.01)
def forward(self, x, mask=None):
# (B, Len, Embed) x (Embed, Units) = (B, Len, Units)
q = torch.matmul(x, self.U)
k = torch.matmul(x, self.V)
if self.use_additive_bias:
h = torch.tanh(q + k + self.b1)
else:
h = torch.tanh(q + k)
if self.use_attention_bias:
e = torch.matmul(h, self.c) + self.b2
else:
e = torch.matmul(h, self.c)
if mask is not None:
attention_probs = nn.Softmax(dim=-1)(e + (1.0 - mask) * -10000)
else:
attention_probs = nn.Softmax(dim=-1)(e)
x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
return x
class GlobalMaskValueAttentionPooling1D(nn.Module):
def __init__(self, embed_size, units=None, use_additive_bias=False, use_attention_bias=False):
super(GlobalMaskValueAttentionPooling1D, self).__init__()
self.embed_size = embed_size
self.use_additive_bias = use_additive_bias
self.use_attention_bias = use_attention_bias
self.units = units if units else embed_size
self.U = nn.Parameter(torch.Tensor(self.embed_size, self.units))
self.V = nn.Parameter(torch.Tensor(self.embed_size, self.units))
if self.use_additive_bias:
self.b1 = nn.Parameter(torch.Tensor(self.units))
nn.init.trunc_normal_(self.b1, std=0.01)
if self.use_attention_bias:
self.b2 = nn.Parameter(torch.Tensor(self.embed_size))
nn.init.trunc_normal_(self.b2, std=0.01)
self.W = nn.Parameter(torch.Tensor(self.units, self.embed_size))
nn.init.trunc_normal_(self.U, std=0.01)
nn.init.trunc_normal_(self.V, std=0.01)
nn.init.trunc_normal_(self.W, std=0.01)
def forward(self, x, mask=None):
# (B, Len, Embed) x (Embed, Units) = (B, Len, Units)
q = torch.matmul(x, self.U)
k = torch.matmul(x, self.V)
if self.use_additive_bias:
h = torch.tanh(q + k + self.b1)
else:
h = torch.tanh(q + k)
# (B, Len, Units) x (Units, Embed) = (B, Len, Embed)
if self.use_attention_bias:
e = torch.matmul(h, self.W) + self.b2
else:
e = torch.matmul(h, self.W)
if mask is not None:
attention_probs = nn.Softmax(dim=1)(e + torch.unsqueeze((1.0 - mask) * -10000, dim=-1))
else:
attention_probs = nn.Softmax(dim=1)(e)
x = torch.sum(attention_probs * x, dim=1)
return x
def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.embed_size) + ' -> ' + str(self.embed_size) + ')'
class GlobalMaskTransformerPooling1D(nn.Module):
def __init__(self, config):
super(GlobalMaskTransformerPooling1D, self).__init__()
self.embeddings = nn.Parameter(torch.Tensor(1, 1, config.hidden_size))
nn.init.trunc_normal_(self.embeddings, std=0.02)
config.num_hidden_layers = 2
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
def forward(self, x, mask=None):
B, Seq_len, Enbed = x.size()
cls_emb_batch = self.embeddings.expand(B, 1, Enbed)
merged_output = torch.cat((cls_emb_batch, x), dim=1) # [B, Seq_len + 1, Enbed]
if mask is not None:
device = x.device
cls_mask = torch.ones(B, 1).to(device)
mask = torch.cat([cls_mask, mask], dim=1)
mask = mask[:, None, None, :]
sequence_output = self.encoder(merged_output,
attention_mask=mask,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False)[0]
pooled_output = self.pooler(sequence_output)
return pooled_output
class GlobalMaxPool1d(nn.Module):
def __init__(self):
super(GlobalMaxPool1d,self).__init__()
self.fc = nn.AdaptiveMaxPool1d(1)
def forward(self, x):
x = x.permute(0, 2, 1)
x = self.fc(x)
x = torch.squeeze(x, dim=-1)
return x
class GlobalAvgPool1d(nn.Module):
def __init__(self, ):
super(GlobalAvgPool1d, self).__init__()
self.fc = nn.AdaptiveAvgPool1d(1)
def forward(self, x):
x = x.permute(0, 2, 1)
x = self.fc(x)
x = torch.squeeze(x, dim=-1)
return x
class AttentionPool1d(nn.Module):
def __init__(self, embed_size, device="cuda"):
super(AttentionPool1d, self).__init__()
self.embed_size = embed_size
self.W = nn.Parameter(torch.Tensor(self.embed_size, self.embed_size))
self.b = nn.Parameter(torch.Tensor(self.embed_size))
self.c = nn.Parameter(torch.Tensor(self.embed_size))
nn.init.trunc_normal_(self.W, std=0.02)
nn.init.trunc_normal_(self.b, std=0.02)
nn.init.trunc_normal_(self.c, std=0.02)
def forward(self, x):
'''
# x:(B, Seq_len, Enbed)
# mul: (B, Seq_len)
mul = torch.matmul(x, self.w)
# B, Seq_len
attention_probs = nn.Softmax(dim=-1)(mul)
# B, Seq_len
x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
'''
mul = torch.tanh(torch.matmul(x, self.W) + self.b)
mul = torch.matmul(mul, self.c)
attention_probs = nn.Softmax(dim=-1)(mul)
x = torch.sum(torch.unsqueeze(attention_probs, dim=-1) * x, dim=1)
return x
class TransformerPool1d(nn.Module):
def __init__(self, config, embeddings, embed_size, num_transformer_layers=2, CLS_ID=102, device="cuda"):
super(TransformerPool1d, self).__init__()
if embeddings:
self.embeddings = embeddings
else:
self.embeddings = nn.Parameter(torch.Tensor(1, 1, embed_size))
nn.init.trunc_normal_(self.embeddings, std=0.02)
# self.embeddings = BertEmbeddings(config)
self.CLS_ID = CLS_ID
self.device = device
config.num_hidden_layers = num_transformer_layers
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
def forward(self, x):
# x:(B, Seq_len, Enbed)
B, Seq_len, Enbed = x.size()
#cls_emb_batch = self.embeddings(torch.tensor([[self.CLS_ID]] * x.size()[0], dtype=torch.long).to(self.device)) # B, 1
cls_emb_batch = self.embeddings.expand(B, 1, Enbed)
merged_output = torch.cat((cls_emb_batch, x), dim=1) # [B, Seq_len + 1, Enbed]
sequence_output = self.encoder(merged_output,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False)[0]
pooled_output = self.pooler(sequence_output)
return pooled_output