aaljabari commited on
Commit
8ebfc57
·
verified ·
1 Parent(s): c09e8d5

Create BertSeqTagger.py

Browse files
Files changed (1) hide show
  1. Nested/nn/BertSeqTagger.py +17 -0
Nested/nn/BertSeqTagger.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import BertModel
3
+
4
+
5
+ class BertSeqTagger(nn.Module):
6
+ def __init__(self, bert_model, num_labels=2, dropout=0.1):
7
+ super().__init__()
8
+
9
+ self.bert = BertModel.from_pretrained(bert_model)
10
+ self.dropout = nn.Dropout(dropout)
11
+ self.linear = nn.Linear(768, num_labels)
12
+
13
+ def forward(self, x):
14
+ y = self.bert(x)
15
+ y = self.dropout(y["last_hidden_state"])
16
+ logits = self.linear(y)
17
+ return logits