| from transformers import PretrainedConfig | |
| class MultiHeadConfig(PretrainedConfig): | |
| model_type = "multihead" | |
| def __init__( | |
| self, | |
| encoder_name="microsoft/deberta-v3-small", | |
| **kwargs | |
| ): | |
| self.encoder_name = encoder_name | |
| self.classifier_dropout = kwargs.get("classifier_dropout", 0.1) | |
| self.num_labels = kwargs.get("num_labels", 2) | |
| self.id2label = kwargs.get("id2label", {0: "irrelevant", 1: "relevant"}) | |
| self.label2id = kwargs.get("label2id", {"irrelevant": 0, "relevant": 1}) | |
| self.tokenizer_class = kwargs.get("tokenizer_class", "DebertaV2TokenizerFast") | |
| super().__init__(**kwargs) | |