Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel, AutoModel | |
| class BiLSTMAttentionBERT(PreTrainedModel): | |
| def __init__(self, hidden_dim, num_classes, num_layers, dropout): | |
| super().__init__(PretrainedConfig()) | |
| self.bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2') | |
| self.lstm = nn.LSTM(768, hidden_dim, num_layers, batch_first=True, bidirectional=True) | |
| self.dropout = nn.Dropout(dropout) | |
| self.fc = nn.Linear(hidden_dim * 2, num_classes) | |
| def from_pretrained(cls, model_path, hidden_dim, num_classes, num_layers, dropout): | |
| model = cls(hidden_dim, num_classes, num_layers, dropout) | |
| state_dict = torch.load(model_path, map_location='cpu') | |
| model.load_state_dict(state_dict) | |
| return model | |
| def forward(self, input_ids, attention_mask): | |
| bert_output = self.bert(input_ids, attention_mask=attention_mask)[0] | |
| lstm_output, _ = self.lstm(bert_output) | |
| dropped = self.dropout(lstm_output[:, -1, :]) | |
| output = self.fc(dropped) | |
| return output |