Spaces:
Running
Running
File size: 1,178 Bytes
f316449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
import torch.nn as nn
from Nested.nn import BaseModel
class BertNestedTagger(BaseModel):
def __init__(self, **kwargs):
super(BertNestedTagger, self).__init__(**kwargs)
self.max_num_labels = max(self.num_labels)
classifiers = [nn.Linear(768, num_labels) for num_labels in self.num_labels]
self.classifiers = torch.nn.Sequential(*classifiers)
def forward(self, x):
y = self.bert(x)
y = self.dropout(y["last_hidden_state"])
output = list()
for i, classifier in enumerate(self.classifiers):
logits = classifier(y)
# Pad logits to allow Multi-GPU/DataParallel training to work
# We will truncate the padded dimensions when we compute the loss in the trainer
logits = torch.nn.ConstantPad1d((0, self.max_num_labels - logits.shape[-1]), 0)(logits)
output.append(logits)
# Return tensor of the shape B x T x L x C
# B: batch size
# T: sequence length
# L: number of tag types
# C: number of classes per tag type
output = torch.stack(output).permute((1, 2, 0, 3))
return output
|