import torch.nn as nn import torch from transformers import AutoModel class BERT_FFNN(nn.Module): """ BERT_FFNN: BERT + feed-forward network for text classification tasks. """ def __init__( self, bert_model_name= "microsoft/deberta-v3-base", hidden_dims=[192, 96], output_dim=5, dropout=0.2, pooling='attention', freeze_bert=False, freeze_layers=0, use_layer_norm=True ): super().__init__() # Load pretrained BERT self.bert = AutoModel.from_pretrained(bert_model_name) self.use_layer_norm = use_layer_norm self.pooling = pooling if pooling == 'attention': self.attention_pool = AttentionPooling(self.bert.config.hidden_size) if freeze_bert: for param in self.bert.parameters(): param.requires_grad = False elif freeze_layers > 0: for layer in self.bert.encoder.layer[:freeze_layers]: for param in layer.parameters(): param.requires_grad = False # Build FFNN layers fc_input_dim = self.bert.config.hidden_size layers = [] in_dim = fc_input_dim for h_dim in hidden_dims: layers.append(nn.Linear(in_dim, h_dim)) layers.append(nn.ReLU()) if use_layer_norm: layers.append(nn.LayerNorm(h_dim)) layers.append(nn.Dropout(dropout)) in_dim = h_dim layers.append(nn.Linear(in_dim, output_dim)) self.classifier = nn.Sequential(*layers) def forward(self, input_ids, attention_mask): # BERT forward outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) if self.pooling == 'mean': mask = attention_mask.unsqueeze(-1).float() sum_emb = (outputs.last_hidden_state * mask).sum(1) features = sum_emb / mask.sum(1).clamp(min=1e-9) elif self.pooling == 'max': mask = attention_mask.unsqueeze(-1).float() masked_emb = outputs.last_hidden_state.masked_fill(mask == 0, float('-inf')) features, _ = masked_emb.max(dim=1) elif self.pooling == 'attention': features = self.attention_pool(outputs.last_hidden_state, attention_mask) else: # CLS pooling features = outputs.pooler_output if getattr(outputs, 'pooler_output', None) is not None else outputs.last_hidden_state[:, 0] logits = self.classifier(features) return logits class AttentionPooling(nn.Module): def __init__(self, hidden_size): super().__init__() self.attention = nn.Linear(hidden_size, 1) def forward(self, hidden_states, attention_mask): # hidden_states: [batch, seq_len, hidden] # attention_mask: [batch, seq_len] scores = self.attention(hidden_states).squeeze(-1) # [batch, seq_len] scores = scores.masked_fill(attention_mask == 0, -1e9) weights = torch.softmax(scores, dim=-1) # [batch, seq_len] weighted_sum = torch.sum(hidden_states * weights.unsqueeze(-1), dim=1) return weighted_sum