from typing import Optional, Union, Tuple import torch from torch import nn from torch.nn.functional import binary_cross_entropy_with_logits from transformers import PreTrainedModel from transformers.models.deberta.configuration_deberta import DebertaConfig from transformers.models.deberta.modeling_deberta import DebertaModel from transformers.modeling_outputs import SequenceClassifierOutput class ResidualBlock(nn.Module): def __init__(self, input_dim: int, output_dim: int, num_groups: int = 8): super().__init__() self.linear_layers = nn.Sequential( nn.Linear(input_dim, 512), nn.GroupNorm(num_groups, 512), nn.ReLU(), nn.Dropout(0.4), nn.Linear(512, output_dim), nn.GroupNorm(num_groups, output_dim), nn.ReLU(), ) self.projection = ( nn.Linear(input_dim, output_dim) if input_dim != output_dim else nn.Identity() ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear_layers(x) + self.projection(x) class EnhancedDebertaForSequenceClassification(PreTrainedModel): """ DeBERTa-based classifier with optional extra feature branches. This is a HF-compatible reimplementation of your EnhancedDebertaModel. For the *baseline* model on the Hub, all extra feature dims are zero, so it behaves like "DeBERTa + linear multi-label head". """ config_class = DebertaConfig # Optional: you can give it a custom type name if you like model_type = "enhanced-deberta" def __init__(self, config: DebertaConfig): super().__init__(config) self.config = config self.num_labels = config.num_labels # ---- Backbone ---- # Keep the attribute name "transformer" so old state_dict keys match. self.transformer = DebertaModel(config) # Extra feature dimensions (defaults for baseline are all zero) num_categories = getattr(config, "num_categories", 0) ling_feature_dim = getattr(config, "ling_feature_dim", 0) ner_feature_dim = getattr(config, "ner_feature_dim", 0) topic_feature_dim = getattr(config, "topic_feature_dim", 0) multilayer = getattr(config, "multilayer", False) residualblock = getattr(config, "residualblock", False) previous_sentences = getattr(config, "previous_sentences", False) num_groups = getattr(config, "num_groups", 8) # ---- Lexicon branch ---- if num_categories > 0: self.lexicon_layer = nn.Sequential( nn.Linear(num_categories, 256), nn.ReLU(), nn.Dropout(0.4), nn.Linear(256, 128), nn.ReLU(), ) else: self.lexicon_layer = None # ---- Linguistic branch ---- if ling_feature_dim > 0: self.ling_layer = nn.Sequential( nn.Linear(ling_feature_dim, 128), nn.ReLU(), nn.Dropout(0.4), ) else: self.ling_layer = None # ---- NER branch ---- if ner_feature_dim > 0: self.ner_layer = nn.Sequential( nn.Linear(ner_feature_dim, 128), nn.ReLU(), nn.Dropout(0.4), ) else: self.ner_layer = None # ---- Topic branch ---- if topic_feature_dim > 0: self.topic_layer = nn.Sequential( nn.Linear(topic_feature_dim, 128), nn.ReLU(), nn.Dropout(0.4), ) else: self.topic_layer = None # ---- Text embedding head (optional multilayer / residual) ---- self.multilayer = multilayer self.residualblock = residualblock if multilayer: if residualblock: self.text_embedding_layer = ResidualBlock( self.transformer.config.hidden_size, 256, num_groups=num_groups ) else: self.text_embedding_layer = nn.Sequential( nn.Linear(self.transformer.config.hidden_size, 512), nn.GroupNorm(num_groups, 512), nn.ReLU(), nn.Dropout(0.4), nn.Linear(512, 256), nn.GroupNorm(num_groups, 256), nn.ReLU(), ) hidden_size = 256 else: self.text_embedding_layer = None hidden_size = self.transformer.config.hidden_size # ---- Previous-sentence labels branch ---- if previous_sentences: # 2 previous sentences × num_labels self.prev_label_size = 2 * self.num_labels self.prev_label_layer = nn.Sequential( nn.Linear(self.prev_label_size, 16), nn.ReLU(), nn.Dropout(0.4), ) else: self.prev_label_size = 0 self.prev_label_layer = None # ---- Final classification head ---- input_dim = hidden_size if self.lexicon_layer is not None: input_dim += 128 if self.ling_layer is not None: input_dim += 128 if self.ner_layer is not None: input_dim += 128 if self.topic_layer is not None: input_dim += 128 if self.prev_label_layer is not None: input_dim += 16 self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classification_head = nn.Linear(input_dim, self.num_labels) # label mappings (already in config, but we mirror them here) self.id2label = getattr(config, "id2label", None) self.label2id = getattr(config, "label2id", None) # Initialize weights (will be overwritten by from_pretrained) self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, lexicon_features: Optional[torch.Tensor] = None, linguistic_features: Optional[torch.Tensor] = None, ner_features: Optional[torch.Tensor] = None, topic_features: Optional[torch.Tensor] = None, prev_label_features: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, **kwargs, ) -> SequenceClassifierOutput: """ Forward pass. Extra feature tensors (lexicon_features, linguistic_features, etc.) are expected to be of shape [batch_size, feat_dim] when used. """ # Ensure integer token IDs if input_ids is not None: input_ids = input_ids.to(torch.long) # ---- Transformer backbone ---- if inputs_embeds is not None: backbone_outputs = self.transformer( inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) else: backbone_outputs = self.transformer( input_ids=input_ids, attention_mask=attention_mask, ) # CLS representation hidden_state = backbone_outputs.last_hidden_state cls_embed = hidden_state[:, 0, :] # [batch, hidden] # Optional multilayer / residual processing if self.text_embedding_layer is not None: text_embeddings = self.text_embedding_layer(cls_embed) else: text_embeddings = cls_embed combined = text_embeddings # ---- Lexicon branch ---- if self.lexicon_layer is not None and lexicon_features is not None: lexicon_features = lexicon_features.to(torch.float32) lexicon_output = self.lexicon_layer(lexicon_features) combined = torch.cat([combined, lexicon_output], dim=-1) # ---- Linguistic branch ---- if self.ling_layer is not None and linguistic_features is not None: linguistic_features = linguistic_features.to(combined.device) ling_output = self.ling_layer(linguistic_features) combined = torch.cat([combined, ling_output], dim=-1) # ---- NER branch ---- if self.ner_layer is not None and ner_features is not None: ner_features = ner_features.to(combined.device) ner_output = self.ner_layer(ner_features) combined = torch.cat([combined, ner_output], dim=-1) # ---- Topic branch ---- if self.topic_layer is not None and topic_features is not None: topic_features = topic_features.to(combined.device) topic_output = self.topic_layer(topic_features) combined = torch.cat([combined, topic_output], dim=-1) # ---- Previous-sentence labels branch ---- if self.prev_label_layer is not None and prev_label_features is not None: prev_label_features = prev_label_features.to(combined.device).float() prev_output = self.prev_label_layer(prev_label_features) combined = torch.cat([combined, prev_output], dim=-1) combined = self.dropout(combined) logits = self.classification_head(combined) loss = None if labels is not None: labels = labels.float() if labels.dim() == 1: labels = labels.unsqueeze(1) loss = binary_cross_entropy_with_logits(logits, labels) return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=backbone_outputs.hidden_states, attentions=backbone_outputs.attentions, )