wojood-api / Nested /nn /BertSeqTagger.py
naghamghanim's picture
Upload 37 files
f316449 verified
raw
history blame contribute delete
487 Bytes
import torch.nn as nn
from transformers import BertModel
class BertSeqTagger(nn.Module):
def __init__(self, bert_model, num_labels=2, dropout=0.1):
super().__init__()
self.bert = BertModel.from_pretrained(bert_model)
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(768, num_labels)
def forward(self, x):
y = self.bert(x)
y = self.dropout(y["last_hidden_state"])
logits = self.linear(y)
return logits