energy-news-classifier-ner-multitask / modeling_energy_multitask.py
Trisham97's picture
Upload modeling_energy_multitask.py with huggingface_hub
a174ddf verified
from __future__ import annotations
"""Energy Intelligence Multitask Model.
Shared DistilBERT encoder with two task heads:
- NER head : token-level BIO entity tagging
- CLS head : sequence-level multi-label topic classification
"""
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.models.distilbert.modeling_distilbert import DistilBertModel
from transformers.utils import ModelOutput
# Works both as a HuggingFace remote-code module (relative) and as a plain
# local file (absolute). The try/except handles both cases.
try:
from .configuration_energy_multitask import EnergyMultitaskConfig
except ImportError:
from configuration_energy_multitask import EnergyMultitaskConfig
# ---------------------------------------------------------------------------
# Output dataclass
# ---------------------------------------------------------------------------
@dataclass
class EnergyMultitaskOutput(ModelOutput):
"""Output container returned by :class:`EnergyMultitaskModel`.
Attributes
----------
loss:
Combined NER + classification loss when labels are provided.
ner_logits:
Raw NER logits of shape ``(batch, seq_len, ner_num_labels)``.
Apply ``argmax(-1)`` for predicted token tags.
cls_logits:
Raw classification logits of shape ``(batch, cls_num_labels)``.
Apply ``sigmoid`` + threshold for active topic labels.
hidden_states:
Encoder hidden states (when ``output_hidden_states=True``).
attentions:
Attention weights (when ``output_attentions=True``).
"""
loss: Optional[torch.FloatTensor] = None
ner_logits: Optional[torch.FloatTensor] = None
cls_logits: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class EnergyMultitaskModel(PreTrainedModel):
"""DistilBERT encoder with a shared backbone and two task heads.
NER head
--------
Token-level linear classifier over all positions in the sequence.
Uses BIO tagging scheme with 19 labels (O + 9 entity types x B/I).
Classification head
-------------------
Sequence-level multi-label classifier on the [CLS] representation.
Uses ``BCEWithLogitsLoss`` during training (10 topic labels).
Quick start
-----------
>>> from transformers import AutoTokenizer
>>> from modeling_energy_multitask import EnergyMultitaskModel
>>> from configuration_energy_multitask import EnergyMultitaskConfig
>>>
>>> model = EnergyMultitaskModel.from_pretrained(
... "QuantBridge/energy-intelligence-multitask",
... trust_remote_code=True,
... )
>>> tokenizer = AutoTokenizer.from_pretrained(
... "QuantBridge/energy-intelligence-multitask",
... trust_remote_code=True,
... )
>>> inputs = tokenizer("Crude oil prices surged", return_tensors="pt")
>>> inputs.pop("token_type_ids", None) # DistilBERT does not use these
>>> out = model(**inputs)
>>> out.ner_logits.shape # (1, seq_len, 19)
>>> out.cls_logits.shape # (1, 10)
"""
config_class = EnergyMultitaskConfig
def __init__(self, config: EnergyMultitaskConfig) -> None:
super().__init__(config)
# Shared encoder
self.distilbert = DistilBertModel(config)
self.dropout = nn.Dropout(config.dropout)
# NER head: every token -> entity tag
self.ner_classifier = nn.Linear(config.dim, config.ner_num_labels)
# Classification head: [CLS] token -> topic labels
self.pre_classifier = nn.Linear(config.dim, config.dim)
self.cls_classifier = nn.Linear(config.dim, config.cls_num_labels)
self.seq_classif_dropout = nn.Dropout(config.seq_classif_dropout)
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
ner_labels: Optional[torch.Tensor] = None,
cls_labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> EnergyMultitaskOutput:
"""Forward pass through the shared encoder and both task heads.
Parameters
----------
input_ids:
Token ids, shape ``(batch, seq_len)``.
attention_mask:
Padding mask, shape ``(batch, seq_len)``.
ner_labels:
Integer token labels, shape ``(batch, seq_len)``.
Ignored positions should be ``-100``.
cls_labels:
Float multi-hot vector, shape ``(batch, cls_num_labels)``.
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
encoder_outputs = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0] # (batch, seq_len, dim)
# ── NER head β€” all token positions ───────────────────────────────────
ner_output = self.dropout(sequence_output)
ner_logits = self.ner_classifier(ner_output) # (batch, seq_len, ner_num_labels)
# ── CLS head β€” [CLS] token only ──────────────────────────────────────
cls_token = sequence_output[:, 0] # (batch, dim)
cls_token = self.pre_classifier(cls_token)
cls_token = nn.functional.relu(cls_token)
cls_token = self.seq_classif_dropout(cls_token)
cls_logits = self.cls_classifier(cls_token) # (batch, cls_num_labels)
# ── Loss (only computed when labels are supplied) ─────────────────────
loss = None
if ner_labels is not None:
loss = nn.CrossEntropyLoss(ignore_index=-100)(
ner_logits.view(-1, self.config.ner_num_labels),
ner_labels.view(-1),
)
if cls_labels is not None:
cls_loss = nn.BCEWithLogitsLoss()(cls_logits, cls_labels.float())
loss = cls_loss if loss is None else loss + cls_loss
return EnergyMultitaskOutput(
loss=loss,
ner_logits=ner_logits,
cls_logits=cls_logits,
hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
attentions=encoder_outputs.attentions if output_attentions else None,
)