|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel, AutoModel |
|
|
from .configuration_bert_ffnn import BertFFNNConfig |
|
|
|
|
|
|
|
|
class AttentionPooling(nn.Module): |
|
|
def __init__(self, hidden_size): |
|
|
super().__init__() |
|
|
self.attention = nn.Linear(hidden_size, 1) |
|
|
|
|
|
def forward(self, hidden_states, attention_mask): |
|
|
scores = self.attention(hidden_states).squeeze(-1) |
|
|
scores = scores.masked_fill(attention_mask == 0, -1e9) |
|
|
weights = torch.softmax(scores, dim=-1) |
|
|
return torch.sum(hidden_states * weights.unsqueeze(-1), dim=1) |
|
|
|
|
|
|
|
|
class BERT_FFNN(PreTrainedModel): |
|
|
config_class = BertFFNNConfig |
|
|
base_model_prefix = "bert_ffnn" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.bert = AutoModel.from_pretrained(config.bert_model_name) |
|
|
self.pooling = config.pooling |
|
|
self.use_layer_norm = config.use_layer_norm |
|
|
|
|
|
if self.pooling == "attention": |
|
|
self.attention_pool = AttentionPooling(self.bert.config.hidden_size) |
|
|
if config.freeze_bert: |
|
|
for p in self.bert.parameters(): |
|
|
p.requires_grad = False |
|
|
elif config.freeze_layers > 0: |
|
|
for layer in self.bert.encoder.layer[:config.freeze_layers]: |
|
|
for p in layer.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
layers = [] |
|
|
in_dim = self.bert.config.hidden_size |
|
|
for h_dim in config.hidden_dims: |
|
|
layers.append(nn.Linear(in_dim, h_dim)) |
|
|
layers.append(nn.ReLU()) |
|
|
if config.use_layer_norm: |
|
|
layers.append(nn.LayerNorm(h_dim)) |
|
|
layers.append(nn.Dropout(config.dropout)) |
|
|
in_dim = h_dim |
|
|
|
|
|
layers.append(nn.Linear(in_dim, config.output_dim)) |
|
|
self.classifier = nn.Sequential(*layers) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
|
|
|
|
if self.pooling == "mean": |
|
|
mask = attention_mask.unsqueeze(-1).float() |
|
|
sum_emb = (outputs.last_hidden_state * mask).sum(1) |
|
|
features = sum_emb / mask.sum(1).clamp(min=1e-9) |
|
|
elif self.pooling == "max": |
|
|
mask = attention_mask.unsqueeze(-1).float() |
|
|
masked_emb = outputs.last_hidden_state.masked_fill(mask == 0, float('-inf')) |
|
|
features, _ = masked_emb.max(dim=1) |
|
|
elif self.pooling == "attention": |
|
|
features = self.attention_pool(outputs.last_hidden_state, attention_mask) |
|
|
else: |
|
|
features = ( |
|
|
outputs.pooler_output |
|
|
if getattr(outputs, "pooler_output", None) is not None |
|
|
else outputs.last_hidden_state[:, 0] |
|
|
) |
|
|
|
|
|
logits = self.classifier(features) |
|
|
return logits |
|
|
|