Spaces:
Running
Running
File size: 487 Bytes
79eb00d f316449 79eb00d f316449 79eb00d f316449 79eb00d f316449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch.nn as nn
from transformers import BertModel
class BertSeqTagger(nn.Module):
def __init__(self, bert_model, num_labels=2, dropout=0.1):
super().__init__()
self.bert = BertModel.from_pretrained(bert_model)
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(768, num_labels)
def forward(self, x):
y = self.bert(x)
y = self.dropout(y["last_hidden_state"])
logits = self.linear(y)
return logits
|