Trisham97 commited on
Commit
de19c4a
·
verified ·
1 Parent(s): e0cb87a

Upload configuration_energy_multitask.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_energy_multitask.py +39 -0
configuration_energy_multitask.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """Custom config for the Energy Intelligence multitask model (NER + classification)."""
4
+
5
+ from transformers import DistilBertConfig
6
+
7
+
8
+ class EnergyMultitaskConfig(DistilBertConfig):
9
+ """Configuration for a DistilBERT model with a shared encoder and two heads:
10
+
11
+ - **NER head** : token-level classification (BIO entity tags)
12
+ - **CLS head** : sequence-level multi-label classification (topic labels)
13
+ """
14
+
15
+ model_type = "energy_multitask"
16
+
17
+ def __init__(
18
+ self,
19
+ # --- NER head ---
20
+ ner_num_labels: int = 19,
21
+ ner_id2label: dict | None = None,
22
+ ner_label2id: dict | None = None,
23
+ # --- Classification head ---
24
+ cls_num_labels: int = 10,
25
+ cls_id2label: dict | None = None,
26
+ cls_label2id: dict | None = None,
27
+ seq_classif_dropout: float = 0.2,
28
+ **kwargs,
29
+ ) -> None:
30
+ # Silence deprecation: transformers 5+ uses return_dict not use_return_dict
31
+ kwargs.pop("use_return_dict", None)
32
+ super().__init__(**kwargs)
33
+ self.ner_num_labels = ner_num_labels
34
+ self.ner_id2label = ner_id2label or {}
35
+ self.ner_label2id = ner_label2id or {}
36
+ self.cls_num_labels = cls_num_labels
37
+ self.cls_id2label = cls_id2label or {}
38
+ self.cls_label2id = cls_label2id or {}
39
+ self.seq_classif_dropout = seq_classif_dropout