energy-news-classifier-ner-multitask / configuration_energy_multitask.py
Trisham97's picture
Upload configuration_energy_multitask.py with huggingface_hub
de19c4a verified
from __future__ import annotations
"""Custom config for the Energy Intelligence multitask model (NER + classification)."""
from transformers import DistilBertConfig
class EnergyMultitaskConfig(DistilBertConfig):
"""Configuration for a DistilBERT model with a shared encoder and two heads:
- **NER head** : token-level classification (BIO entity tags)
- **CLS head** : sequence-level multi-label classification (topic labels)
"""
model_type = "energy_multitask"
def __init__(
self,
# --- NER head ---
ner_num_labels: int = 19,
ner_id2label: dict | None = None,
ner_label2id: dict | None = None,
# --- Classification head ---
cls_num_labels: int = 10,
cls_id2label: dict | None = None,
cls_label2id: dict | None = None,
seq_classif_dropout: float = 0.2,
**kwargs,
) -> None:
# Silence deprecation: transformers 5+ uses return_dict not use_return_dict
kwargs.pop("use_return_dict", None)
super().__init__(**kwargs)
self.ner_num_labels = ner_num_labels
self.ner_id2label = ner_id2label or {}
self.ner_label2id = ner_label2id or {}
self.cls_num_labels = cls_num_labels
self.cls_id2label = cls_id2label or {}
self.cls_label2id = cls_label2id or {}
self.seq_classif_dropout = seq_classif_dropout