File size: 487 Bytes
79eb00d
 
 
f316449
79eb00d
 
 
f316449
79eb00d
 
 
f316449
79eb00d
 
 
 
f316449
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch.nn as nn
from transformers import BertModel


class BertSeqTagger(nn.Module):
    def __init__(self, bert_model, num_labels=2, dropout=0.1):
        super().__init__()

        self.bert = BertModel.from_pretrained(bert_model)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, num_labels)

    def forward(self, x):
        y = self.bert(x)
        y = self.dropout(y["last_hidden_state"])
        logits = self.linear(y)
        return logits