Guihss commited on
Commit
bb8ce2c
·
1 Parent(s): f1fe08f

Upload TypeBERTForSequenceClassification

Browse files
Files changed (3) hide show
  1. config.json +28 -0
  2. pytorch_model.bin +3 -0
  3. type_bert_model.py +68 -0
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TypeBERTForSequenceClassification"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "type_bert_model.TypeBERTConfig",
7
+ "AutoModelForSequenceClassification": "type_bert_model.TypeBERTForSequenceClassification"
8
+ },
9
+ "id2label": {
10
+ "0": "agent",
11
+ "1": "event",
12
+ "2": "place",
13
+ "3": "item",
14
+ "4": "virtual",
15
+ "5": "concept"
16
+ },
17
+ "label2id": {
18
+ "agent": 0,
19
+ "concept": 5,
20
+ "event": 1,
21
+ "item": 3,
22
+ "place": 2,
23
+ "virtual": 4
24
+ },
25
+ "model_type": "type_bert",
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.22.1"
28
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a51f16e0417694151ffbea00afe90058d9707f82d0d59d88ed9a64230088f2fd
3
+ size 448627745
type_bert_model.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertModel
2
+ from transformers import PretrainedConfig, PreTrainedModel
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class TypeBERTConfig(PretrainedConfig):
8
+ model_type = "type_bert"
9
+
10
+ def __init__(self, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.id2label = {
13
+ 0: "agent",
14
+ 1: "event",
15
+ 2: "place",
16
+ 3: "item",
17
+ 4: "virtual",
18
+ 5: "concept"
19
+ }
20
+
21
+ self.label2id = {
22
+ "agent": 0,
23
+ "event": 1,
24
+ "place": 2,
25
+ "item": 3,
26
+ "virtual": 4,
27
+ "concept": 5
28
+ }
29
+
30
+
31
+
32
+ class TypeBERTForSequenceClassification(PreTrainedModel):
33
+ config_class = TypeBERTConfig
34
+
35
+ def __init__(self, config):
36
+ super(TypeBERTForSequenceClassification, self).__init__(config)
37
+ self.bert = BertModel.from_pretrained("bert-base-uncased")
38
+ # for param in self.bert.base_model.parameters():
39
+ # param.requires_grad = False
40
+ #
41
+ # self.bert.eval()
42
+
43
+ self.tanh = nn.Tanh()
44
+
45
+ self.dff = nn.Sequential(
46
+ nn.Linear(768, 2048),
47
+ nn.ReLU(),
48
+ nn.Dropout(0.1),
49
+ nn.Linear(2048, 512),
50
+ nn.ReLU(),
51
+ nn.Dropout(0.1),
52
+ nn.Linear(512, 64),
53
+ nn.ReLU(),
54
+ nn.Dropout(0.1),
55
+ nn.Linear(64, 6),
56
+ nn.LogSoftmax(dim=1)
57
+ )
58
+
59
+ self.eval()
60
+
61
+ def forward(self, **kwargs):
62
+
63
+ a = kwargs['attention_mask']
64
+ embs = self.bert(**kwargs)['last_hidden_state']
65
+
66
+ embs *= a.unsqueeze(2)
67
+ out = embs.sum(dim=1) / a.sum(dim=1, keepdims=True)
68
+ return {'logits': self.dff(self.tanh(out))}