human-value-detection-deberta-baseline / modeling_enhanced_deberta.py
VictorYeste's picture
Create modeling_enhanced_deberta.py
59a2ae6 verified
from typing import Optional, Union, Tuple
import torch
from torch import nn
from torch.nn.functional import binary_cross_entropy_with_logits
from transformers import PreTrainedModel
from transformers.models.deberta.configuration_deberta import DebertaConfig
from transformers.models.deberta.modeling_deberta import DebertaModel
from transformers.modeling_outputs import SequenceClassifierOutput
class ResidualBlock(nn.Module):
def __init__(self, input_dim: int, output_dim: int, num_groups: int = 8):
super().__init__()
self.linear_layers = nn.Sequential(
nn.Linear(input_dim, 512),
nn.GroupNorm(num_groups, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, output_dim),
nn.GroupNorm(num_groups, output_dim),
nn.ReLU(),
)
self.projection = (
nn.Linear(input_dim, output_dim)
if input_dim != output_dim
else nn.Identity()
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_layers(x) + self.projection(x)
class EnhancedDebertaForSequenceClassification(PreTrainedModel):
"""
DeBERTa-based classifier with optional extra feature branches.
This is a HF-compatible reimplementation of your EnhancedDebertaModel.
For the *baseline* model on the Hub, all extra feature dims are zero,
so it behaves like "DeBERTa + linear multi-label head".
"""
config_class = DebertaConfig
# Optional: you can give it a custom type name if you like
model_type = "enhanced-deberta"
def __init__(self, config: DebertaConfig):
super().__init__(config)
self.config = config
self.num_labels = config.num_labels
# ---- Backbone ----
# Keep the attribute name "transformer" so old state_dict keys match.
self.transformer = DebertaModel(config)
# Extra feature dimensions (defaults for baseline are all zero)
num_categories = getattr(config, "num_categories", 0)
ling_feature_dim = getattr(config, "ling_feature_dim", 0)
ner_feature_dim = getattr(config, "ner_feature_dim", 0)
topic_feature_dim = getattr(config, "topic_feature_dim", 0)
multilayer = getattr(config, "multilayer", False)
residualblock = getattr(config, "residualblock", False)
previous_sentences = getattr(config, "previous_sentences", False)
num_groups = getattr(config, "num_groups", 8)
# ---- Lexicon branch ----
if num_categories > 0:
self.lexicon_layer = nn.Sequential(
nn.Linear(num_categories, 256),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(256, 128),
nn.ReLU(),
)
else:
self.lexicon_layer = None
# ---- Linguistic branch ----
if ling_feature_dim > 0:
self.ling_layer = nn.Sequential(
nn.Linear(ling_feature_dim, 128),
nn.ReLU(),
nn.Dropout(0.4),
)
else:
self.ling_layer = None
# ---- NER branch ----
if ner_feature_dim > 0:
self.ner_layer = nn.Sequential(
nn.Linear(ner_feature_dim, 128),
nn.ReLU(),
nn.Dropout(0.4),
)
else:
self.ner_layer = None
# ---- Topic branch ----
if topic_feature_dim > 0:
self.topic_layer = nn.Sequential(
nn.Linear(topic_feature_dim, 128),
nn.ReLU(),
nn.Dropout(0.4),
)
else:
self.topic_layer = None
# ---- Text embedding head (optional multilayer / residual) ----
self.multilayer = multilayer
self.residualblock = residualblock
if multilayer:
if residualblock:
self.text_embedding_layer = ResidualBlock(
self.transformer.config.hidden_size, 256, num_groups=num_groups
)
else:
self.text_embedding_layer = nn.Sequential(
nn.Linear(self.transformer.config.hidden_size, 512),
nn.GroupNorm(num_groups, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, 256),
nn.GroupNorm(num_groups, 256),
nn.ReLU(),
)
hidden_size = 256
else:
self.text_embedding_layer = None
hidden_size = self.transformer.config.hidden_size
# ---- Previous-sentence labels branch ----
if previous_sentences:
# 2 previous sentences × num_labels
self.prev_label_size = 2 * self.num_labels
self.prev_label_layer = nn.Sequential(
nn.Linear(self.prev_label_size, 16),
nn.ReLU(),
nn.Dropout(0.4),
)
else:
self.prev_label_size = 0
self.prev_label_layer = None
# ---- Final classification head ----
input_dim = hidden_size
if self.lexicon_layer is not None:
input_dim += 128
if self.ling_layer is not None:
input_dim += 128
if self.ner_layer is not None:
input_dim += 128
if self.topic_layer is not None:
input_dim += 128
if self.prev_label_layer is not None:
input_dim += 16
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classification_head = nn.Linear(input_dim, self.num_labels)
# label mappings (already in config, but we mirror them here)
self.id2label = getattr(config, "id2label", None)
self.label2id = getattr(config, "label2id", None)
# Initialize weights (will be overwritten by from_pretrained)
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
lexicon_features: Optional[torch.Tensor] = None,
linguistic_features: Optional[torch.Tensor] = None,
ner_features: Optional[torch.Tensor] = None,
topic_features: Optional[torch.Tensor] = None,
prev_label_features: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> SequenceClassifierOutput:
"""
Forward pass.
Extra feature tensors (lexicon_features, linguistic_features, etc.)
are expected to be of shape [batch_size, feat_dim] when used.
"""
# Ensure integer token IDs
if input_ids is not None:
input_ids = input_ids.to(torch.long)
# ---- Transformer backbone ----
if inputs_embeds is not None:
backbone_outputs = self.transformer(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
else:
backbone_outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
)
# CLS representation
hidden_state = backbone_outputs.last_hidden_state
cls_embed = hidden_state[:, 0, :] # [batch, hidden]
# Optional multilayer / residual processing
if self.text_embedding_layer is not None:
text_embeddings = self.text_embedding_layer(cls_embed)
else:
text_embeddings = cls_embed
combined = text_embeddings
# ---- Lexicon branch ----
if self.lexicon_layer is not None and lexicon_features is not None:
lexicon_features = lexicon_features.to(torch.float32)
lexicon_output = self.lexicon_layer(lexicon_features)
combined = torch.cat([combined, lexicon_output], dim=-1)
# ---- Linguistic branch ----
if self.ling_layer is not None and linguistic_features is not None:
linguistic_features = linguistic_features.to(combined.device)
ling_output = self.ling_layer(linguistic_features)
combined = torch.cat([combined, ling_output], dim=-1)
# ---- NER branch ----
if self.ner_layer is not None and ner_features is not None:
ner_features = ner_features.to(combined.device)
ner_output = self.ner_layer(ner_features)
combined = torch.cat([combined, ner_output], dim=-1)
# ---- Topic branch ----
if self.topic_layer is not None and topic_features is not None:
topic_features = topic_features.to(combined.device)
topic_output = self.topic_layer(topic_features)
combined = torch.cat([combined, topic_output], dim=-1)
# ---- Previous-sentence labels branch ----
if self.prev_label_layer is not None and prev_label_features is not None:
prev_label_features = prev_label_features.to(combined.device).float()
prev_output = self.prev_label_layer(prev_label_features)
combined = torch.cat([combined, prev_output], dim=-1)
combined = self.dropout(combined)
logits = self.classification_head(combined)
loss = None
if labels is not None:
labels = labels.float()
if labels.dim() == 1:
labels = labels.unsqueeze(1)
loss = binary_cross_entropy_with_logits(logits, labels)
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=backbone_outputs.hidden_states,
attentions=backbone_outputs.attentions,
)