| from __future__ import annotations |
|
|
| """Custom config for the Energy Intelligence multitask model (NER + classification).""" |
|
|
| from transformers import DistilBertConfig |
|
|
|
|
| class EnergyMultitaskConfig(DistilBertConfig): |
| """Configuration for a DistilBERT model with a shared encoder and two heads: |
| |
| - **NER head** : token-level classification (BIO entity tags) |
| - **CLS head** : sequence-level multi-label classification (topic labels) |
| """ |
|
|
| model_type = "energy_multitask" |
|
|
| def __init__( |
| self, |
| |
| ner_num_labels: int = 19, |
| ner_id2label: dict | None = None, |
| ner_label2id: dict | None = None, |
| |
| cls_num_labels: int = 10, |
| cls_id2label: dict | None = None, |
| cls_label2id: dict | None = None, |
| seq_classif_dropout: float = 0.2, |
| **kwargs, |
| ) -> None: |
| |
| kwargs.pop("use_return_dict", None) |
| super().__init__(**kwargs) |
| self.ner_num_labels = ner_num_labels |
| self.ner_id2label = ner_id2label or {} |
| self.ner_label2id = ner_label2id or {} |
| self.cls_num_labels = cls_num_labels |
| self.cls_id2label = cls_id2label or {} |
| self.cls_label2id = cls_label2id or {} |
| self.seq_classif_dropout = seq_classif_dropout |
|
|