| from torch import nn | |
| class BERT_Arch(nn.Module): | |
| def __init__(self, bert): | |
| super(BERT_Arch, self).__init__() | |
| self.bert = bert | |
| self.dropout = nn.Dropout(0.1) | |
| self.relu = nn.ReLU() | |
| self.fc1 = nn.Linear(768,512) | |
| self.fc2 = nn.Linear(512,2) | |
| self.softmax = nn.LogSoftmax(dim=1) | |
| def forward(self, sent_id, mask): | |
| _, cls_hs = self.bert(sent_id, attention_mask=mask, return_dict=False) | |
| x = self.fc1(cls_hs) | |
| x = self.relu(x) | |
| x = self.dropout(x) | |
| x = self.fc2(x) | |
| x = self.softmax(x) | |
| return x |