|
|
from typing import Optional, Union, Tuple |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn.functional import binary_cross_entropy_with_logits |
|
|
|
|
|
from transformers import PreTrainedModel |
|
|
from transformers.models.deberta.configuration_deberta import DebertaConfig |
|
|
from transformers.models.deberta.modeling_deberta import DebertaModel |
|
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
|
def __init__(self, input_dim: int, output_dim: int, num_groups: int = 8): |
|
|
super().__init__() |
|
|
self.linear_layers = nn.Sequential( |
|
|
nn.Linear(input_dim, 512), |
|
|
nn.GroupNorm(num_groups, 512), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.4), |
|
|
nn.Linear(512, output_dim), |
|
|
nn.GroupNorm(num_groups, output_dim), |
|
|
nn.ReLU(), |
|
|
) |
|
|
self.projection = ( |
|
|
nn.Linear(input_dim, output_dim) |
|
|
if input_dim != output_dim |
|
|
else nn.Identity() |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.linear_layers(x) + self.projection(x) |
|
|
|
|
|
|
|
|
class EnhancedDebertaForSequenceClassification(PreTrainedModel): |
|
|
""" |
|
|
DeBERTa-based classifier with optional extra feature branches. |
|
|
|
|
|
This is a HF-compatible reimplementation of your EnhancedDebertaModel. |
|
|
For the *baseline* model on the Hub, all extra feature dims are zero, |
|
|
so it behaves like "DeBERTa + linear multi-label head". |
|
|
""" |
|
|
|
|
|
config_class = DebertaConfig |
|
|
|
|
|
model_type = "enhanced-deberta" |
|
|
|
|
|
def __init__(self, config: DebertaConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.num_labels = config.num_labels |
|
|
|
|
|
|
|
|
|
|
|
self.transformer = DebertaModel(config) |
|
|
|
|
|
|
|
|
num_categories = getattr(config, "num_categories", 0) |
|
|
ling_feature_dim = getattr(config, "ling_feature_dim", 0) |
|
|
ner_feature_dim = getattr(config, "ner_feature_dim", 0) |
|
|
topic_feature_dim = getattr(config, "topic_feature_dim", 0) |
|
|
multilayer = getattr(config, "multilayer", False) |
|
|
residualblock = getattr(config, "residualblock", False) |
|
|
previous_sentences = getattr(config, "previous_sentences", False) |
|
|
num_groups = getattr(config, "num_groups", 8) |
|
|
|
|
|
|
|
|
if num_categories > 0: |
|
|
self.lexicon_layer = nn.Sequential( |
|
|
nn.Linear(num_categories, 256), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.4), |
|
|
nn.Linear(256, 128), |
|
|
nn.ReLU(), |
|
|
) |
|
|
else: |
|
|
self.lexicon_layer = None |
|
|
|
|
|
|
|
|
if ling_feature_dim > 0: |
|
|
self.ling_layer = nn.Sequential( |
|
|
nn.Linear(ling_feature_dim, 128), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.4), |
|
|
) |
|
|
else: |
|
|
self.ling_layer = None |
|
|
|
|
|
|
|
|
if ner_feature_dim > 0: |
|
|
self.ner_layer = nn.Sequential( |
|
|
nn.Linear(ner_feature_dim, 128), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.4), |
|
|
) |
|
|
else: |
|
|
self.ner_layer = None |
|
|
|
|
|
|
|
|
if topic_feature_dim > 0: |
|
|
self.topic_layer = nn.Sequential( |
|
|
nn.Linear(topic_feature_dim, 128), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.4), |
|
|
) |
|
|
else: |
|
|
self.topic_layer = None |
|
|
|
|
|
|
|
|
self.multilayer = multilayer |
|
|
self.residualblock = residualblock |
|
|
|
|
|
if multilayer: |
|
|
if residualblock: |
|
|
self.text_embedding_layer = ResidualBlock( |
|
|
self.transformer.config.hidden_size, 256, num_groups=num_groups |
|
|
) |
|
|
else: |
|
|
self.text_embedding_layer = nn.Sequential( |
|
|
nn.Linear(self.transformer.config.hidden_size, 512), |
|
|
nn.GroupNorm(num_groups, 512), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.4), |
|
|
nn.Linear(512, 256), |
|
|
nn.GroupNorm(num_groups, 256), |
|
|
nn.ReLU(), |
|
|
) |
|
|
hidden_size = 256 |
|
|
else: |
|
|
self.text_embedding_layer = None |
|
|
hidden_size = self.transformer.config.hidden_size |
|
|
|
|
|
|
|
|
if previous_sentences: |
|
|
|
|
|
self.prev_label_size = 2 * self.num_labels |
|
|
self.prev_label_layer = nn.Sequential( |
|
|
nn.Linear(self.prev_label_size, 16), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.4), |
|
|
) |
|
|
else: |
|
|
self.prev_label_size = 0 |
|
|
self.prev_label_layer = None |
|
|
|
|
|
|
|
|
input_dim = hidden_size |
|
|
if self.lexicon_layer is not None: |
|
|
input_dim += 128 |
|
|
if self.ling_layer is not None: |
|
|
input_dim += 128 |
|
|
if self.ner_layer is not None: |
|
|
input_dim += 128 |
|
|
if self.topic_layer is not None: |
|
|
input_dim += 128 |
|
|
if self.prev_label_layer is not None: |
|
|
input_dim += 16 |
|
|
|
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.classification_head = nn.Linear(input_dim, self.num_labels) |
|
|
|
|
|
|
|
|
self.id2label = getattr(config, "id2label", None) |
|
|
self.label2id = getattr(config, "label2id", None) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
lexicon_features: Optional[torch.Tensor] = None, |
|
|
linguistic_features: Optional[torch.Tensor] = None, |
|
|
ner_features: Optional[torch.Tensor] = None, |
|
|
topic_features: Optional[torch.Tensor] = None, |
|
|
prev_label_features: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
**kwargs, |
|
|
) -> SequenceClassifierOutput: |
|
|
""" |
|
|
Forward pass. |
|
|
|
|
|
Extra feature tensors (lexicon_features, linguistic_features, etc.) |
|
|
are expected to be of shape [batch_size, feat_dim] when used. |
|
|
""" |
|
|
|
|
|
|
|
|
if input_ids is not None: |
|
|
input_ids = input_ids.to(torch.long) |
|
|
|
|
|
|
|
|
if inputs_embeds is not None: |
|
|
backbone_outputs = self.transformer( |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
else: |
|
|
backbone_outputs = self.transformer( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
|
|
|
|
|
|
hidden_state = backbone_outputs.last_hidden_state |
|
|
cls_embed = hidden_state[:, 0, :] |
|
|
|
|
|
|
|
|
if self.text_embedding_layer is not None: |
|
|
text_embeddings = self.text_embedding_layer(cls_embed) |
|
|
else: |
|
|
text_embeddings = cls_embed |
|
|
|
|
|
combined = text_embeddings |
|
|
|
|
|
|
|
|
if self.lexicon_layer is not None and lexicon_features is not None: |
|
|
lexicon_features = lexicon_features.to(torch.float32) |
|
|
lexicon_output = self.lexicon_layer(lexicon_features) |
|
|
combined = torch.cat([combined, lexicon_output], dim=-1) |
|
|
|
|
|
|
|
|
if self.ling_layer is not None and linguistic_features is not None: |
|
|
linguistic_features = linguistic_features.to(combined.device) |
|
|
ling_output = self.ling_layer(linguistic_features) |
|
|
combined = torch.cat([combined, ling_output], dim=-1) |
|
|
|
|
|
|
|
|
if self.ner_layer is not None and ner_features is not None: |
|
|
ner_features = ner_features.to(combined.device) |
|
|
ner_output = self.ner_layer(ner_features) |
|
|
combined = torch.cat([combined, ner_output], dim=-1) |
|
|
|
|
|
|
|
|
if self.topic_layer is not None and topic_features is not None: |
|
|
topic_features = topic_features.to(combined.device) |
|
|
topic_output = self.topic_layer(topic_features) |
|
|
combined = torch.cat([combined, topic_output], dim=-1) |
|
|
|
|
|
|
|
|
if self.prev_label_layer is not None and prev_label_features is not None: |
|
|
prev_label_features = prev_label_features.to(combined.device).float() |
|
|
prev_output = self.prev_label_layer(prev_label_features) |
|
|
combined = torch.cat([combined, prev_output], dim=-1) |
|
|
|
|
|
combined = self.dropout(combined) |
|
|
logits = self.classification_head(combined) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
labels = labels.float() |
|
|
if labels.dim() == 1: |
|
|
labels = labels.unsqueeze(1) |
|
|
loss = binary_cross_entropy_with_logits(logits, labels) |
|
|
|
|
|
return SequenceClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=backbone_outputs.hidden_states, |
|
|
attentions=backbone_outputs.attentions, |
|
|
) |