File size: 2,450 Bytes
038426a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel, PretrainedConfig
from typing import Optional
from fsa import FSA_layer
class Text_encoder(BertPreTrainedModel):
def __init__(self, config: PretrainedConfig):
super(Text_encoder, self).__init__(config)
self.encoder = BertModel(config)
self.config = config
hidden_size = self.config.hidden_size
adapt_span_params={'adapt_span_enabled':True,'adapt_span_loss':0.0,'adapt_span_ramp':32,'adapt_span_init':0.0,'adapt_span_cache':False}
self.sparse_attn_layer=FSA_layer(hidden_size=hidden_size, nb_heads=8, attn_span=30, dropout=0.1, inner_hidden_size=hidden_size, adapt_span_params=adapt_span_params)
self.sigmoid = nn.Sigmoid()
self.BCE_loss=nn.BCEWithLogitsLoss(reduction="sum")
self.post_init()
def forward(self, input_ids: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.encoder(
input_ids=input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
sequence_output = outputs.last_hidden_state
attentions=outputs.attentions
pooled_output = outputs.pooler_output
if output_attentions:
sequence_output,fuzzy_span_attentions = self.sparse_attn_layer(sequence_output,output_attentions=output_attentions)
else:
sequence_output = self.sparse_attn_layer(sequence_output,output_attentions=output_attentions)
fuzzy_span_attentions=None
return sequence_output,pooled_output,attentions |