File size: 3,255 Bytes
2564e6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch.nn as nn
import torch
from transformers import AutoModel

class BERT_FFNN(nn.Module):
    """
    BERT_FFNN: BERT + feed-forward network for text classification tasks.
    """
    def __init__(
        self,
        bert_model_name= "microsoft/deberta-v3-base",
        hidden_dims=[192, 96],
        output_dim=5,
        dropout=0.2,
        pooling='attention',
        freeze_bert=False,
        freeze_layers=0,
        use_layer_norm=True
    ):
        super().__init__()
        
        # Load pretrained BERT
        self.bert = AutoModel.from_pretrained(bert_model_name)
        self.use_layer_norm = use_layer_norm
        self.pooling = pooling
        
        if pooling == 'attention':
            self.attention_pool = AttentionPooling(self.bert.config.hidden_size)
                
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
        elif freeze_layers > 0:
            for layer in self.bert.encoder.layer[:freeze_layers]:
                for param in layer.parameters():
                    param.requires_grad = False

        # Build FFNN layers
        fc_input_dim = self.bert.config.hidden_size
        layers = []
        in_dim = fc_input_dim
        for h_dim in hidden_dims:
            layers.append(nn.Linear(in_dim, h_dim))
            layers.append(nn.ReLU())
            if use_layer_norm:
                layers.append(nn.LayerNorm(h_dim))
            layers.append(nn.Dropout(dropout))
            in_dim = h_dim
        layers.append(nn.Linear(in_dim, output_dim))
        self.classifier = nn.Sequential(*layers)
    
    def forward(self, input_ids, attention_mask):
        # BERT forward
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        if self.pooling == 'mean':
            mask = attention_mask.unsqueeze(-1).float()
            sum_emb = (outputs.last_hidden_state * mask).sum(1)
            features = sum_emb / mask.sum(1).clamp(min=1e-9)
        elif self.pooling == 'max':
            mask = attention_mask.unsqueeze(-1).float()
            masked_emb = outputs.last_hidden_state.masked_fill(mask == 0, float('-inf'))
            features, _ = masked_emb.max(dim=1)
        elif self.pooling == 'attention':
            features = self.attention_pool(outputs.last_hidden_state, attention_mask)
        else:
            # CLS pooling
            features = outputs.pooler_output if getattr(outputs, 'pooler_output', None) is not None else outputs.last_hidden_state[:, 0]        
        
        logits = self.classifier(features)
        return logits

class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attention = nn.Linear(hidden_size, 1)

    def forward(self, hidden_states, attention_mask):
        # hidden_states: [batch, seq_len, hidden]
        # attention_mask: [batch, seq_len]

        scores = self.attention(hidden_states).squeeze(-1)  # [batch, seq_len]
        scores = scores.masked_fill(attention_mask == 0, -1e9)
        weights = torch.softmax(scores, dim=-1)  # [batch, seq_len]

        weighted_sum = torch.sum(hidden_states * weights.unsqueeze(-1), dim=1)
        return weighted_sum