|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from collections import Sequence |
|
|
from packaging import version |
|
|
from .ops import * |
|
|
from .disentangled_attention import * |
|
|
from .da_utils import * |
|
|
|
|
|
__all__ = ['BertEncoder', 'BertEmbeddings', 'ACT2FN', 'LayerNorm', 'BertLMPredictionHead'] |
|
|
|
|
|
class BertSelfOutput(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) |
|
|
self.dropout = StableDropout(config.hidden_dropout_prob) |
|
|
self.config = config |
|
|
|
|
|
def forward(self, hidden_states, input_states, mask=None): |
|
|
hidden_states = self.dense(hidden_states) |
|
|
hidden_states = self.dropout(hidden_states) |
|
|
hidden_states += input_states |
|
|
hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
class BertAttention(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.self = DisentangledSelfAttention(config) |
|
|
self.output = BertSelfOutput(config) |
|
|
self.config = config |
|
|
|
|
|
def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None): |
|
|
output = self.self(hidden_states, attention_mask, return_att, query_states=query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings) |
|
|
self_output, att_matrix, att_logits_=output['hidden_states'], output['attention_probs'], output['attention_logits'] |
|
|
if query_states is None: |
|
|
query_states = hidden_states |
|
|
attention_output = self.output(self_output, query_states, attention_mask) |
|
|
|
|
|
if return_att: |
|
|
return (attention_output, att_matrix) |
|
|
else: |
|
|
return attention_output |
|
|
|
|
|
class BertIntermediate(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) |
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act] \ |
|
|
if isinstance(config.hidden_act, str) else config.hidden_act |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
hidden_states = self.dense(hidden_states) |
|
|
hidden_states = self.intermediate_act_fn(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
class BertOutput(nn.Module): |
|
|
def __init__(self, config): |
|
|
super(BertOutput, self).__init__() |
|
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) |
|
|
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) |
|
|
self.dropout = StableDropout(config.hidden_dropout_prob) |
|
|
self.config = config |
|
|
|
|
|
def forward(self, hidden_states, input_states, mask=None): |
|
|
hidden_states = self.dense(hidden_states) |
|
|
hidden_states = self.dropout(hidden_states) |
|
|
hidden_states += input_states |
|
|
hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
class BertLayer(nn.Module): |
|
|
def __init__(self, config): |
|
|
super(BertLayer, self).__init__() |
|
|
self.attention = BertAttention(config) |
|
|
self.intermediate = BertIntermediate(config) |
|
|
self.output = BertOutput(config) |
|
|
|
|
|
def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None): |
|
|
attention_output = self.attention(hidden_states, attention_mask, return_att=return_att, \ |
|
|
query_states=query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings) |
|
|
if return_att: |
|
|
attention_output, att_matrix = attention_output |
|
|
intermediate_output = self.intermediate(attention_output) |
|
|
layer_output = self.output(intermediate_output, attention_output, attention_mask) |
|
|
if return_att: |
|
|
return (layer_output, att_matrix) |
|
|
else: |
|
|
return layer_output |
|
|
|
|
|
class ConvLayer(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
kernel_size = getattr(config, 'conv_kernel_size', 3) |
|
|
groups = getattr(config, 'conv_groups', 1) |
|
|
self.conv_act = getattr(config, 'conv_act', 'tanh') |
|
|
self.conv = torch.nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size, padding = (kernel_size-1)//2, groups = groups) |
|
|
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) |
|
|
self.dropout = StableDropout(config.hidden_dropout_prob) |
|
|
self.config = config |
|
|
|
|
|
def forward(self, hidden_states, residual_states, input_mask): |
|
|
out = self.conv(hidden_states.permute(0,2,1).contiguous()).permute(0,2,1).contiguous() |
|
|
if version.Version(torch.__version__) >= version.Version('1.2.0a'): |
|
|
rmask = (1-input_mask).bool() |
|
|
else: |
|
|
rmask = (1-input_mask).byte() |
|
|
out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) |
|
|
out = ACT2FN[self.conv_act](self.dropout(out)) |
|
|
output_states = MaskedLayerNorm(self.LayerNorm, residual_states + out, input_mask) |
|
|
|
|
|
return output_states |
|
|
|
|
|
class BertEncoder(nn.Module): |
|
|
""" Modified BertEncoder with relative position bias support |
|
|
""" |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) |
|
|
self.relative_attention = getattr(config, 'relative_attention', False) |
|
|
if self.relative_attention: |
|
|
self.max_relative_positions = getattr(config, 'max_relative_positions', -1) |
|
|
if self.max_relative_positions <1: |
|
|
self.max_relative_positions = config.max_position_embeddings |
|
|
self.position_buckets = getattr(config, 'position_buckets', -1) |
|
|
pos_ebd_size = self.max_relative_positions*2 |
|
|
if self.position_buckets>0: |
|
|
pos_ebd_size = self.position_buckets*2 |
|
|
self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) |
|
|
|
|
|
self.norm_rel_ebd = [x.strip() for x in getattr(config, 'norm_rel_ebd', 'none').lower().split('|')] |
|
|
if 'layer_norm' in self.norm_rel_ebd: |
|
|
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine = True) |
|
|
kernel_size = getattr(config, 'conv_kernel_size', 0) |
|
|
self.with_conv = False |
|
|
if kernel_size > 0: |
|
|
self.with_conv = True |
|
|
self.conv = ConvLayer(config) |
|
|
|
|
|
def get_rel_embedding(self): |
|
|
rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None |
|
|
if rel_embeddings is not None and ('layer_norm' in self.norm_rel_ebd): |
|
|
rel_embeddings = self.LayerNorm(rel_embeddings) |
|
|
return rel_embeddings |
|
|
|
|
|
def get_attention_mask(self, attention_mask): |
|
|
if attention_mask.dim()<=2: |
|
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
|
attention_mask = extended_attention_mask*extended_attention_mask.squeeze(-2).unsqueeze(-1) |
|
|
attention_mask = attention_mask.byte() |
|
|
elif attention_mask.dim()==3: |
|
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
|
|
return attention_mask |
|
|
|
|
|
def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): |
|
|
if self.relative_attention and relative_pos is None: |
|
|
q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) |
|
|
relative_pos = build_relative_position(q, hidden_states.size(-2), bucket_size = self.position_buckets, max_position=self.max_relative_positions) |
|
|
return relative_pos |
|
|
|
|
|
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, return_att=False, query_states = None, relative_pos=None): |
|
|
if attention_mask.dim()<=2: |
|
|
input_mask = attention_mask |
|
|
else: |
|
|
input_mask = (attention_mask.sum(-2)>0).byte() |
|
|
attention_mask = self.get_attention_mask(attention_mask) |
|
|
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) |
|
|
|
|
|
all_encoder_layers = [] |
|
|
att_matrices = [] |
|
|
if isinstance(hidden_states, Sequence): |
|
|
next_kv = hidden_states[0] |
|
|
else: |
|
|
next_kv = hidden_states |
|
|
rel_embeddings = self.get_rel_embedding() |
|
|
for i, layer_module in enumerate(self.layer): |
|
|
output_states = layer_module(next_kv, attention_mask, return_att, query_states = query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings) |
|
|
if return_att: |
|
|
output_states, att_m = output_states |
|
|
|
|
|
if i == 0 and self.with_conv: |
|
|
prenorm = output_states |
|
|
output_states = self.conv(hidden_states, prenorm, input_mask) |
|
|
|
|
|
if query_states is not None: |
|
|
query_states = output_states |
|
|
if isinstance(hidden_states, Sequence): |
|
|
next_kv = hidden_states[i+1] if i+1 < len(self.layer) else None |
|
|
else: |
|
|
next_kv = output_states |
|
|
|
|
|
if output_all_encoded_layers: |
|
|
all_encoder_layers.append(output_states) |
|
|
if return_att: |
|
|
att_matrices.append(att_m) |
|
|
if not output_all_encoded_layers: |
|
|
all_encoder_layers.append(output_states) |
|
|
if return_att: |
|
|
att_matrices.append(att_m) |
|
|
return { |
|
|
'hidden_states': all_encoder_layers, |
|
|
'attention_matrices': att_matrices |
|
|
} |
|
|
|
|
|
class BertEmbeddings(nn.Module): |
|
|
"""Construct the embeddings from word, position and token_type embeddings. |
|
|
""" |
|
|
def __init__(self, config): |
|
|
super(BertEmbeddings, self).__init__() |
|
|
padding_idx = getattr(config, 'padding_idx', 0) |
|
|
self.embedding_size = getattr(config, 'embedding_size', config.hidden_size) |
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx = padding_idx) |
|
|
self.position_biased_input = getattr(config, 'position_biased_input', True) |
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size) |
|
|
|
|
|
if config.type_vocab_size>0: |
|
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size) |
|
|
|
|
|
if self.embedding_size != config.hidden_size: |
|
|
self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) |
|
|
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) |
|
|
self.dropout = StableDropout(config.hidden_dropout_prob) |
|
|
self.output_to_half = False |
|
|
self.config = config |
|
|
|
|
|
def forward(self, input_ids, token_type_ids=None, position_ids=None, mask = None): |
|
|
seq_length = input_ids.size(1) |
|
|
if position_ids is None: |
|
|
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device) |
|
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) |
|
|
if token_type_ids is None: |
|
|
token_type_ids = torch.zeros_like(input_ids) |
|
|
|
|
|
words_embeddings = self.word_embeddings(input_ids) |
|
|
position_embeddings = self.position_embeddings(position_ids.long()) |
|
|
|
|
|
embeddings = words_embeddings |
|
|
if self.config.type_vocab_size>0: |
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
embeddings += token_type_embeddings |
|
|
|
|
|
if self.position_biased_input: |
|
|
embeddings += position_embeddings |
|
|
|
|
|
if self.embedding_size != self.config.hidden_size: |
|
|
embeddings = self.embed_proj(embeddings) |
|
|
embeddings = MaskedLayerNorm(self.LayerNorm, embeddings, mask) |
|
|
embeddings = self.dropout(embeddings) |
|
|
return { |
|
|
'embeddings': embeddings, |
|
|
'position_embeddings': position_embeddings} |
|
|
|
|
|
class BertLMPredictionHead(nn.Module): |
|
|
def __init__(self, config, vocab_size): |
|
|
super().__init__() |
|
|
self.embedding_size = getattr(config, 'embedding_size', config.hidden_size) |
|
|
self.dense = nn.Linear(config.hidden_size, self.embedding_size) |
|
|
self.transform_act_fn = ACT2FN[config.hidden_act] \ |
|
|
if isinstance(config.hidden_act, str) else config.hidden_act |
|
|
|
|
|
self.LayerNorm = LayerNorm(self.embedding_size, config.layer_norm_eps, elementwise_affine=True) |
|
|
|
|
|
self.bias = nn.Parameter(torch.zeros(vocab_size)) |
|
|
|
|
|
def forward(self, hidden_states, embeding_weight): |
|
|
hidden_states = self.dense(hidden_states) |
|
|
hidden_states = self.transform_act_fn(hidden_states) |
|
|
|
|
|
hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states) |
|
|
|
|
|
|
|
|
logits = torch.matmul(hidden_states, embeding_weight.t().to(hidden_states)) + self.bias |
|
|
return logits |
|
|
|
|
|
|
|
|
class AR_MASK(object): |
|
|
def get_attention_mask(self, input_ids=None, token_type_ids=None ): |
|
|
seq_len = input_ids.size(1) |
|
|
|
|
|
|
|
|
mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.uint8)).to(input_ids.device) |
|
|
mask = mask.unsqueeze(0).expand(input_ids.size(0), seq_len, seq_len) |
|
|
return mask |
|
|
|
|
|
|
|
|
class Prefix_MASK(object): |
|
|
def get_attention_mask(self, input_ids=None, token_type_ids=None): |
|
|
idxs = torch.cumsum(token_type_ids, axis=1) |
|
|
mask = idxs[:, None, :] <= idxs[:, :, None] |
|
|
return mask.byte().to(input_ids.device) |