File size: 649 Bytes
af4c621
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers.models.bert.configuration_bert import BertConfig
from typing import Optional

class BertForStanceConfig(BertConfig):
    model_type = "bert_for_stance"
    def __init__(self,
                 *,
                 classifier_hidden_units: Optional[int] = None,
                 **base_kwargs):
        super().__init__(**base_kwargs)
        self.problem_type = "single_label_classification"
        self.add_pooling_layer = False
        self.return_dict = True
        self.classifier_hidden_units = classifier_hidden_units if classifier_hidden_units else self.hidden_size

BertForStanceConfig.register_for_auto_class("AutoConfig")