citclass commited on
Commit
69f9a17
·
1 Parent(s): 5fd8b82

Upload DistilBertClassifier

Browse files
Files changed (3) hide show
  1. classifier.py +36 -0
  2. config.json +30 -0
  3. pytorch_model.bin +3 -0
classifier.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, AutoModel, AutoConfig, PretrainedConfig
4
+ import transformers
5
+
6
+
7
+ class DistilBertClassifier(PreTrainedModel):
8
+
9
+ def __init__(self, bert_config, model_name='distilbert-base-uncased', tokenizer_len=30528, freeze_bert=False):
10
+
11
+
12
+ super().__init__(bert_config)
13
+ D_in, H, D_out = 256, 50, 71
14
+
15
+ self.bert = AutoModel.from_pretrained(model_name)
16
+ self.bert.resize_token_embeddings(tokenizer_len)
17
+ self.classifier = nn.Sequential(
18
+ nn.GELU(),
19
+ nn.Linear(self.bert.config.hidden_size, 300),
20
+ nn.GELU(),
21
+ nn.Dropout(0.05),
22
+ nn.Linear(300, 71)
23
+ )
24
+
25
+ if freeze_bert:
26
+ for param in self.bert.parameters():
27
+ param.requires_grad = False
28
+
29
+ def forward(self, input_ids, attention_mask):
30
+
31
+ outputs = self.bert(input_ids=input_ids,
32
+ attention_mask=attention_mask)
33
+
34
+ last_hidden_state_cls = outputs[0][:, 0, :]
35
+ logits = self.classifier(last_hidden_state_cls)
36
+ return logits
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "distilbert-base-uncased",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "DistilBertClassifier"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "auto_map": {
9
+ "AutoModelForSequenceClassification": "classifier.DistilBertClassifier"
10
+ },
11
+ "dim": 800,
12
+ "dropout": 0.1,
13
+ "hidden_dim": 3072,
14
+ "hidden_dropout_prob": 0.1,
15
+ "initializer_range": 2,
16
+ "intermediate_size": 500,
17
+ "layer_norm_eps": 1e-07,
18
+ "max_position_embeddings": 270,
19
+ "model_type": "distilbert",
20
+ "n_heads": 12,
21
+ "n_layers": 3,
22
+ "pad_token_id": 0,
23
+ "qa_dropout": 0.1,
24
+ "seq_classif_dropout": 0.2,
25
+ "sinusoidal_pos_embds": false,
26
+ "tie_weights_": true,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.28.1",
29
+ "vocab_size": 30528
30
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6df5a46dbe135b0e83312597b1e11f174fda4db4ac797a248c367d876e2ccad
3
+ size 266511597