aaljabari commited on
Commit
c09e8d5
·
verified ·
1 Parent(s): e14f998

Create BertNestedTagger.py

Browse files
Files changed (1) hide show
  1. Nested/nn/BertNestedTagger.py +34 -0
Nested/nn/BertNestedTagger.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from Nested.nn import BaseModel
4
+
5
+
6
+ class BertNestedTagger(BaseModel):
7
+ def __init__(self, **kwargs):
8
+ super(BertNestedTagger, self).__init__(**kwargs)
9
+
10
+ self.max_num_labels = max(self.num_labels)
11
+ classifiers = [nn.Linear(768, num_labels) for num_labels in self.num_labels]
12
+ self.classifiers = torch.nn.Sequential(*classifiers)
13
+
14
+ def forward(self, x):
15
+ y = self.bert(x)
16
+ y = self.dropout(y["last_hidden_state"])
17
+ output = list()
18
+
19
+ for i, classifier in enumerate(self.classifiers):
20
+ logits = classifier(y)
21
+
22
+ # Pad logits to allow Multi-GPU/DataParallel training to work
23
+ # We will truncate the padded dimensions when we compute the loss in the trainer
24
+ logits = torch.nn.ConstantPad1d((0, self.max_num_labels - logits.shape[-1]), 0)(logits)
25
+ output.append(logits)
26
+
27
+ # Return tensor of the shape B x T x L x C
28
+ # B: batch size
29
+ # T: sequence length
30
+ # L: number of tag types
31
+ # C: number of classes per tag type
32
+ output = torch.stack(output).permute((1, 2, 0, 3))
33
+ return output
34
+