not-lain commited on
Commit
d0c15af
·
verified ·
1 Parent(s): f84c0a8

Update modeling_tunbert.py

Browse files
Files changed (1) hide show
  1. modeling_tunbert.py +18 -11
modeling_tunbert.py CHANGED
@@ -2,15 +2,17 @@ import torch.nn as nn
2
  from transformers import PreTrainedModel, BertModel
3
  from transformers.modeling_outputs import SequenceClassifierOutput
4
  from .config_tunbert import TunBertConfig
 
5
  class classifier(nn.Module):
6
  def __init__(self,config):
7
  super().__init__()
8
 
9
  self.layer0 = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=True)
10
  self.layer1 = nn.Linear(in_features=config.hidden_size, out_features=config.type_vocab_size, bias=True)
11
- def forward(self,tensor):
12
- out1 = self.layer0(tensor)
13
- return self.layer1(out1)
 
14
 
15
 
16
  class TunBERT(PreTrainedModel):
@@ -22,15 +24,20 @@ class TunBERT(PreTrainedModel):
22
  self.classifier = classifier(config)
23
 
24
  def forward(self,input_ids=None,token_type_ids=None,attention_mask=None,labels=None) :
25
- outputs = self.BertModel(input_ids,token_type_ids,attention_mask)
26
- sequence_output = self.dropout(outputs.last_hidden_state)
27
- logits = self.classifier(sequence_output)
28
- loss =None
29
- if labels is not None :
 
 
 
 
 
30
  loss_func = nn.CrossEntropyLoss()
31
- loss = loss_func(logits.view(-1,self.config.type_vocab_size),labels.view(-1))
32
- return SequenceClassifierOutput(loss = loss, logits= logits, hidden_states=outputs.last_hidden_state,attentions=outputs.attentions)
33
-
34
 
35
  TunBertConfig.register_for_auto_class()
36
  TunBERT.register_for_auto_class("AutoModelForSequenceClassification")
 
2
  from transformers import PreTrainedModel, BertModel
3
  from transformers.modeling_outputs import SequenceClassifierOutput
4
  from .config_tunbert import TunBertConfig
5
+
6
  class classifier(nn.Module):
7
  def __init__(self,config):
8
  super().__init__()
9
 
10
  self.layer0 = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=True)
11
  self.layer1 = nn.Linear(in_features=config.hidden_size, out_features=config.type_vocab_size, bias=True)
12
+
13
+ def forward(self,tensor):
14
+ out1 = self.layer0(tensor)
15
+ return self.layer1(out1)
16
 
17
 
18
  class TunBERT(PreTrainedModel):
 
24
  self.classifier = classifier(config)
25
 
26
  def forward(self,input_ids=None,token_type_ids=None,attention_mask=None,labels=None) :
27
+ outputs = self.BertModel(input_ids,token_type_ids,attention_mask)
28
+ sequence_output = self.dropout(outputs.last_hidden_state)
29
+ logits = self.classifier(sequence_output)
30
+ # every sentence is surrounded by [cls] in the beginning and [sep] in the end
31
+ # the [cls] token is used in bert to identify the class of the sentence
32
+ # meaning that we need only the first token of each sentence
33
+ # and the model representation of the rest of the sentence does not concern us
34
+ logits = logits[:,0,:] # [bs, seq, class]
35
+ loss =None
36
+ if labels is not None :
37
  loss_func = nn.CrossEntropyLoss()
38
+ loss = loss_func(logits.view(-1,self.config.type_vocab_size),labels.view(-1))
39
+ return SequenceClassifierOutput(loss = loss, logits= logits, hidden_states=outputs.last_hidden_state,attentions=outputs.attentions)
40
+
41
 
42
  TunBertConfig.register_for_auto_class()
43
  TunBERT.register_for_auto_class("AutoModelForSequenceClassification")