hugo-albert commited on
Commit
1b87188
·
verified ·
1 Parent(s): 098737f

Update XLMRoBERTaClassifier.py

Browse files
Files changed (1) hide show
  1. XLMRoBERTaClassifier.py +4 -5
XLMRoBERTaClassifier.py CHANGED
@@ -15,7 +15,7 @@ warnings.filterwarnings("ignore")
15
 
16
  class XLMRoBERTaClassifier(PreTrainedModel):
17
  def __init__(self, dropout=0.3, model_name='xlm-roberta-large'):
18
- self.config = AutoConfig.from_pretrained("hugo-albert/xlm-robertargument")
19
  super(XLMRoBERTaClassifier, self).__init__(self.config)
20
  self.roberta = XLMRobertaModel.from_pretrained(model_name)
21
  self.dropout = nn.Dropout(dropout)
@@ -29,9 +29,8 @@ class XLMRoBERTaClassifier(PreTrainedModel):
29
  self.final_layer = nn.Linear(128, 1)
30
 
31
  def forward(self, input_ids, attention_mask): #, extra_features):
32
- input_id, mask = torch.Tensor(input_ids).long(), torch.LongTensor(attention_mask).long()
33
- roberta_output = self.roberta(input_ids = input_id,
34
- attention_mask=mask)
35
  last_hidden_state = roberta_output.last_hidden_state
36
  conv_output = self.conv1(last_hidden_state)
37
  pool_output = self.pool(conv_output)
@@ -44,4 +43,4 @@ class XLMRoBERTaClassifier(PreTrainedModel):
44
  final_output = self.final_layer(dropout_output)
45
  sigmoid_output = self.sigmoid(final_output)
46
  sigmoid_output = torch.squeeze(sigmoid_output)
47
- return {"logits": sigmoid_output}
 
15
 
16
  class XLMRoBERTaClassifier(PreTrainedModel):
17
  def __init__(self, dropout=0.3, model_name='xlm-roberta-large'):
18
+ self.config = AutoConfig.from_pretrained("FacebookAI/xlm-roberta-large")
19
  super(XLMRoBERTaClassifier, self).__init__(self.config)
20
  self.roberta = XLMRobertaModel.from_pretrained(model_name)
21
  self.dropout = nn.Dropout(dropout)
 
29
  self.final_layer = nn.Linear(128, 1)
30
 
31
  def forward(self, input_ids, attention_mask): #, extra_features):
32
+ roberta_output = self.roberta(input_ids = input_ids,
33
+ attention_mask=attention_mask)
 
34
  last_hidden_state = roberta_output.last_hidden_state
35
  conv_output = self.conv1(last_hidden_state)
36
  pool_output = self.pool(conv_output)
 
43
  final_output = self.final_layer(dropout_output)
44
  sigmoid_output = self.sigmoid(final_output)
45
  sigmoid_output = torch.squeeze(sigmoid_output)
46
+ return sigmoid_output