joko333's picture
Refactor BiLSTMAttentionBERT to use BiLSTMConfig for improved configuration management
41047a5
raw
history blame
1.21 kB
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, PretrainedConfig
class BiLSTMConfig(PretrainedConfig):
def __init__(self, hidden_dim=128, num_classes=22, num_layers=2, dropout=0.5, **kwargs):
super().__init__(**kwargs)
self.hidden_dim = hidden_dim
self.num_classes = num_classes
self.num_layers = num_layers
self.dropout = dropout
class BiLSTMAttentionBERT(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
self.lstm = nn.LSTM(768, config.hidden_dim, config.num_layers,
batch_first=True, bidirectional=True)
self.dropout = nn.Dropout(config.dropout)
self.fc = nn.Linear(config.hidden_dim * 2, config.num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids, attention_mask=attention_mask)
bert_output = outputs[0]
lstm_output, _ = self.lstm(bert_output)
dropped = self.dropout(lstm_output[:, -1, :])
logits = self.fc(dropped)
return logits