File size: 2,450 Bytes
038426a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel, PretrainedConfig
from typing import Optional
from fsa import FSA_layer

class Text_encoder(BertPreTrainedModel):
    def __init__(self, config: PretrainedConfig):
        super(Text_encoder, self).__init__(config)
        self.encoder = BertModel(config)
        self.config = config
        hidden_size = self.config.hidden_size
        adapt_span_params={'adapt_span_enabled':True,'adapt_span_loss':0.0,'adapt_span_ramp':32,'adapt_span_init':0.0,'adapt_span_cache':False}
        self.sparse_attn_layer=FSA_layer(hidden_size=hidden_size, nb_heads=8, attn_span=30, dropout=0.1, inner_hidden_size=hidden_size, adapt_span_params=adapt_span_params)
        self.sigmoid = nn.Sigmoid()
        self.BCE_loss=nn.BCEWithLogitsLoss(reduction="sum")
        self.post_init()

    def forward(self, input_ids: Optional[torch.Tensor] = None,
                token_type_ids: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.Tensor] = None,
                attention_mask: Optional[torch.Tensor] = None,
                head_mask: Optional[torch.Tensor] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
                output_attentions: Optional[bool] = None,
                output_hidden_states: Optional[bool] = None,
                return_dict: Optional[bool] = None,
                ):

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.encoder(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict

        )
        sequence_output = outputs.last_hidden_state
        attentions=outputs.attentions
        pooled_output = outputs.pooler_output
        if output_attentions:
            sequence_output,fuzzy_span_attentions = self.sparse_attn_layer(sequence_output,output_attentions=output_attentions)
        else:
            sequence_output = self.sparse_attn_layer(sequence_output,output_attentions=output_attentions)
            fuzzy_span_attentions=None
        
        return sequence_output,pooled_output,attentions