| import torch | |
| import torch.nn as nn | |
| from transformers import BertModel, BertPreTrainedModel, PretrainedConfig | |
| from typing import Optional | |
| from fsa import FSA_layer | |
| class Text_encoder(BertPreTrainedModel): | |
| def __init__(self, config: PretrainedConfig): | |
| super(Text_encoder, self).__init__(config) | |
| self.encoder = BertModel(config) | |
| self.config = config | |
| hidden_size = self.config.hidden_size | |
| adapt_span_params={'adapt_span_enabled':True,'adapt_span_loss':0.0,'adapt_span_ramp':32,'adapt_span_init':0.0,'adapt_span_cache':False} | |
| self.sparse_attn_layer=FSA_layer(hidden_size=hidden_size, nb_heads=8, attn_span=30, dropout=0.1, inner_hidden_size=hidden_size, adapt_span_params=adapt_span_params) | |
| self.sigmoid = nn.Sigmoid() | |
| self.BCE_loss=nn.BCEWithLogitsLoss(reduction="sum") | |
| self.post_init() | |
| def forward(self, input_ids: Optional[torch.Tensor] = None, | |
| token_type_ids: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.encoder( | |
| input_ids=input_ids, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| attention_mask=attention_mask, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict | |
| ) | |
| sequence_output = outputs.last_hidden_state | |
| attentions=outputs.attentions | |
| pooled_output = outputs.pooler_output | |
| if output_attentions: | |
| sequence_output,fuzzy_span_attentions = self.sparse_attn_layer(sequence_output,output_attentions=output_attentions) | |
| else: | |
| sequence_output = self.sparse_attn_layer(sequence_output,output_attentions=output_attentions) | |
| fuzzy_span_attentions=None | |
| return sequence_output,pooled_output,attentions |